package orchestrator import ( "bufio" "bytes" "encoding/json" "fmt" "io" "log" "net/http" "regexp" "strings" "sync" "time" "github.com/muyue/muyue/internal/config" ) var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?`) const maxHistorySize = 100 type ContentPart struct { Type string `json:"type"` Text string `json:"text,omitempty"` ImageURL *ImageURL `json:"image_url,omitempty"` } type ImageURL struct { URL string `json:"url"` } type Message struct { Role string `json:"role"` Content json.RawMessage `json:"content,omitempty"` ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` Name string `json:"name,omitempty"` } func TextContent(s string) json.RawMessage { b, _ := json.Marshal(s) return b } func PartsContent(parts []ContentPart) json.RawMessage { b, _ := json.Marshal(parts) return b } func (m Message) ContentString() string { var s string if json.Unmarshal(m.Content, &s) == nil { return s } return string(m.Content) } type ToolCallMsg struct { ID string `json:"id"` Type string `json:"type"` Function ToolCallFuncMsg `json:"function"` } type ToolCallFuncMsg struct { Name string `json:"name"` Arguments string `json:"arguments"` } type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream bool `json:"stream"` Tools json.RawMessage `json:"tools,omitempty"` } type ChatResponse struct { Choices []struct { Message struct { Content string `json:"content"` ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"message"` Delta struct { Content string `json:"content"` ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"delta"` FinishReason *string `json:"finish_reason"` } `json:"choices"` Usage struct { TotalTokens int `json:"total_tokens"` } `json:"usage"` } type Orchestrator struct { config *config.MuyueConfig provider *config.AIProvider client *http.Client history []Message histMu sync.Mutex systemPrompt string tools json.RawMessage } var sharedHTTPClient = &http.Client{ Timeout: 120 * time.Second, } // requestClient creates an HTTP client with the specified timeout. func requestClient(timeout time.Duration) *http.Client { return &http.Client{Timeout: timeout} } func New(cfg *config.MuyueConfig) (*Orchestrator, error) { var provider *config.AIProvider for i := range cfg.AI.Providers { if cfg.AI.Providers[i].Active { provider = &cfg.AI.Providers[i] break } } if provider == nil { return nil, fmt.Errorf("no active AI provider configured") } if provider.APIKey == "" { return nil, fmt.Errorf("API key not set for %s", provider.Name) } return &Orchestrator{ config: cfg, provider: provider, client: sharedHTTPClient, history: []Message{}, }, nil } func (o *Orchestrator) SetSystemPrompt(prompt string) { o.systemPrompt = prompt } func (o *Orchestrator) SetTools(tools json.RawMessage) { o.tools = tools } func (o *Orchestrator) ProviderName() string { if o.provider == nil { return "" } return o.provider.Name } func (o *Orchestrator) AppendHistory(msg Message) { o.histMu.Lock() defer o.histMu.Unlock() o.history = append(o.history, msg) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } } func (o *Orchestrator) GetHistory() []Message { o.histMu.Lock() defer o.histMu.Unlock() out := make([]Message, len(o.history)) copy(out, o.history) return out } func (o *Orchestrator) Send(userMessage string) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ Role: "user", Content: TextContent(userMessage), }) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } messages := make([]Message, 0, len(o.history)+1) if o.systemPrompt != "" { messages = append(messages, Message{Role: "system", Content: TextContent(o.systemPrompt)}) } messages = append(messages, o.history...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: messages, Stream: false, Tools: o.tools, } o.histMu.Unlock() chatResp, providerName, err := o.sendWithFallback(reqBody, "") if err != nil { return "", err } content := cleanAIResponse(chatResp.Choices[0].Message.Content) o.histMu.Lock() o.history = append(o.history, Message{ Role: "assistant", Content: TextContent(content), }) _ = providerName o.histMu.Unlock() return content, nil } func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ Role: "user", Content: TextContent(userMessage), }) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } messages := make([]Message, 0, len(o.history)+1) if o.systemPrompt != "" { messages = append(messages, Message{Role: "system", Content: TextContent(o.systemPrompt)}) } messages = append(messages, o.history...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: messages, Stream: true, Tools: o.tools, } o.histMu.Unlock() body, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal request: %w", err) } baseURL := o.provider.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(o.provider.Name) } url := strings.TrimRight(baseURL, "/") + "/chat/completions" req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+o.provider.APIKey) resp, err := o.client.Do(req) if err != nil { return "", fmt.Errorf("send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) } var fullContent strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chatResp ChatResponse if err := json.Unmarshal([]byte(data), &chatResp); err != nil { continue } if len(chatResp.Choices) > 0 { chunk := chatResp.Choices[0].Delta.Content if chunk != "" { fullContent.WriteString(chunk) onChunk(chunk) } } } if err := scanner.Err(); err != nil { return fullContent.String(), fmt.Errorf("read stream: %w", err) } content := cleanAIResponse(fullContent.String()) o.histMu.Lock() o.history = append(o.history, Message{ Role: "assistant", Content: TextContent(content), }) o.histMu.Unlock() return content, nil } func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) { fullMessages := make([]Message, 0, len(messages)+1) if o.systemPrompt != "" { fullMessages = append(fullMessages, Message{Role: "system", Content: TextContent(o.systemPrompt)}) } fullMessages = append(fullMessages, messages...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: fullMessages, Stream: false, Tools: o.tools, } chatResp, _, err := o.sendWithFallback(reqBody, "") if err != nil { return nil, err } if len(chatResp.Choices) == 0 { return nil, fmt.Errorf("no response from AI") } return chatResp, nil } // ChunkCallback is called for each streaming chunk. type ChunkCallback func(content string, toolCalls []ToolCallMsg) // SendWithToolsStream sends messages with streaming responses. // The callback receives chunks of content and tool_calls as they arrive. func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) { fullMessages := make([]Message, 0, len(messages)+1) if o.systemPrompt != "" { fullMessages = append(fullMessages, Message{Role: "system", Content: TextContent(o.systemPrompt)}) } fullMessages = append(fullMessages, messages...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: fullMessages, Stream: true, Tools: o.tools, } body, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } provider := o.provider baseURL := provider.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(provider.Name) } url := strings.TrimRight(baseURL, "/") + "/chat/completions" req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+provider.APIKey) resp, err := o.client.Do(req) if err != nil { return nil, fmt.Errorf("send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) } var fullContent strings.Builder var accumulatedToolCalls []ToolCallMsg var totalTokens int scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chatResp ChatResponse if err := json.Unmarshal([]byte(data), &chatResp); err != nil { continue } if len(chatResp.Choices) > 0 { chunk := chatResp.Choices[0].Delta.Content if chunk != "" { fullContent.WriteString(chunk) onChunk(chunk, nil) } // Handle delta tool calls if len(chatResp.Choices[0].Delta.ToolCalls) > 0 { for _, tc := range chatResp.Choices[0].Delta.ToolCalls { // Find or create the tool call in accumulated list found := false for i := range accumulatedToolCalls { if accumulatedToolCalls[i].ID == tc.ID { // Append arguments accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments found = true break } } if !found { accumulatedToolCalls = append(accumulatedToolCalls, tc) } } onChunk("", accumulatedToolCalls) } // Capture usage from final chunk if chatResp.Usage.TotalTokens > 0 { totalTokens = chatResp.Usage.TotalTokens } } } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("read stream: %w", err) } // Build final response finalResp := &ChatResponse{ Usage: struct { TotalTokens int `json:"total_tokens"` }{TotalTokens: totalTokens}, Choices: []struct { Message struct { Content string `json:"content"` ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"message"` Delta struct { Content string `json:"content"` ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"delta"` FinishReason *string `json:"finish_reason"` }{}, } finalContent := cleanAIResponse(fullContent.String()) finalResp.Choices[0].Message.Content = finalContent finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls return finalResp, nil } func cleanAIResponse(content string) string { content = thinkRegex.ReplaceAllString(content, "") lines := strings.Split(content, "\n") var clean []string inBlock := false for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "<<" || trimmed == "<<<" { inBlock = true continue } if trimmed == ">>" || trimmed == ">>>" { inBlock = false continue } if inBlock { continue } clean = append(clean, line) } result := strings.TrimSpace(strings.Join(clean, "\n")) return result } func getProviderBaseURL(name string) string { switch name { case "minimax": return "https://api.minimax.io/v1" case "anthropic": return "https://api.anthropic.com/v1" case "openai": return "https://api.openai.com/v1" case "zai": return "https://api.z.ai/v1" case "mimo": return "https://token-plan-ams.xiaomimimo.com/v1" default: return "" } } func (o *Orchestrator) getAvailableProviders() []*config.AIProvider { var providers []*config.AIProvider for i := range o.config.AI.Providers { prov := &o.config.AI.Providers[i] if prov.APIKey != "" { providers = append(providers, prov) } } return providers } func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride string) (*ChatResponse, string, error) { providers := o.getAvailableProviders() if len(providers) == 0 { return nil, "", fmt.Errorf("no providers available") } providerOrder := make([]*config.AIProvider, 0, len(providers)) if o.provider != nil { providerOrder = append(providerOrder, o.provider) } var zaiProvider *config.AIProvider for _, p := range providers { if o.provider == nil || p.Name != o.provider.Name { if p.Name == "zai" { zaiProvider = p } else { providerOrder = append(providerOrder, p) } } } if zaiProvider != nil { providerOrder = append(providerOrder, zaiProvider) } var lastErr error var triedProviders []string for _, prov := range providerOrder { triedProviders = append(triedProviders, prov.Name) baseURL := baseURLOverride if baseURL == "" { baseURL = prov.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(prov.Name) } } url := strings.TrimRight(baseURL, "/") + "/chat/completions" body, err := json.Marshal(reqBody) if err != nil { lastErr = fmt.Errorf("marshal request: %w", err) continue } req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { lastErr = fmt.Errorf("create request: %w", err) continue } req.Header.Set("Content-Type", "application/json") // Provider-specific headers if prov.Name == "anthropic" { req.Header.Set("x-api-key", prov.APIKey) req.Header.Set("anthropic-version", "2023-06-01") } else { req.Header.Set("Authorization", "Bearer "+prov.APIKey) } resp, err := o.client.Do(req) if err != nil { lastErr = fmt.Errorf("send request to %s: %w", prov.Name, err) continue } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { lastErr = fmt.Errorf("read response: %w", err) continue } if resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) continue } var chatResp ChatResponse if err := json.Unmarshal(respBody, &chatResp); err != nil { lastErr = fmt.Errorf("parse response: %w", err) continue } if len(chatResp.Choices) == 0 { lastErr = fmt.Errorf("no response from AI") continue } o.provider = prov return &chatResp, prov.Name, nil } log.Printf("[orchestrator] fallback from %v to next provider", triedProviders) return nil, "", lastErr }