package orchestrator import ( "bufio" "bytes" "encoding/json" "fmt" "io" "net/http" "regexp" "strings" "sync" "time" "github.com/muyue/muyue/internal/config" ) var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?`) const maxHistorySize = 100 type Message struct { Role string `json:"role"` Content string `json:"content"` } type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream bool `json:"stream"` } type ChatResponse struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` Delta struct { Content string `json:"content"` } `json:"delta"` } `json:"choices"` Usage struct { TotalTokens int `json:"total_tokens"` } `json:"usage"` } type Orchestrator struct { config *config.MuyueConfig provider *config.AIProvider client *http.Client history []Message histMu sync.Mutex systemPrompt string } var sharedHTTPClient = &http.Client{ Timeout: 120 * time.Second, } func New(cfg *config.MuyueConfig) (*Orchestrator, error) { var provider *config.AIProvider for i := range cfg.AI.Providers { if cfg.AI.Providers[i].Active { provider = &cfg.AI.Providers[i] break } } if provider == nil { return nil, fmt.Errorf("no active AI provider configured") } if provider.APIKey == "" { return nil, fmt.Errorf("API key not set for %s", provider.Name) } return &Orchestrator{ config: cfg, provider: provider, client: sharedHTTPClient, history: []Message{}, }, nil } func (o *Orchestrator) SetSystemPrompt(prompt string) { o.systemPrompt = prompt } func (o *Orchestrator) Send(userMessage string) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ Role: "user", Content: userMessage, }) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } messages := make([]Message, 0, len(o.history)+1) if o.systemPrompt != "" { messages = append(messages, Message{Role: "system", Content: o.systemPrompt}) } messages = append(messages, o.history...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: messages, Stream: false, } o.histMu.Unlock() body, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal request: %w", err) } baseURL := o.provider.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(o.provider.Name) } url := strings.TrimRight(baseURL, "/") + "/chat/completions" req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+o.provider.APIKey) resp, err := o.client.Do(req) if err != nil { return "", fmt.Errorf("send request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("read response: %w", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) } var chatResp ChatResponse if err := json.Unmarshal(respBody, &chatResp); err != nil { return "", fmt.Errorf("parse response: %w", err) } if len(chatResp.Choices) == 0 { return "", fmt.Errorf("no response from AI") } content := cleanAIResponse(chatResp.Choices[0].Message.Content) o.histMu.Lock() o.history = append(o.history, Message{ Role: "assistant", Content: content, }) o.histMu.Unlock() return content, nil } func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ Role: "user", Content: userMessage, }) if len(o.history) > maxHistorySize { o.history = o.history[len(o.history)-maxHistorySize:] } messages := make([]Message, 0, len(o.history)+1) if o.systemPrompt != "" { messages = append(messages, Message{Role: "system", Content: o.systemPrompt}) } messages = append(messages, o.history...) reqBody := ChatRequest{ Model: o.provider.Model, Messages: messages, Stream: true, } o.histMu.Unlock() body, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal request: %w", err) } baseURL := o.provider.BaseURL if baseURL == "" { baseURL = getProviderBaseURL(o.provider.Name) } url := strings.TrimRight(baseURL, "/") + "/chat/completions" req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+o.provider.APIKey) resp, err := o.client.Do(req) if err != nil { return "", fmt.Errorf("send request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) } var fullContent strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var chatResp ChatResponse if err := json.Unmarshal([]byte(data), &chatResp); err != nil { continue } if len(chatResp.Choices) > 0 { chunk := chatResp.Choices[0].Delta.Content if chunk != "" { fullContent.WriteString(chunk) onChunk(chunk) } } } if err := scanner.Err(); err != nil { return fullContent.String(), fmt.Errorf("read stream: %w", err) } content := cleanAIResponse(fullContent.String()) o.histMu.Lock() o.history = append(o.history, Message{ Role: "assistant", Content: content, }) o.histMu.Unlock() return content, nil } func cleanAIResponse(content string) string { content = thinkRegex.ReplaceAllString(content, "") lines := strings.Split(content, "\n") var clean []string inBlock := false for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "<<" || trimmed == "<<<" { inBlock = true continue } if trimmed == ">>" || trimmed == ">>>" { inBlock = false continue } if inBlock { continue } clean = append(clean, line) } result := strings.TrimSpace(strings.Join(clean, "\n")) return result } func getProviderBaseURL(name string) string { switch name { case "minimax": return "https://api.minimax.io/v1" case "anthropic": return "https://api.anthropic.com/v1" case "openai": return "https://api.openai.com/v1" case "zai": return "https://api.z.ai/v1" default: return "" } }