package orchestrator import ( "bufio" "bytes" "encoding/json" "fmt" "io" "net/http" "regexp" "strings" "sync" "time" "github.com/muyue/muyue/internal/config" ) var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?`) const maxHistorySize = 100 type Message struct { Role string `json:"role"` Content string `json:"content,omitempty"` ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` Name string `json:"name,omitempty"` } type ToolCallMsg struct { ID string `json:"id"` Type string `json:"type"` Function ToolCallFuncMsg `json:"function"` } type ToolCallFuncMsg struct { Name string `json:"name"` Arguments string `json:"arguments"` } type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream bool `json:"stream"` Tools json.RawMessage `json:"tools,omitempty"` } type ChatResponse struct { Choices []struct { Message struct { Content string `json:"content"` ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"message"` Delta struct { Content string `json:"content"` ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"delta"` FinishReason *string `json:"finish_reason"` } `json:"choices"` Usage struct { TotalTokens int `json:"total_tokens"` } `json:"usage"` } type Orchestrator struct { config *config.MuyueConfig provider *config.AIProvider client *http.Client history []Message histMu sync.Mutex systemPrompt string tools json.RawMessage } var sharedHTTPClient = &http.Client{ Timeout: 120 * time.Second, } func New(cfg *config.MuyueConfig) (*Orchestrator, error) { var provider *config.AIProvider for i := range cfg.AI.Providers { if cfg.AI.Providers[i].Active { provider = &cfg.AI.Providers[i] break } } if provider == nil { return nil, fmt.Errorf("no active AI provider configured") } if provider.APIKey == "" { return nil, fmt.Errorf("API key not set for %s", provider.Name) } return &Orchestrator{ config: cfg, provider: provider, client: sharedHTTPClient, history: []Message{}, }, nil } func (o *Orchestrator) SetSystemPrompt(prompt string) { o.systemPrompt = prompt } func (o *Orchestrator) SetTools(tools json.RawMessage) { o.tools = tools } func (o *Orchestrator) ProviderName() string { if o.provider == nil { return "" } return o.provider.Name } func (o *Orchestrator) AppendHistory(msg Message) { o.histMu.Lock() defer o.histMu.Unlock() o.history = append(o.history, msg) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } } func (o *Orchestrator) GetHistory() []Message { o.histMu.Lock() defer o.histMu.Unlock() out := make([]Message, len(o.history)) copy(out, o.history) return out } func (o *Orchestrator) Send(userMessage string) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ Role: "user", Content: userMessage, }) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } messages := make([]Message, 0, len(o.history)+1) if o.systemPrompt != "" { messages = append(messages, Message{Role: "system", Content: o.systemPrompt}) } messages = append(messages, o.history...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: messages, Stream: false, Tools: o.tools, } o.histMu.Unlock() chatResp, providerName, err := o.sendWithFallback(reqBody, "") if err != nil { return "", err } content := cleanAIResponse(chatResp.Choices[0].Message.Content) o.histMu.Lock() o.history = append(o.history, Message{ Role: "assistant", Content: content, }) _ = providerName o.histMu.Unlock() return content, nil } func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ Role: "user", Content: userMessage, }) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } messages := make([]Message, 0, len(o.history)+1) if o.systemPrompt != "" { messages = append(messages, Message{Role: "system", Content: o.systemPrompt}) } messages = append(messages, o.history...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: messages, Stream: true, Tools: o.tools, } o.histMu.Unlock() body, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal request: %w", err) } baseURL := o.provider.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(o.provider.Name) } url := strings.TrimRight(baseURL, "/") + "/chat/completions" req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+o.provider.APIKey) resp, err := o.client.Do(req) if err != nil { return "", fmt.Errorf("send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) } var fullContent strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chatResp ChatResponse if err := json.Unmarshal([]byte(data), &chatResp); err != nil { continue } if len(chatResp.Choices) > 0 { chunk := chatResp.Choices[0].Delta.Content if chunk != "" { fullContent.WriteString(chunk) onChunk(chunk) } } } if err := scanner.Err(); err != nil { return fullContent.String(), fmt.Errorf("read stream: %w", err) } content := cleanAIResponse(fullContent.String()) o.histMu.Lock() o.history = append(o.history, Message{ Role: "assistant", Content: content, }) o.histMu.Unlock() return content, nil } func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) { fullMessages := make([]Message, 0, len(messages)+1) if o.systemPrompt != "" { fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt}) } fullMessages = append(fullMessages, messages...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: fullMessages, Stream: false, Tools: o.tools, } chatResp, _, err := o.sendWithFallback(reqBody, "") if err != nil { return nil, err } if len(chatResp.Choices) == 0 { return nil, fmt.Errorf("no response from AI") } return chatResp, nil } func cleanAIResponse(content string) string { content = thinkRegex.ReplaceAllString(content, "") lines := strings.Split(content, "\n") var clean []string inBlock := false for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "<<" || trimmed == "<<<" { inBlock = true continue } if trimmed == ">>" || trimmed == ">>>" { inBlock = false continue } if inBlock { continue } clean = append(clean, line) } result := strings.TrimSpace(strings.Join(clean, "\n")) return result } func getProviderBaseURL(name string) string { switch name { case "minimax": return "https://api.minimax.io/v1" case "anthropic": return "https://api.anthropic.com/v1" case "openai": return "https://api.openai.com/v1" case "zai": return "https://api.z.ai/v1" default: return "" } } func (o *Orchestrator) getAvailableProviders() []*config.AIProvider { var providers []*config.AIProvider for i := range o.config.AI.Providers { prov := &o.config.AI.Providers[i] if prov.APIKey != "" { providers = append(providers, prov) } } return providers } func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride string) (*ChatResponse, string, error) { providers := o.getAvailableProviders() if len(providers) == 0 { return nil, "", fmt.Errorf("no providers available") } providerOrder := make([]*config.AIProvider, 0, len(providers)) if o.provider != nil { providerOrder = append(providerOrder, o.provider) } for _, p := range providers { if o.provider == nil || p.Name != o.provider.Name { providerOrder = append(providerOrder, p) } } var lastErr error for _, prov := range providerOrder { baseURL := baseURLOverride if baseURL == "" { baseURL = prov.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(prov.Name) } } url := strings.TrimRight(baseURL, "/") + "/chat/completions" body, err := json.Marshal(reqBody) if err != nil { lastErr = fmt.Errorf("marshal request: %w", err) continue } req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { lastErr = fmt.Errorf("create request: %w", err) continue } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+prov.APIKey) resp, err := o.client.Do(req) if err != nil { lastErr = fmt.Errorf("send request to %s: %w", prov.Name, err) continue } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { lastErr = fmt.Errorf("read response: %w", err) continue } if resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) continue } var chatResp ChatResponse if err := json.Unmarshal(respBody, &chatResp); err != nil { lastErr = fmt.Errorf("parse response: %w", err) continue } if len(chatResp.Choices) == 0 { lastErr = fmt.Errorf("no response from AI") continue } o.provider = prov return &chatResp, prov.Name, nil } return nil, "", lastErr }