Files
MuyueWorkspace/internal/orchestrator/orchestrator.go
Augustin 7d0f807fb0
All checks were successful
Beta Release / beta (push) Successful in 57s
feat(ai): add Xiaomi MiMo provider, ZAI as last-resort fallback
Add MiMo-V2.5-Pro from Xiaomi Token Plan as a new AI provider with
base URL https://token-plan-ams.xiaomimimo.com/v1. The /model change
command now switches between MiniMax and MiMo only. ZAI is always
placed last in the fallback chain as the provider of ultimate resort.
Config panel shows MiniMax and MiMo cards.

💘 Generated with Crush

Assisted-by: GLM-5.1 via Crush <crush@charm.land>
2026-04-24 21:22:34 +02:00

594 lines
14 KiB
Go

package orchestrator
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/muyue/muyue/internal/config"
)
var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?</[Tt]hink>`)
const maxHistorySize = 100
type Message struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
type ToolCallMsg struct {
ID string `json:"id"`
Type string `json:"type"`
Function ToolCallFuncMsg `json:"function"`
}
type ToolCallFuncMsg struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Tools json.RawMessage `json:"tools,omitempty"`
}
type ChatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"message"`
Delta struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type Orchestrator struct {
config *config.MuyueConfig
provider *config.AIProvider
client *http.Client
history []Message
histMu sync.Mutex
systemPrompt string
tools json.RawMessage
}
var sharedHTTPClient = &http.Client{
Timeout: 120 * time.Second,
}
// requestClient creates an HTTP client with the specified timeout.
func requestClient(timeout time.Duration) *http.Client {
return &http.Client{Timeout: timeout}
}
func New(cfg *config.MuyueConfig) (*Orchestrator, error) {
var provider *config.AIProvider
for i := range cfg.AI.Providers {
if cfg.AI.Providers[i].Active {
provider = &cfg.AI.Providers[i]
break
}
}
if provider == nil {
return nil, fmt.Errorf("no active AI provider configured")
}
if provider.APIKey == "" {
return nil, fmt.Errorf("API key not set for %s", provider.Name)
}
return &Orchestrator{
config: cfg,
provider: provider,
client: sharedHTTPClient,
history: []Message{},
}, nil
}
func (o *Orchestrator) SetSystemPrompt(prompt string) {
o.systemPrompt = prompt
}
func (o *Orchestrator) SetTools(tools json.RawMessage) {
o.tools = tools
}
func (o *Orchestrator) ProviderName() string {
if o.provider == nil {
return ""
}
return o.provider.Name
}
func (o *Orchestrator) AppendHistory(msg Message) {
o.histMu.Lock()
defer o.histMu.Unlock()
o.history = append(o.history, msg)
if len(o.history) > maxHistorySize {
o.history = o.history[len(o.history)-maxHistorySize:]
}
}
func (o *Orchestrator) GetHistory() []Message {
o.histMu.Lock()
defer o.histMu.Unlock()
out := make([]Message, len(o.history))
copy(out, o.history)
return out
}
func (o *Orchestrator) Send(userMessage string) (string, error) {
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "user",
Content: userMessage,
})
if len(o.history) > maxHistorySize {
o.history = o.history[len(o.history)-maxHistorySize:]
}
messages := make([]Message, 0, len(o.history)+1)
if o.systemPrompt != "" {
messages = append(messages, Message{Role: "system", Content: o.systemPrompt})
}
messages = append(messages, o.history...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: messages,
Stream: false,
Tools: o.tools,
}
o.histMu.Unlock()
chatResp, providerName, err := o.sendWithFallback(reqBody, "")
if err != nil {
return "", err
}
content := cleanAIResponse(chatResp.Choices[0].Message.Content)
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "assistant",
Content: content,
})
_ = providerName
o.histMu.Unlock()
return content, nil
}
func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) {
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "user",
Content: userMessage,
})
if len(o.history) > maxHistorySize {
o.history = o.history[len(o.history)-maxHistorySize:]
}
messages := make([]Message, 0, len(o.history)+1)
if o.systemPrompt != "" {
messages = append(messages, Message{Role: "system", Content: o.systemPrompt})
}
messages = append(messages, o.history...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: messages,
Stream: true,
Tools: o.tools,
}
o.histMu.Unlock()
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
baseURL := o.provider.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(o.provider.Name)
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+o.provider.APIKey)
resp, err := o.client.Do(req)
if err != nil {
return "", fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
}
var fullContent strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chatResp ChatResponse
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
continue
}
if len(chatResp.Choices) > 0 {
chunk := chatResp.Choices[0].Delta.Content
if chunk != "" {
fullContent.WriteString(chunk)
onChunk(chunk)
}
}
}
if err := scanner.Err(); err != nil {
return fullContent.String(), fmt.Errorf("read stream: %w", err)
}
content := cleanAIResponse(fullContent.String())
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "assistant",
Content: content,
})
o.histMu.Unlock()
return content, nil
}
func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) {
fullMessages := make([]Message, 0, len(messages)+1)
if o.systemPrompt != "" {
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
}
fullMessages = append(fullMessages, messages...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: fullMessages,
Stream: false,
Tools: o.tools,
}
chatResp, _, err := o.sendWithFallback(reqBody, "")
if err != nil {
return nil, err
}
if len(chatResp.Choices) == 0 {
return nil, fmt.Errorf("no response from AI")
}
return chatResp, nil
}
// ChunkCallback is called for each streaming chunk.
type ChunkCallback func(content string, toolCalls []ToolCallMsg)
// SendWithToolsStream sends messages with streaming responses.
// The callback receives chunks of content and tool_calls as they arrive.
func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) {
fullMessages := make([]Message, 0, len(messages)+1)
if o.systemPrompt != "" {
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
}
fullMessages = append(fullMessages, messages...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: fullMessages,
Stream: true,
Tools: o.tools,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
provider := o.provider
baseURL := provider.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(provider.Name)
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
resp, err := o.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
}
var fullContent strings.Builder
var accumulatedToolCalls []ToolCallMsg
var totalTokens int
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chatResp ChatResponse
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
continue
}
if len(chatResp.Choices) > 0 {
chunk := chatResp.Choices[0].Delta.Content
if chunk != "" {
fullContent.WriteString(chunk)
onChunk(chunk, nil)
}
// Handle delta tool calls
if len(chatResp.Choices[0].Delta.ToolCalls) > 0 {
for _, tc := range chatResp.Choices[0].Delta.ToolCalls {
// Find or create the tool call in accumulated list
found := false
for i := range accumulatedToolCalls {
if accumulatedToolCalls[i].ID == tc.ID {
// Append arguments
accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments
found = true
break
}
}
if !found {
accumulatedToolCalls = append(accumulatedToolCalls, tc)
}
}
onChunk("", accumulatedToolCalls)
}
// Capture usage from final chunk
if chatResp.Usage.TotalTokens > 0 {
totalTokens = chatResp.Usage.TotalTokens
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("read stream: %w", err)
}
// Build final response
finalResp := &ChatResponse{
Usage: struct {
TotalTokens int `json:"total_tokens"`
}{TotalTokens: totalTokens},
Choices: []struct {
Message struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"message"`
Delta struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
}{},
}
finalContent := cleanAIResponse(fullContent.String())
finalResp.Choices[0].Message.Content = finalContent
finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls
return finalResp, nil
}
func cleanAIResponse(content string) string {
content = thinkRegex.ReplaceAllString(content, "")
lines := strings.Split(content, "\n")
var clean []string
inBlock := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "<<" || trimmed == "<<<" {
inBlock = true
continue
}
if trimmed == ">>" || trimmed == ">>>" {
inBlock = false
continue
}
if inBlock {
continue
}
clean = append(clean, line)
}
result := strings.TrimSpace(strings.Join(clean, "\n"))
return result
}
func getProviderBaseURL(name string) string {
switch name {
case "minimax":
return "https://api.minimax.io/v1"
case "anthropic":
return "https://api.anthropic.com/v1"
case "openai":
return "https://api.openai.com/v1"
case "zai":
return "https://api.z.ai/v1"
case "mimo":
return "https://token-plan-ams.xiaomimimo.com/v1"
default:
return ""
}
}
func (o *Orchestrator) getAvailableProviders() []*config.AIProvider {
var providers []*config.AIProvider
for i := range o.config.AI.Providers {
prov := &o.config.AI.Providers[i]
if prov.APIKey != "" {
providers = append(providers, prov)
}
}
return providers
}
func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride string) (*ChatResponse, string, error) {
providers := o.getAvailableProviders()
if len(providers) == 0 {
return nil, "", fmt.Errorf("no providers available")
}
providerOrder := make([]*config.AIProvider, 0, len(providers))
if o.provider != nil {
providerOrder = append(providerOrder, o.provider)
}
var zaiProvider *config.AIProvider
for _, p := range providers {
if o.provider == nil || p.Name != o.provider.Name {
if p.Name == "zai" {
zaiProvider = p
} else {
providerOrder = append(providerOrder, p)
}
}
}
if zaiProvider != nil {
providerOrder = append(providerOrder, zaiProvider)
}
var lastErr error
var triedProviders []string
for _, prov := range providerOrder {
triedProviders = append(triedProviders, prov.Name)
baseURL := baseURLOverride
if baseURL == "" {
baseURL = prov.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(prov.Name)
}
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
body, err := json.Marshal(reqBody)
if err != nil {
lastErr = fmt.Errorf("marshal request: %w", err)
continue
}
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
lastErr = fmt.Errorf("create request: %w", err)
continue
}
req.Header.Set("Content-Type", "application/json")
// Provider-specific headers
if prov.Name == "anthropic" {
req.Header.Set("x-api-key", prov.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
} else {
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
}
resp, err := o.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("send request to %s: %w", prov.Name, err)
continue
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
lastErr = fmt.Errorf("read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
continue
}
var chatResp ChatResponse
if err := json.Unmarshal(respBody, &chatResp); err != nil {
lastErr = fmt.Errorf("parse response: %w", err)
continue
}
if len(chatResp.Choices) == 0 {
lastErr = fmt.Errorf("no response from AI")
continue
}
o.provider = prov
return &chatResp, prov.Name, nil
}
log.Printf("[orchestrator] fallback from %v to next provider", triedProviders)
return nil, "", lastErr
}