All checks were successful
Beta Release / beta (push) Successful in 57s
Add MiMo-V2.5-Pro from Xiaomi Token Plan as a new AI provider with base URL https://token-plan-ams.xiaomimimo.com/v1. The /model change command now switches between MiniMax and MiMo only. ZAI is always placed last in the fallback chain as the provider of ultimate resort. Config panel shows MiniMax and MiMo cards. 💘 Generated with Crush Assisted-by: GLM-5.1 via Crush <crush@charm.land>
594 lines
14 KiB
Go
594 lines
14 KiB
Go
package orchestrator
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/muyue/muyue/internal/config"
|
|
)
|
|
|
|
var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?</[Tt]hink>`)
|
|
|
|
const maxHistorySize = 100
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content,omitempty"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
}
|
|
|
|
type ToolCallMsg struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Function ToolCallFuncMsg `json:"function"`
|
|
}
|
|
|
|
type ToolCallFuncMsg struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
}
|
|
|
|
type ChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
Tools json.RawMessage `json:"tools,omitempty"`
|
|
}
|
|
|
|
type ChatResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"message"`
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
type Orchestrator struct {
|
|
config *config.MuyueConfig
|
|
provider *config.AIProvider
|
|
client *http.Client
|
|
history []Message
|
|
histMu sync.Mutex
|
|
systemPrompt string
|
|
tools json.RawMessage
|
|
}
|
|
|
|
var sharedHTTPClient = &http.Client{
|
|
Timeout: 120 * time.Second,
|
|
}
|
|
|
|
// requestClient creates an HTTP client with the specified timeout.
|
|
func requestClient(timeout time.Duration) *http.Client {
|
|
return &http.Client{Timeout: timeout}
|
|
}
|
|
|
|
func New(cfg *config.MuyueConfig) (*Orchestrator, error) {
|
|
var provider *config.AIProvider
|
|
for i := range cfg.AI.Providers {
|
|
if cfg.AI.Providers[i].Active {
|
|
provider = &cfg.AI.Providers[i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if provider == nil {
|
|
return nil, fmt.Errorf("no active AI provider configured")
|
|
}
|
|
|
|
if provider.APIKey == "" {
|
|
return nil, fmt.Errorf("API key not set for %s", provider.Name)
|
|
}
|
|
|
|
return &Orchestrator{
|
|
config: cfg,
|
|
provider: provider,
|
|
client: sharedHTTPClient,
|
|
history: []Message{},
|
|
}, nil
|
|
}
|
|
|
|
func (o *Orchestrator) SetSystemPrompt(prompt string) {
|
|
o.systemPrompt = prompt
|
|
}
|
|
|
|
func (o *Orchestrator) SetTools(tools json.RawMessage) {
|
|
o.tools = tools
|
|
}
|
|
|
|
func (o *Orchestrator) ProviderName() string {
|
|
if o.provider == nil {
|
|
return ""
|
|
}
|
|
return o.provider.Name
|
|
}
|
|
|
|
func (o *Orchestrator) AppendHistory(msg Message) {
|
|
o.histMu.Lock()
|
|
defer o.histMu.Unlock()
|
|
o.history = append(o.history, msg)
|
|
if len(o.history) > maxHistorySize {
|
|
o.history = o.history[len(o.history)-maxHistorySize:]
|
|
}
|
|
}
|
|
|
|
func (o *Orchestrator) GetHistory() []Message {
|
|
o.histMu.Lock()
|
|
defer o.histMu.Unlock()
|
|
out := make([]Message, len(o.history))
|
|
copy(out, o.history)
|
|
return out
|
|
}
|
|
|
|
func (o *Orchestrator) Send(userMessage string) (string, error) {
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "user",
|
|
Content: userMessage,
|
|
})
|
|
|
|
if len(o.history) > maxHistorySize {
|
|
o.history = o.history[len(o.history)-maxHistorySize:]
|
|
}
|
|
|
|
messages := make([]Message, 0, len(o.history)+1)
|
|
if o.systemPrompt != "" {
|
|
messages = append(messages, Message{Role: "system", Content: o.systemPrompt})
|
|
}
|
|
messages = append(messages, o.history...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: messages,
|
|
Stream: false,
|
|
Tools: o.tools,
|
|
}
|
|
o.histMu.Unlock()
|
|
|
|
chatResp, providerName, err := o.sendWithFallback(reqBody, "")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
content := cleanAIResponse(chatResp.Choices[0].Message.Content)
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "assistant",
|
|
Content: content,
|
|
})
|
|
_ = providerName
|
|
o.histMu.Unlock()
|
|
|
|
return content, nil
|
|
}
|
|
|
|
func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) {
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "user",
|
|
Content: userMessage,
|
|
})
|
|
|
|
if len(o.history) > maxHistorySize {
|
|
o.history = o.history[len(o.history)-maxHistorySize:]
|
|
}
|
|
|
|
messages := make([]Message, 0, len(o.history)+1)
|
|
if o.systemPrompt != "" {
|
|
messages = append(messages, Message{Role: "system", Content: o.systemPrompt})
|
|
}
|
|
messages = append(messages, o.history...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: messages,
|
|
Stream: true,
|
|
Tools: o.tools,
|
|
}
|
|
o.histMu.Unlock()
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
baseURL := o.provider.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = getProviderBaseURL(o.provider.Name)
|
|
}
|
|
|
|
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+o.provider.APIKey)
|
|
|
|
resp, err := o.client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var fullContent strings.Builder
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(chatResp.Choices) > 0 {
|
|
chunk := chatResp.Choices[0].Delta.Content
|
|
if chunk != "" {
|
|
fullContent.WriteString(chunk)
|
|
onChunk(chunk)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return fullContent.String(), fmt.Errorf("read stream: %w", err)
|
|
}
|
|
|
|
content := cleanAIResponse(fullContent.String())
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "assistant",
|
|
Content: content,
|
|
})
|
|
o.histMu.Unlock()
|
|
|
|
return content, nil
|
|
}
|
|
|
|
func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) {
|
|
fullMessages := make([]Message, 0, len(messages)+1)
|
|
if o.systemPrompt != "" {
|
|
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
|
|
}
|
|
fullMessages = append(fullMessages, messages...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: fullMessages,
|
|
Stream: false,
|
|
Tools: o.tools,
|
|
}
|
|
|
|
chatResp, _, err := o.sendWithFallback(reqBody, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no response from AI")
|
|
}
|
|
|
|
return chatResp, nil
|
|
}
|
|
|
|
// ChunkCallback is called for each streaming chunk.
|
|
type ChunkCallback func(content string, toolCalls []ToolCallMsg)
|
|
|
|
// SendWithToolsStream sends messages with streaming responses.
|
|
// The callback receives chunks of content and tool_calls as they arrive.
|
|
func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) {
|
|
fullMessages := make([]Message, 0, len(messages)+1)
|
|
if o.systemPrompt != "" {
|
|
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
|
|
}
|
|
fullMessages = append(fullMessages, messages...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: fullMessages,
|
|
Stream: true,
|
|
Tools: o.tools,
|
|
}
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
provider := o.provider
|
|
baseURL := provider.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = getProviderBaseURL(provider.Name)
|
|
}
|
|
|
|
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
|
|
|
resp, err := o.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var fullContent strings.Builder
|
|
var accumulatedToolCalls []ToolCallMsg
|
|
var totalTokens int
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(chatResp.Choices) > 0 {
|
|
chunk := chatResp.Choices[0].Delta.Content
|
|
if chunk != "" {
|
|
fullContent.WriteString(chunk)
|
|
onChunk(chunk, nil)
|
|
}
|
|
|
|
// Handle delta tool calls
|
|
if len(chatResp.Choices[0].Delta.ToolCalls) > 0 {
|
|
for _, tc := range chatResp.Choices[0].Delta.ToolCalls {
|
|
// Find or create the tool call in accumulated list
|
|
found := false
|
|
for i := range accumulatedToolCalls {
|
|
if accumulatedToolCalls[i].ID == tc.ID {
|
|
// Append arguments
|
|
accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
accumulatedToolCalls = append(accumulatedToolCalls, tc)
|
|
}
|
|
}
|
|
onChunk("", accumulatedToolCalls)
|
|
}
|
|
|
|
// Capture usage from final chunk
|
|
if chatResp.Usage.TotalTokens > 0 {
|
|
totalTokens = chatResp.Usage.TotalTokens
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("read stream: %w", err)
|
|
}
|
|
|
|
// Build final response
|
|
finalResp := &ChatResponse{
|
|
Usage: struct {
|
|
TotalTokens int `json:"total_tokens"`
|
|
}{TotalTokens: totalTokens},
|
|
Choices: []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"message"`
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}{},
|
|
}
|
|
|
|
finalContent := cleanAIResponse(fullContent.String())
|
|
finalResp.Choices[0].Message.Content = finalContent
|
|
finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls
|
|
|
|
return finalResp, nil
|
|
}
|
|
|
|
func cleanAIResponse(content string) string {
|
|
content = thinkRegex.ReplaceAllString(content, "")
|
|
lines := strings.Split(content, "\n")
|
|
var clean []string
|
|
inBlock := false
|
|
for _, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed == "<<" || trimmed == "<<<" {
|
|
inBlock = true
|
|
continue
|
|
}
|
|
if trimmed == ">>" || trimmed == ">>>" {
|
|
inBlock = false
|
|
continue
|
|
}
|
|
if inBlock {
|
|
continue
|
|
}
|
|
clean = append(clean, line)
|
|
}
|
|
result := strings.TrimSpace(strings.Join(clean, "\n"))
|
|
return result
|
|
}
|
|
|
|
func getProviderBaseURL(name string) string {
|
|
switch name {
|
|
case "minimax":
|
|
return "https://api.minimax.io/v1"
|
|
case "anthropic":
|
|
return "https://api.anthropic.com/v1"
|
|
case "openai":
|
|
return "https://api.openai.com/v1"
|
|
case "zai":
|
|
return "https://api.z.ai/v1"
|
|
case "mimo":
|
|
return "https://token-plan-ams.xiaomimimo.com/v1"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (o *Orchestrator) getAvailableProviders() []*config.AIProvider {
|
|
var providers []*config.AIProvider
|
|
for i := range o.config.AI.Providers {
|
|
prov := &o.config.AI.Providers[i]
|
|
if prov.APIKey != "" {
|
|
providers = append(providers, prov)
|
|
}
|
|
}
|
|
return providers
|
|
}
|
|
|
|
func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride string) (*ChatResponse, string, error) {
|
|
providers := o.getAvailableProviders()
|
|
|
|
if len(providers) == 0 {
|
|
return nil, "", fmt.Errorf("no providers available")
|
|
}
|
|
|
|
providerOrder := make([]*config.AIProvider, 0, len(providers))
|
|
if o.provider != nil {
|
|
providerOrder = append(providerOrder, o.provider)
|
|
}
|
|
var zaiProvider *config.AIProvider
|
|
for _, p := range providers {
|
|
if o.provider == nil || p.Name != o.provider.Name {
|
|
if p.Name == "zai" {
|
|
zaiProvider = p
|
|
} else {
|
|
providerOrder = append(providerOrder, p)
|
|
}
|
|
}
|
|
}
|
|
if zaiProvider != nil {
|
|
providerOrder = append(providerOrder, zaiProvider)
|
|
}
|
|
|
|
var lastErr error
|
|
var triedProviders []string
|
|
for _, prov := range providerOrder {
|
|
triedProviders = append(triedProviders, prov.Name)
|
|
baseURL := baseURLOverride
|
|
if baseURL == "" {
|
|
baseURL = prov.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = getProviderBaseURL(prov.Name)
|
|
}
|
|
}
|
|
|
|
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("marshal request: %w", err)
|
|
continue
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("create request: %w", err)
|
|
continue
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// Provider-specific headers
|
|
if prov.Name == "anthropic" {
|
|
req.Header.Set("x-api-key", prov.APIKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
} else {
|
|
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
|
|
}
|
|
|
|
resp, err := o.client.Do(req)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("send request to %s: %w", prov.Name, err)
|
|
continue
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("read response: %w", err)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
lastErr = fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
|
continue
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
|
lastErr = fmt.Errorf("parse response: %w", err)
|
|
continue
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
lastErr = fmt.Errorf("no response from AI")
|
|
continue
|
|
}
|
|
|
|
o.provider = prov
|
|
return &chatResp, prov.Name, nil
|
|
}
|
|
|
|
log.Printf("[orchestrator] fallback from %v to next provider", triedProviders)
|
|
return nil, "", lastErr
|
|
}
|