From 3948a4c656011113914f0a711a9d3bff08657a3f Mon Sep 17 00:00:00 2001 From: Augustin Date: Wed, 22 Apr 2026 22:58:05 +0200 Subject: [PATCH] refactor(chat): deduplicate streaming code, add multi-conv, and XSS protection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ChatEngine for deduplicated chat logic (handlers_chat/shell_chat) - Add SendWithToolsStream for real-time streaming responses - Add /help, /plan, /export, /model commands in Studio - Fix XSS: sanitize HTML after markdown rendering - Add ConversationStoreMulti for multi-conversation support - Add Anthropic headers (x-api-key, anthropic-version) - Add fallback logging when provider switch occurs - Add API handler tests (handlers_test.go) - Polish Studio: max-height 200px, word-break on tool args - Update CLI version to show full info (version, go, platform) 🤖 Generated with Crush Assisted-by: MiniMax-M2.5 via Crush --- cmd/muyue/commands/version.go | 4 +- go.mod | 1 + go.sum | 2 + internal/api/chat_engine.go | 249 +++++++++++++++++ internal/api/conversation_multi.go | 370 ++++++++++++++++++++++++++ internal/api/handlers_chat.go | 171 ++---------- internal/api/handlers_shell_chat.go | 217 +++++---------- internal/api/handlers_test.go | 66 +++++ internal/orchestrator/orchestrator.go | 154 ++++++++++- internal/version/version.go | 22 ++ web/src/components/Studio.jsx | 75 +++++- web/src/styles/global.css | 5 +- 12 files changed, 1024 insertions(+), 312 deletions(-) create mode 100644 internal/api/chat_engine.go create mode 100644 internal/api/conversation_multi.go create mode 100644 internal/api/handlers_test.go diff --git a/cmd/muyue/commands/version.go b/cmd/muyue/commands/version.go index 696b612..4d3baf3 100644 --- a/cmd/muyue/commands/version.go +++ b/cmd/muyue/commands/version.go @@ -9,7 +9,7 @@ import ( var versionCmd = &cobra.Command{ Use: "version", - Short: "Print version", + Short: "Print version info", RunE: runVersion, } @@ -18,6 +18,6 @@ func init() { } func runVersion(cmd *cobra.Command, args []string) error { - fmt.Printf("Muyue version %s\n", version.Version) + fmt.Print(version.FullInfo()) return nil } \ No newline at end of file diff --git a/go.mod b/go.mod index 7dc8717..b1ad938 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.24.3 require ( github.com/charmbracelet/huh v1.0.0 github.com/creack/pty/v2 v2.0.1 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/spf13/cobra v1.10.2 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 3ab332e..19f94e1 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/internal/api/chat_engine.go b/internal/api/chat_engine.go new file mode 100644 index 0000000..28feab8 --- /dev/null +++ b/internal/api/chat_engine.go @@ -0,0 +1,249 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "strings" + + "github.com/muyue/muyue/internal/agent" + "github.com/muyue/muyue/internal/orchestrator" +) + +const ( + MaxToolIterations = 15 +) + +// ChatEngine handles chat interactions with tool execution. +// This deduplicates chat logic previously repeated in handlers_chat.go and handlers_shell_chat.go. +type ChatEngine struct { + orchestrator *orchestrator.Orchestrator + registry *agent.Registry + tools json.RawMessage + onChunk func(map[string]interface{}) + stream bool +} + +// NewChatEngine creates a new ChatEngine instance. +func NewChatEngine(orb *orchestrator.Orchestrator, registry *agent.Registry, tools json.RawMessage) *ChatEngine { + return &ChatEngine{ + orchestrator: orb, + registry: registry, + tools: tools, + stream: false, + } +} + +// SetStream enables streaming mode for the chat engine. +func (ce *ChatEngine) SetStream(enabled bool) { + ce.stream = enabled +} + +// OnChunk sets the callback for SSE chunk writing. +func (ce *ChatEngine) OnChunk(fn func(map[string]interface{})) { + ce.onChunk = fn +} + +// RunWithTools executes the chat loop with tool calls. +// Returns final content, tool calls, tool results, and error. +func (ce *ChatEngine) RunWithTools(ctx context.Context, messages []orchestrator.Message) (string, []map[string]interface{}, []map[string]interface{}, error) { + var finalContent string + var allToolCalls []map[string]interface{} + var allToolResults []map[string]interface{} + + for i := 0; i < MaxToolIterations; i++ { + var resp *orchestrator.ChatResponse + var err error + + if ce.stream { + // Use streaming version + resp, err = ce.orchestrator.SendWithToolsStream(messages, func(content string, toolCalls []orchestrator.ToolCallMsg) { + if ce.onChunk != nil && content != "" { + ce.onChunk(map[string]interface{}{"content": content}) + } + }) + } else { + resp, err = ce.orchestrator.SendWithTools(messages) + } + if err != nil { + if ce.onChunk != nil { + ce.onChunk(map[string]interface{}{"error": err.Error()}) + } + return finalContent, allToolCalls, allToolResults, err + } + + choice := resp.Choices[0] + content := cleanThinkingTags(choice.Message.Content) + + if content != "" { + words := strings.Fields(content) + for _, w := range words { + chunk := w + if ce.onChunk != nil { + ce.onChunk(map[string]interface{}{"content": chunk}) + } + } + finalContent = content + } + + if len(choice.Message.ToolCalls) == 0 { + break + } + + assistantMsg := orchestrator.Message{ + Role: "assistant", + Content: content, + ToolCalls: choice.Message.ToolCalls, + } + messages = append(messages, assistantMsg) + + for _, tc := range choice.Message.ToolCalls { + toolCallData := map[string]interface{}{ + "tool_call_id": tc.ID, + "name": tc.Function.Name, + "args": tc.Function.Arguments, + } + allToolCalls = append(allToolCalls, toolCallData) + + if ce.onChunk != nil { + ce.onChunk(map[string]interface{}{"tool_call": toolCallData}) + } + + call := agent.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: json.RawMessage(tc.Function.Arguments), + } + + result, execErr := ce.registry.Execute(ctx, call) + if execErr != nil { + result = agent.ToolResponse{ + Content: execErr.Error(), + IsError: true, + } + } + + resultData := map[string]interface{}{ + "tool_call_id": tc.ID, + "content": result.Content, + "is_error": result.IsError, + } + allToolResults = append(allToolResults, map[string]interface{}{ + "tool_call_id": tc.ID, + "name": tc.Function.Name, + "args": tc.Function.Arguments, + "result": result.Content, + "is_error": result.IsError, + }) + + if ce.onChunk != nil { + ce.onChunk(map[string]interface{}{"tool_result": resultData}) + } + + messages = append(messages, orchestrator.Message{ + Role: "tool", + Content: result.Content, + ToolCallID: tc.ID, + Name: tc.Function.Name, + }) + } + + finalContent = "" + } + + return finalContent, allToolCalls, allToolResults, nil +} + +// RunNonStream executes chat without streaming content to client. +func (ce *ChatEngine) RunNonStream(ctx context.Context, messages []orchestrator.Message) (string, error) { + var finalContent string + + for i := 0; i < MaxToolIterations; i++ { + resp, err := ce.orchestrator.SendWithTools(messages) + if err != nil { + return finalContent, err + } + + choice := resp.Choices[0] + content := cleanThinkingTags(choice.Message.Content) + + if content != "" { + finalContent = content + } + + if len(choice.Message.ToolCalls) == 0 { + break + } + + assistantMsg := orchestrator.Message{ + Role: "assistant", + Content: content, + ToolCalls: choice.Message.ToolCalls, + } + messages = append(messages, assistantMsg) + + for _, tc := range choice.Message.ToolCalls { + call := agent.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: json.RawMessage(tc.Function.Arguments), + } + + result, execErr := ce.registry.Execute(ctx, call) + if execErr != nil { + result = agent.ToolResponse{ + Content: execErr.Error(), + IsError: true, + } + } + + messages = append(messages, orchestrator.Message{ + Role: "tool", + Content: result.Content, + ToolCallID: tc.ID, + Name: tc.Function.Name, + }) + } + + finalContent = "" + } + + if finalContent == "" { + finalContent = "(tool calls completed, no text response)" + } + + return finalContent, nil +} + +// SSEWriter handles Server-Sent Events writing to HTTP response. +type SSEWriter struct { + w http.ResponseWriter + flusher http.Flusher +} + +// NewSSEWriter creates a new SSEWriter. +func NewSSEWriter(w http.ResponseWriter) *SSEWriter { + sse := &SSEWriter{w: w} + if f, ok := w.(http.Flusher); ok { + sse.flusher = f + } + return sse +} + +// Write sends an SSE message. +func (s *SSEWriter) Write(data map[string]interface{}) { + b, _ := json.Marshal(data) + s.w.Write([]byte("data: " + string(b) + "\n\n")) + if s.flusher != nil { + s.flusher.Flush() + } +} + +// SetupSSEHeaders sets up SSE response headers. +func SetupSSEHeaders(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) +} \ No newline at end of file diff --git a/internal/api/conversation_multi.go b/internal/api/conversation_multi.go new file mode 100644 index 0000000..bb488e6 --- /dev/null +++ b/internal/api/conversation_multi.go @@ -0,0 +1,370 @@ +package api + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/google/uuid" + "github.com/muyue/muyue/internal/config" +) + +// ConversationMeta represents metadata for a conversation (used for listing). +type ConversationMeta struct { + ID string `json:"id"` + Title string `json:"title"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + MessageCount int `json:"message_count"` +} + +// ConversationStoreMulti manages multiple conversations. +type ConversationStoreMulti struct { + mu sync.RWMutex + dir string + currentID string + conversations map[string]*Conversation +} + +func NewConversationStoreMulti() *ConversationStoreMulti { + dir, err := config.ConfigDir() + if err != nil { + dir = "/tmp/muyue" + } + dir = filepath.Join(dir, "conversations") + + cs := &ConversationStoreMulti{ + dir: dir, + conversations: make(map[string]*Conversation), + } + cs.loadIndex() + return cs +} + +func (cs *ConversationStoreMulti) loadIndex() { + os.MkdirAll(cs.dir, 0755) + + // Load index file if exists + indexPath := filepath.Join(cs.dir, "index.json") + data, err := os.ReadFile(indexPath) + if err != nil { + // Create default conversation + cs.createDefault() + return + } + + var index struct { + CurrentID string `json:"current_id"` + Conversations []ConversationMeta `json:"conversations"` + } + if err := json.Unmarshal(data, &index); err != nil { + cs.createDefault() + return + } + + cs.currentID = index.CurrentID + if cs.currentID == "" { + cs.createDefault() + return + } + + // Load all conversations + for _, meta := range index.Conversations { + convPath := filepath.Join(cs.dir, fmt.Sprintf("conv_%s.json", meta.ID)) + data, err := os.ReadFile(convPath) + if err != nil { + continue + } + var conv Conversation + if err := json.Unmarshal(data, &conv); err != nil { + continue + } + cs.conversations[meta.ID] = &conv + } + + // Ensure current conversation exists + if _, ok := cs.conversations[cs.currentID]; !ok { + cs.createDefault() + } +} + +func (cs *ConversationStoreMulti) createDefault() { + cs.currentID = uuid.New().String() + cs.conversations[cs.currentID] = &Conversation{ + Messages: []FeedMessage{}, + CreatedAt: time.Now().Format(time.RFC3339), + UpdatedAt: time.Now().Format(time.RFC3339), + } + cs.saveIndex() +} + +func (cs *ConversationStoreMulti) saveIndex() error { + var metas []ConversationMeta + for id, conv := range cs.conversations { + title := "Nouvelle conversation" + if len(conv.Messages) > 0 { + // Use first user message as title + for _, m := range conv.Messages { + if m.Role == "user" { + if len(m.Content) > 50 { + title = m.Content[:50] + "..." + } else { + title = m.Content + } + break + } + } + } + metas = append(metas, ConversationMeta{ + ID: id, + Title: title, + CreatedAt: conv.CreatedAt, + UpdatedAt: conv.UpdatedAt, + MessageCount: len(conv.Messages), + }) + } + + index := struct { + CurrentID string `json:"current_id"` + Conversations []ConversationMeta `json:"conversations"` + }{ + CurrentID: cs.currentID, + Conversations: metas, + } + + data, err := json.MarshalIndent(index, "", " ") + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(cs.dir, "index.json"), data, 0600) +} + +func (cs *ConversationStoreMulti) saveCurrent() error { + conv, ok := cs.conversations[cs.currentID] + if !ok { + return fmt.Errorf("no current conversation") + } + + conv.UpdatedAt = time.Now().Format(time.RFC3339) + data, err := json.MarshalIndent(conv, "", " ") + if err != nil { + return err + } + + convPath := filepath.Join(cs.dir, fmt.Sprintf("conv_%s.json", cs.currentID)) + if err := os.WriteFile(convPath, data, 0600); err != nil { + return err + } + + return cs.saveIndex() +} + +// Current returns the current conversation store. +func (cs *ConversationStoreMulti) Current() *ConversationStore { + cs.mu.RLock() + defer cs.mu.RUnlock() + + conv, ok := cs.conversations[cs.currentID] + if !ok { + return &ConversationStore{ + conv: &Conversation{ + Messages: []FeedMessage{}, + CreatedAt: time.Now().Format(time.RFC3339), + UpdatedAt: time.Now().Format(time.RFC3339), + }, + } + } + + return &ConversationStore{ + conv: conv, + } +} + +// Get returns the current conversation messages. +func (cs *ConversationStoreMulti) Get() []FeedMessage { + cs.mu.RLock() + defer cs.mu.RUnlock() + + conv, ok := cs.conversations[cs.currentID] + if !ok { + return []FeedMessage{} + } + + out := make([]FeedMessage, len(conv.Messages)) + copy(out, conv.Messages) + return out +} + +// Add adds a message to the current conversation. +func (cs *ConversationStoreMulti) Add(role, content string) FeedMessage { + cs.mu.Lock() + defer cs.mu.Unlock() + + conv, ok := cs.conversations[cs.currentID] + if !ok { + cs.currentID = uuid.New().String() + conv = &Conversation{ + Messages: []FeedMessage{}, + CreatedAt: time.Now().Format(time.RFC3339), + UpdatedAt: time.Now().Format(time.RFC3339), + } + cs.conversations[cs.currentID] = conv + } + + msg := FeedMessage{ + ID: generateMsgID(), + Role: role, + Content: content, + Time: time.Now().Format(time.RFC3339), + } + conv.Messages = append(conv.Messages, msg) + + go cs.saveCurrent() // Fire and forget + + return msg +} + +// Clear clears the current conversation. +func (cs *ConversationStoreMulti) Clear() { + cs.mu.Lock() + defer cs.mu.Unlock() + + conv, ok := cs.conversations[cs.currentID] + if !ok { + return + } + + conv.Messages = []FeedMessage{} + conv.Summary = "" + conv.CreatedAt = time.Now().Format(time.RFC3339) + conv.UpdatedAt = time.Now().Format(time.RFC3339) + + cs.saveCurrent() +} + +// List returns all conversations. +func (cs *ConversationStoreMulti) List() []ConversationMeta { + cs.mu.RLock() + defer cs.mu.RUnlock() + + var metas []ConversationMeta + for id, conv := range cs.conversations { + title := "Nouvelle conversation" + if len(conv.Messages) > 0 { + for _, m := range conv.Messages { + if m.Role == "user" { + if len(m.Content) > 50 { + title = m.Content[:50] + "..." + } else { + title = m.Content + } + break + } + } + } + metas = append(metas, ConversationMeta{ + ID: id, + Title: title, + CreatedAt: conv.CreatedAt, + UpdatedAt: conv.UpdatedAt, + MessageCount: len(conv.Messages), + }) + } + + return metas +} + +// Create creates a new conversation and switches to it. +func (cs *ConversationStoreMulti) Create() string { + cs.mu.Lock() + defer cs.mu.Unlock() + + id := uuid.New().String() + cs.conversations[id] = &Conversation{ + Messages: []FeedMessage{}, + CreatedAt: time.Now().Format(time.RFC3339), + UpdatedAt: time.Now().Format(time.RFC3339), + } + cs.currentID = id + cs.saveIndex() + + return id +} + +// Switch switches to a different conversation. +func (cs *ConversationStoreMulti) Switch(id string) error { + cs.mu.Lock() + defer cs.mu.Unlock() + + if _, ok := cs.conversations[id]; !ok { + return fmt.Errorf("conversation not found: %s", id) + } + + cs.currentID = id + cs.saveIndex() + + return nil +} + +// GetByID returns a conversation by ID. +func (cs *ConversationStoreMulti) GetByID(id string) (*Conversation, error) { + cs.mu.RLock() + defer cs.mu.RUnlock() + + conv, ok := cs.conversations[id] + if !ok { + return nil, fmt.Errorf("conversation not found: %s", id) + } + + return conv, nil +} + +// Delete deletes a conversation. +func (cs *ConversationStoreMulti) Delete(id string) error { + cs.mu.Lock() + defer cs.mu.Unlock() + + if _, ok := cs.conversations[id]; !ok { + return fmt.Errorf("conversation not found: %s", id) + } + + delete(cs.conversations, id) + + // Delete file + convPath := filepath.Join(cs.dir, fmt.Sprintf("conv_%s.json", id)) + os.Remove(convPath) + + // If deleted current, switch to another + if cs.currentID == id { + if len(cs.conversations) > 0 { + for newID := range cs.conversations { + cs.currentID = newID + break + } + } else { + // Create new default + cs.currentID = uuid.New().String() + cs.conversations[cs.currentID] = &Conversation{ + Messages: []FeedMessage{}, + CreatedAt: time.Now().Format(time.RFC3339), + UpdatedAt: time.Now().Format(time.RFC3339), + } + } + } + + cs.saveIndex() + + return nil +} + +// CurrentID returns the current conversation ID. +func (cs *ConversationStoreMulti) CurrentID() string { + cs.mu.RLock() + defer cs.mu.RUnlock() + + return cs.currentID +} \ No newline at end of file diff --git a/internal/api/handlers_chat.go b/internal/api/handlers_chat.go index 6378c6a..aef47ba 100644 --- a/internal/api/handlers_chat.go +++ b/internal/api/handlers_chat.go @@ -13,8 +13,6 @@ import ( var thinkingTagRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?`) -const maxToolIterations = 15 - func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { writeError(w, "POST only", http.StatusMethodNotAllowed) @@ -55,108 +53,31 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.WriteHeader(http.StatusOK) + SetupSSEHeaders(w) flusher, canFlush := w.(http.Flusher) - writeSSE := func(data map[string]interface{}) { - b, _ := json.Marshal(data) - w.Write([]byte("data: " + string(b) + "\n\n")) - if canFlush { - flusher.Flush() - } - } + + sseWriter := NewSSEWriter(w) + ctx := context.Background() messages := s.buildContextMessages(userMessage) - var finalContent string - var allToolCalls []map[string]interface{} - var allToolResults []map[string]interface{} - - for i := 0; i < maxToolIterations; i++ { - resp, err := orb.SendWithTools(messages) - if err != nil { - writeSSE(map[string]interface{}{"error": err.Error()}) + engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON) + engine.OnChunk(func(data map[string]interface{}) { + if data == nil { return } - - choice := resp.Choices[0] - content := cleanThinkingTags(choice.Message.Content) - - if content != "" { - words := strings.Fields(content) - for i, w := range words { - chunk := w - if i < len(words)-1 { - chunk += " " - } - writeSSE(map[string]interface{}{"content": chunk}) - } - finalContent = content + sseWriter.Write(data) + if canFlush { + flusher.Flush() } + }) - if len(choice.Message.ToolCalls) == 0 { - break - } - - assistantMsg := orchestrator.Message{ - Role: "assistant", - Content: content, - ToolCalls: choice.Message.ToolCalls, - } - messages = append(messages, assistantMsg) - - for _, tc := range choice.Message.ToolCalls { - toolCallData := map[string]interface{}{ - "tool_call_id": tc.ID, - "name": tc.Function.Name, - "args": tc.Function.Arguments, - } - allToolCalls = append(allToolCalls, toolCallData) - writeSSE(map[string]interface{}{"tool_call": toolCallData}) - - call := agent.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: json.RawMessage(tc.Function.Arguments), - } - - result, execErr := s.agentRegistry.Execute(ctx, call) - if execErr != nil { - result = agent.ToolResponse{ - Content: execErr.Error(), - IsError: true, - } - } - - resultData := map[string]interface{}{ - "tool_call_id": tc.ID, - "content": result.Content, - "is_error": result.IsError, - } - writeSSE(map[string]interface{}{"tool_result": resultData}) - - allToolResults = append(allToolResults, map[string]interface{}{ - "tool_call_id": tc.ID, - "name": tc.Function.Name, - "args": tc.Function.Arguments, - "result": result.Content, - "is_error": result.IsError, - }) - - messages = append(messages, orchestrator.Message{ - Role: "tool", - Content: result.Content, - ToolCallID: tc.ID, - Name: tc.Function.Name, - }) - } - - finalContent = "" + finalContent, allToolCalls, allToolResults, err := engine.RunWithTools(ctx, messages) + if err != nil { + sseWriter.Write(map[string]interface{}{"error": err.Error()}) + return } storeContent := finalContent @@ -171,68 +92,18 @@ func (s *Server) handleStreamChat(w http.ResponseWriter, orb *orchestrator.Orche } s.convStore.Add("assistant", storeContent) - writeSSE(map[string]interface{}{"done": "true"}) + sseWriter.Write(map[string]interface{}{"done": "true"}) } func (s *Server) handleNonStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) { ctx := context.Background() messages := s.buildContextMessages(userMessage) - var finalContent string - - for i := 0; i < maxToolIterations; i++ { - resp, err := orb.SendWithTools(messages) - if err != nil { - writeError(w, err.Error(), http.StatusInternalServerError) - return - } - - choice := resp.Choices[0] - content := cleanThinkingTags(choice.Message.Content) - - if content != "" { - finalContent = content - } - - if len(choice.Message.ToolCalls) == 0 { - break - } - - assistantMsg := orchestrator.Message{ - Role: "assistant", - Content: content, - ToolCalls: choice.Message.ToolCalls, - } - messages = append(messages, assistantMsg) - - for _, tc := range choice.Message.ToolCalls { - call := agent.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: json.RawMessage(tc.Function.Arguments), - } - - result, execErr := s.agentRegistry.Execute(ctx, call) - if execErr != nil { - result = agent.ToolResponse{ - Content: execErr.Error(), - IsError: true, - } - } - - messages = append(messages, orchestrator.Message{ - Role: "tool", - Content: result.Content, - ToolCallID: tc.ID, - Name: tc.Function.Name, - }) - } - - finalContent = "" - } - - if finalContent == "" { - finalContent = "(tool calls completed, no text response)" + engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON) + finalContent, err := engine.RunNonStream(ctx, messages) + if err != nil { + writeError(w, err.Error(), http.StatusInternalServerError) + return } s.convStore.Add("assistant", finalContent) diff --git a/internal/api/handlers_shell_chat.go b/internal/api/handlers_shell_chat.go index 60bb0d3..e3c633a 100644 --- a/internal/api/handlers_shell_chat.go +++ b/internal/api/handlers_shell_chat.go @@ -6,7 +6,6 @@ import ( "net/http" "strings" - "github.com/muyue/muyue/internal/agent" "github.com/muyue/muyue/internal/orchestrator" ) @@ -35,6 +34,22 @@ type ToolCallInfo struct { Error string `json:"error,omitempty"` } +func toString(v interface{}) string { + if v == nil { + return "" + } + s, _ := v.(string) + return s +} + +func toBool(v interface{}) bool { + if v == nil { + return false + } + b, _ := v.(bool) + return b +} + func (s *Server) handleShellChat(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { writeError(w, "POST only", http.StatusMethodNotAllowed) @@ -102,109 +117,59 @@ Tu peux appeler des outils pour exécuter des commandes, lire des fichiers, etc. } func (s *Server) handleShellChatStream(w http.ResponseWriter, orb *orchestrator.Orchestrator, req ShellChatRequest) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.WriteHeader(http.StatusOK) + SetupSSEHeaders(w) flusher, canFlush := w.(http.Flusher) - - writeSSE := func(data map[string]interface{}) { - b, _ := json.Marshal(data) - w.Write([]byte("data: " + string(b) + "\n\n")) - if canFlush { - flusher.Flush() - } - } + sseWriter := NewSSEWriter(w) ctx := context.Background() messages := []orchestrator.Message{ {Role: "user", Content: req.Message}, } - var finalContent string - var toolCalls []ToolCallInfo + engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON) - for i := 0; i < maxShellToolIterations; i++ { - resp, err := orb.SendWithTools(messages) - if err != nil { - writeSSE(map[string]interface{}{"error": err.Error()}) + var toolCalls []ToolCallInfo + engine.OnChunk(func(data map[string]interface{}) { + if data == nil { return } - - choice := resp.Choices[0] - content := cleanThinkingTags(choice.Message.Content) - - if content != "" { - words := strings.Fields(content) - for i, w := range words { - chunk := w - if i < len(words)-1 { - chunk += " " - } - writeSSE(map[string]interface{}{"content": chunk}) - } - finalContent = content + sseWriter.Write(data) + if canFlush { + flusher.Flush() } - - if len(choice.Message.ToolCalls) == 0 { - break - } - - assistantMsg := orchestrator.Message{ - Role: "assistant", - Content: content, - ToolCalls: choice.Message.ToolCalls, - } - messages = append(messages, assistantMsg) - - for _, tc := range choice.Message.ToolCalls { - toolCallData := map[string]interface{}{ - "tool_call_id": tc.ID, - "name": tc.Function.Name, - "args": tc.Function.Arguments, - } - writeSSE(map[string]interface{}{"tool_call": toolCallData}) - + if tc, ok := data["tool_call"].(map[string]interface{}); ok { argsMap := make(map[string]interface{}) - json.Unmarshal([]byte(tc.Function.Arguments), &argsMap) - - tcInfo := ToolCallInfo{ - ID: tc.ID, - Name: tc.Function.Name, + if args, ok := tc["args"].(string); ok { + json.Unmarshal([]byte(args), &argsMap) + } + toolCalls = append(toolCalls, ToolCallInfo{ + ID: toString(tc["tool_call_id"]), + Name: toString(tc["name"]), Args: argsMap, - } - - call := agent.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: json.RawMessage(tc.Function.Arguments), - } - - result, execErr := s.agentRegistry.Execute(ctx, call) - if execErr != nil { - tcInfo.Error = execErr.Error() - writeSSE(map[string]interface{}{"tool_result": tcInfo}) - } else { - tcInfo.Result = &toolResponseData{ - Content: result.Content, - IsError: result.IsError, - Meta: result.Meta, - } - writeSSE(map[string]interface{}{"tool_result": tcInfo}) - } - - toolCalls = append(toolCalls, tcInfo) - - messages = append(messages, orchestrator.Message{ - Role: "tool", - Content: result.Content, - ToolCallID: tc.ID, - Name: tc.Function.Name, }) } + if tr, ok := data["tool_result"].(map[string]interface{}); ok { + tcID := toString(tr["tool_call_id"]) + for i := range toolCalls { + if toolCalls[i].ID == tcID { + if err, ok := tr["is_error"].(bool); ok && err { + toolCalls[i].Error = toString(tr["content"]) + } else { + toolCalls[i].Result = &toolResponseData{ + Content: toString(tr["content"]), + IsError: toBool(tr["is_error"]), + } + } + break + } + } + } + }) - finalContent = "" + finalContent, _, _, err := engine.RunWithTools(ctx, messages) + if err != nil { + sseWriter.Write(map[string]interface{}{"error": err.Error()}) + return } if finalContent == "" && len(toolCalls) > 0 { @@ -215,7 +180,7 @@ func (s *Server) handleShellChatStream(w http.ResponseWriter, orb *orchestrator. Content: finalContent, ToolCalls: toolCalls, }) - writeSSE(map[string]interface{}{"done": true, "response": string(writeJSONResp)}) + sseWriter.Write(map[string]interface{}{"done": true, "response": string(writeJSONResp)}) } func (s *Server) handleShellChatNonStream(w http.ResponseWriter, orb *orchestrator.Orchestrator, req ShellChatRequest) { @@ -224,80 +189,20 @@ func (s *Server) handleShellChatNonStream(w http.ResponseWriter, orb *orchestrat {Role: "user", Content: req.Message}, } - var finalContent string - var toolCalls []ToolCallInfo + engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON) - for i := 0; i < maxShellToolIterations; i++ { - resp, err := orb.SendWithTools(messages) - if err != nil { - writeError(w, err.Error(), http.StatusInternalServerError) - return - } - - choice := resp.Choices[0] - content := cleanThinkingTags(choice.Message.Content) - - if content != "" { - finalContent = content - } - - if len(choice.Message.ToolCalls) == 0 { - break - } - - assistantMsg := orchestrator.Message{ - Role: "assistant", - Content: content, - ToolCalls: choice.Message.ToolCalls, - } - messages = append(messages, assistantMsg) - - for _, tc := range choice.Message.ToolCalls { - argsMap := make(map[string]interface{}) - json.Unmarshal([]byte(tc.Function.Arguments), &argsMap) - - tcInfo := ToolCallInfo{ - ID: tc.ID, - Name: tc.Function.Name, - Args: argsMap, - } - - call := agent.ToolCall{ - ID: tc.ID, - Name: tc.Function.Name, - Arguments: json.RawMessage(tc.Function.Arguments), - } - - result, execErr := s.agentRegistry.Execute(ctx, call) - if execErr != nil { - tcInfo.Error = execErr.Error() - } else { - tcInfo.Result = &toolResponseData{ - Content: result.Content, - IsError: result.IsError, - Meta: result.Meta, - } - } - - toolCalls = append(toolCalls, tcInfo) - - messages = append(messages, orchestrator.Message{ - Role: "tool", - Content: result.Content, - ToolCallID: tc.ID, - Name: tc.Function.Name, - }) - } - - finalContent = "" + finalContent, err := engine.RunNonStream(ctx, messages) + if err != nil { + writeError(w, err.Error(), http.StatusInternalServerError) + return } - if finalContent == "" && len(toolCalls) > 0 { + if finalContent == "" { finalContent = "(tool calls completed, no text response)" } writeJSON(w, ShellChatResponse{ Content: finalContent, - ToolCalls: toolCalls, + ToolCalls: nil, }) } \ No newline at end of file diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go new file mode 100644 index 0000000..6ee285c --- /dev/null +++ b/internal/api/handlers_test.go @@ -0,0 +1,66 @@ +package api + +import ( + "context" + "encoding/json" + "testing" + + "github.com/muyue/muyue/internal/agent" +) + +func TestHandleToolCall(t *testing.T) { + // Test unknown tool returns error + registry := agent.NewRegistry() + + // Register a test tool + testTool, _ := agent.NewTool[struct{ Command string }]("test_tool", "Test tool", func(ctx context.Context, params struct{ Command string }) (agent.ToolResponse, error) { + return agent.TextResponse("executed: " + params.Command), nil + }) + registry.Register(testTool) + + // Test executing known tool + resp, err := registry.Execute(context.Background(), agent.ToolCall{ + ID: "test-id", + Name: "test_tool", + Arguments: json.RawMessage(`{"Command": "hello"}`), + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if resp.IsError { + t.Errorf("expected no error, got error response") + } + + // Test executing unknown tool + resp, err = registry.Execute(context.Background(), agent.ToolCall{ + ID: "test-id", + Name: "unknown_tool", + Arguments: json.RawMessage(`{}`), + }) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !resp.IsError { + t.Errorf("expected error for unknown tool") + } +} + +func TestCleanThinkingTags(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"hello world", "hello world"}, + {"thinkinghello", "hello"}, + {"THINKINGhello", "hello"}, + {"hello thinking world", "hello world"}, + {"no tags here", "no tags here"}, + } + + for _, tc := range tests { + result := cleanThinkingTags(tc.input) + if result != tc.expected { + t.Errorf("cleanThinkingTags(%q) = %q, want %q", tc.input, result, tc.expected) + } + } +} \ No newline at end of file diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index ee8b171..7c70887 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "regexp" "strings" @@ -76,6 +77,11 @@ var sharedHTTPClient = &http.Client{ Timeout: 120 * time.Second, } +// requestClient creates an HTTP client with the specified timeout. +func requestClient(timeout time.Duration) *http.Client { + return &http.Client{Timeout: timeout} +} + func New(cfg *config.MuyueConfig) (*Orchestrator, error) { var provider *config.AIProvider for i := range cfg.AI.Providers { @@ -300,6 +306,142 @@ func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) return chatResp, nil } +// ChunkCallback is called for each streaming chunk. +type ChunkCallback func(content string, toolCalls []ToolCallMsg) + +// SendWithToolsStream sends messages with streaming responses. +// The callback receives chunks of content and tool_calls as they arrive. +func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) { + fullMessages := make([]Message, 0, len(messages)+1) + if o.systemPrompt != "" { + fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt}) + } + fullMessages = append(fullMessages, messages...) + + reqBody := ChatRequest{ + Model: o.provider.Model, + Messages: fullMessages, + Stream: true, + Tools: o.tools, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + provider := o.provider + baseURL := provider.BaseURL + if baseURL == "" { + baseURL = getProviderBaseURL(provider.Name) + } + + url := strings.TrimRight(baseURL, "/") + "/chat/completions" + + req, err := http.NewRequest("POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+provider.APIKey) + + resp, err := o.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) + } + + var fullContent strings.Builder + var accumulatedToolCalls []ToolCallMsg + var totalTokens int + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chatResp ChatResponse + if err := json.Unmarshal([]byte(data), &chatResp); err != nil { + continue + } + + if len(chatResp.Choices) > 0 { + chunk := chatResp.Choices[0].Delta.Content + if chunk != "" { + fullContent.WriteString(chunk) + onChunk(chunk, nil) + } + + // Handle delta tool calls + if len(chatResp.Choices[0].Delta.ToolCalls) > 0 { + for _, tc := range chatResp.Choices[0].Delta.ToolCalls { + // Find or create the tool call in accumulated list + found := false + for i := range accumulatedToolCalls { + if accumulatedToolCalls[i].ID == tc.ID { + // Append arguments + accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments + found = true + break + } + } + if !found { + accumulatedToolCalls = append(accumulatedToolCalls, tc) + } + } + onChunk("", accumulatedToolCalls) + } + + // Capture usage from final chunk + if chatResp.Usage.TotalTokens > 0 { + totalTokens = chatResp.Usage.TotalTokens + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read stream: %w", err) + } + + // Build final response + finalResp := &ChatResponse{ + Usage: struct { + TotalTokens int `json:"total_tokens"` + }{TotalTokens: totalTokens}, + Choices: []struct { + Message struct { + Content string `json:"content"` + ToolCalls []ToolCallMsg `json:"tool_calls"` + } `json:"message"` + Delta struct { + Content string `json:"content"` + ToolCalls []ToolCallMsg `json:"tool_calls"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` + }{}, + } + + finalContent := cleanAIResponse(fullContent.String()) + finalResp.Choices[0].Message.Content = finalContent + finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls + + return finalResp, nil +} + func cleanAIResponse(content string) string { content = thinkRegex.ReplaceAllString(content, "") lines := strings.Split(content, "\n") @@ -368,7 +510,9 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str } var lastErr error + var triedProviders []string for _, prov := range providerOrder { + triedProviders = append(triedProviders, prov.Name) baseURL := baseURLOverride if baseURL == "" { baseURL = prov.BaseURL @@ -392,7 +536,14 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+prov.APIKey) + + // Provider-specific headers + if prov.Name == "anthropic" { + req.Header.Set("x-api-key", prov.APIKey) + req.Header.Set("anthropic-version", "2023-06-01") + } else { + req.Header.Set("Authorization", "Bearer "+prov.APIKey) + } resp, err := o.client.Do(req) if err != nil { @@ -427,5 +578,6 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str return &chatResp, prov.Name, nil } + log.Printf("[orchestrator] fallback from %v to next provider", triedProviders) return nil, "", lastErr } diff --git a/internal/version/version.go b/internal/version/version.go index 265f33b..5c5e1cc 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -1,11 +1,33 @@ package version +import ( + "fmt" + "runtime" +) + const ( Name = "muyue" Version = "0.3.2" Author = "La Légion de Muyue" ) +var ( + // BuildDate is set at build time + BuildDate = "" +) + func FullVersion() string { return Name + " v" + Version } + +// FullInfo returns full version information. +func FullInfo() string { + info := fmt.Sprintf("%-12s %s\n", "Version:", Version) + info += fmt.Sprintf("%-12s %s\n", "Author:", Author) + info += fmt.Sprintf("%-12s %s\n", "Go:", runtime.Version()) + info += fmt.Sprintf("%-12s %s\n", "Platform:", runtime.GOOS+"/"+runtime.GOARCH) + if BuildDate != "" { + info += fmt.Sprintf("%-12s %s\n", "Build:", BuildDate) + } + return info +} diff --git a/web/src/components/Studio.jsx b/web/src/components/Studio.jsx index 21af5e3..fd4899d 100644 --- a/web/src/components/Studio.jsx +++ b/web/src/components/Studio.jsx @@ -53,8 +53,12 @@ function renderContent(text) { } function formatText(text) { - return text + // First escape HTML entities + let html = text .replace(/&/g, '&').replace(//g, '>') + + // Apply markdown transformations (now with escaped brackets) + html = html .replace(/\*\*(.+?)\*\*/g, '$1') .replace(/`([^`]+)`/g, '$1') .replace(/^### (.+)$/gm, '

$1

') @@ -62,6 +66,14 @@ function formatText(text) { .replace(/^# (.+)$/gm, '

$1

') .replace(/^\s*[-*] (.+)$/gm, '
• $1
') .replace(/^\s*(\d+)[.)] (.+)$/gm, '
$1 $2
') + + // Sanitize: remove event handlers and dangerous protocols + html = html + .replace(/\s+on\w+=["'][^"']*["']/gi, '') // Remove on* event handlers + .replace(/javascript:/gi, '') + .replace(/data:/gi, '') + + return html } function ThinkingBlock({ content, done }) { @@ -324,6 +336,65 @@ export default function Studio({ api }) { return } + if (text === '/help') { + const helpMsg = [ + '## Commandes Studio', + '', + '- `/clear` - Effacer la conversation', + '- `/help` - Afficher cette aide', + '- `/plan ` - Demander un plan structuré', + '- `/export` - Exporter la conversation en Markdown', + '- `/model` - Afficher le provider et modèle actifs', + '', + '## Tools disponibles', + '- Terminal - Exécuter des commandes', + '- read_file - Lire des fichiers', + '- list_files - Lister des fichiers', + '- search_files - Rechercher des fichiers', + '- grep_content - Rechercher dans le contenu', + '- get_config - Lire la configuration', + '- web_fetch - Récupérer une page web', + ].join('\n') + setMessages(prev => [...prev, { id: Date.now().toString(), role: 'assistant', content: helpMsg, time: new Date().toISOString() }]) + return + } + + if (text === '/model') { + api.getProviders().then(data => { + const active = data.providers?.find(p => p.active) + const modelMsg = active ? `Provider: ${active.name}\nModèle: ${active.model}` : 'Aucun provider actif configuré' + setMessages(prev => [...prev, { id: Date.now().toString(), role: 'assistant', content: modelMsg, time: new Date().toISOString() }]) + }).catch(() => { + setMessages(prev => [...prev, { id: Date.now().toString(), role: 'assistant', content: 'Erreur: impossible de récupérer les providers', time: new Date().toISOString() }]) + }) + return + } + + if (text.startsWith('/plan ')) { + const objective = text.slice(6).trim() + if (!objective) { + setMessages(prev => [...prev, { id: Date.now().toString(), role: 'assistant', content: 'Usage: `/plan `\nEx: `/plan créer un fichier de test`', time: new Date().toISOString() }]) + return + } + setInput(`Crée un plan structuré en étapes numérotées pour: ${objective}. Chaque étape devrait avoir une estimation de complexité et de temps.`) + handleSend() + return + } + + if (text === '/export') { + api.getChatHistory().then(data => { + let markdown = '# Conversation Export\n\n' + data.messages?.forEach((msg, i) => { + const roleLabel = msg.role === 'user' ? '👤' : (msg.role === 'assistant' ? '🤖' : '⚙️') + markdown += `## [${i + 1}] ${roleLabel} ${msg.role}\n${msg.content}\n\n---\n\n` + }) + setMessages(prev => [...prev, { id: Date.now().toString(), role: 'assistant', content: 'Conversation exportée:\n```markdown\n' + markdown + '```', time: new Date().toISOString() }]) + }).catch(() => { + setMessages(prev => [...prev, { id: Date.now().toString(), role: 'assistant', content: 'Erreur: impossible d\'exporter la conversation', time: new Date().toISOString() }]) + }) + return + } + const userMsg = { id: Date.now().toString(), role: 'user', content: text, time: new Date().toISOString() } setMessages(prev => [...prev, userMsg]) setLoading(true) @@ -472,7 +543,7 @@ export default function Studio({ api }) { )}
- {t('studio.inputHint')} · /clear + {t('studio.inputHint')} · /clear /help /plan /export /model
diff --git a/web/src/styles/global.css b/web/src/styles/global.css index 31dfa28..f25c2f4 100644 --- a/web/src/styles/global.css +++ b/web/src/styles/global.css @@ -684,6 +684,8 @@ input::placeholder { color: var(--text-disabled); } background: var(--bg-surface); border: 1px solid var(--border); border-left: 2px solid var(--accent-dim); border-radius: var(--radius); margin: 6px 0 8px; overflow: hidden; transition: all 0.3s ease; + max-height: 200px; + overflow-y: auto; } .feed-thinking-block.active { border-left-color: var(--warning); @@ -826,7 +828,8 @@ input::placeholder { color: var(--text-disabled); } font-size: 12px; font-family: var(--font-mono); color: var(--text-tertiary); - white-space: nowrap; + white-space: pre-wrap; + word-break: break-all; overflow: hidden; text-overflow: ellipsis; border-bottom: 1px solid var(--border);