package api import ( "context" "encoding/json" "net/http" "github.com/muyue/muyue/internal/agent" "github.com/muyue/muyue/internal/orchestrator" ) const ( MaxToolIterations = 15 ) // ChatEngine handles chat interactions with tool execution. // This deduplicates chat logic previously repeated in handlers_chat.go and handlers_shell_chat.go. type ChatEngine struct { orchestrator *orchestrator.Orchestrator registry *agent.Registry tools json.RawMessage onChunk func(map[string]interface{}) stream bool TotalTokens int } // NewChatEngine creates a new ChatEngine instance. func NewChatEngine(orb *orchestrator.Orchestrator, registry *agent.Registry, tools json.RawMessage) *ChatEngine { return &ChatEngine{ orchestrator: orb, registry: registry, tools: tools, stream: false, } } // SetStream enables streaming mode for the chat engine. func (ce *ChatEngine) SetStream(enabled bool) { ce.stream = enabled } // OnChunk sets the callback for SSE chunk writing. func (ce *ChatEngine) OnChunk(fn func(map[string]interface{})) { ce.onChunk = fn } // RunWithTools executes the chat loop with tool calls. // Returns final content, tool calls, tool results, and error. func (ce *ChatEngine) RunWithTools(ctx context.Context, messages []orchestrator.Message) (string, []map[string]interface{}, []map[string]interface{}, error) { var finalContent string var allToolCalls []map[string]interface{} var allToolResults []map[string]interface{} for i := 0; i < MaxToolIterations; i++ { var resp *orchestrator.ChatResponse var err error if ce.stream { // Use streaming version resp, err = ce.orchestrator.SendWithToolsStream(messages, func(content string, toolCalls []orchestrator.ToolCallMsg) { if ce.onChunk != nil && content != "" { ce.onChunk(map[string]interface{}{"content": content}) } }) } else { resp, err = ce.orchestrator.SendWithTools(messages) } if err != nil { if ce.onChunk != nil { ce.onChunk(map[string]interface{}{"error": err.Error()}) } return finalContent, allToolCalls, allToolResults, err } if resp.Usage.TotalTokens > 0 { ce.TotalTokens += resp.Usage.TotalTokens } choice := resp.Choices[0] content := cleanThinkingTags(choice.Message.Content) if content != "" { if ce.onChunk != nil { ce.onChunk(map[string]interface{}{"content": content}) } finalContent = content } if len(choice.Message.ToolCalls) == 0 { break } assistantMsg := orchestrator.Message{ Role: "assistant", Content: orchestrator.TextContent(content), ToolCalls: choice.Message.ToolCalls, } messages = append(messages, assistantMsg) for _, tc := range choice.Message.ToolCalls { toolCallData := map[string]interface{}{ "tool_call_id": tc.ID, "name": tc.Function.Name, "args": tc.Function.Arguments, } allToolCalls = append(allToolCalls, toolCallData) if ce.onChunk != nil { ce.onChunk(map[string]interface{}{"tool_call": toolCallData}) } call := agent.ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: json.RawMessage(tc.Function.Arguments), } result, execErr := ce.registry.Execute(ctx, call) if execErr != nil { result = agent.ToolResponse{ Content: execErr.Error(), IsError: true, } } resultData := map[string]interface{}{ "tool_call_id": tc.ID, "content": result.Content, "is_error": result.IsError, } if result.Meta != nil { for k, v := range result.Meta { resultData[k] = v } } allToolResults = append(allToolResults, map[string]interface{}{ "tool_call_id": tc.ID, "name": tc.Function.Name, "args": tc.Function.Arguments, "result": result.Content, "is_error": result.IsError, }) if ce.onChunk != nil { ce.onChunk(map[string]interface{}{"tool_result": resultData}) } messages = append(messages, orchestrator.Message{ Role: "tool", Content: orchestrator.TextContent(result.Content), ToolCallID: tc.ID, Name: tc.Function.Name, }) } finalContent = "" } return finalContent, allToolCalls, allToolResults, nil } // ProviderName returns the name of the active provider used by the engine. func (ce *ChatEngine) ProviderName() string { return ce.orchestrator.ProviderName() } // RunNonStream executes chat without streaming content to client. func (ce *ChatEngine) RunNonStream(ctx context.Context, messages []orchestrator.Message) (string, error) { var finalContent string for i := 0; i < MaxToolIterations; i++ { resp, err := ce.orchestrator.SendWithTools(messages) if err != nil { return finalContent, err } if resp.Usage.TotalTokens > 0 { ce.TotalTokens += resp.Usage.TotalTokens } choice := resp.Choices[0] content := cleanThinkingTags(choice.Message.Content) if content != "" { finalContent = content } if len(choice.Message.ToolCalls) == 0 { break } assistantMsg := orchestrator.Message{ Role: "assistant", Content: orchestrator.TextContent(content), ToolCalls: choice.Message.ToolCalls, } messages = append(messages, assistantMsg) for _, tc := range choice.Message.ToolCalls { call := agent.ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: json.RawMessage(tc.Function.Arguments), } result, execErr := ce.registry.Execute(ctx, call) if execErr != nil { result = agent.ToolResponse{ Content: execErr.Error(), IsError: true, } } messages = append(messages, orchestrator.Message{ Role: "tool", Content: orchestrator.TextContent(result.Content), ToolCallID: tc.ID, Name: tc.Function.Name, }) } finalContent = "" } if finalContent == "" { finalContent = "(tool calls completed, no text response)" } return finalContent, nil } // SSEWriter handles Server-Sent Events writing to HTTP response. type SSEWriter struct { w http.ResponseWriter flusher http.Flusher } // NewSSEWriter creates a new SSEWriter. func NewSSEWriter(w http.ResponseWriter) *SSEWriter { sse := &SSEWriter{w: w} if f, ok := w.(http.Flusher); ok { sse.flusher = f } return sse } // Write sends an SSE message. func (s *SSEWriter) Write(data map[string]interface{}) { b, _ := json.Marshal(data) s.w.Write([]byte("data: " + string(b) + "\n\n")) if s.flusher != nil { s.flusher.Flush() } } // SetupSSEHeaders sets up SSE response headers. func SetupSSEHeaders(w http.ResponseWriter) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") w.WriteHeader(http.StatusOK) }