All checks were successful
Beta Release / beta (push) Successful in 1m1s
- Fix token count reset on app restart: persist realTokens in conversation.json - Fix token/context window values: Studio 150K (summarize at 120K), Terminal 100K - Fix table rendering in terminal tab: correct thead/tbody display model - Fix copy button always top-right in Studio code blocks - Add markdown horizontal rule (---) support in Studio and Terminal - Fix bullet list double dot: remove CSS ::before duplicate bullet point - Add image attachments support (VLM description, file mentions @file.ext) - Add sudo detection with cache (sync.Once) - Fix message content serialization (TextContent wrapper) - Guide AI to use read_file instead of cat in studio prompt 💘 Generated with Crush Assisted-by: GLM-5.1 via Crush <crush@charm.land>
622 lines
15 KiB
Go
622 lines
15 KiB
Go
package orchestrator
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/muyue/muyue/internal/config"
|
|
)
|
|
|
|
var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?</[Tt]hink>`)
|
|
|
|
const maxHistorySize = 100
|
|
|
|
type ContentPart struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL *ImageURL `json:"image_url,omitempty"`
|
|
}
|
|
|
|
type ImageURL struct {
|
|
URL string `json:"url"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content json.RawMessage `json:"content,omitempty"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
}
|
|
|
|
func TextContent(s string) json.RawMessage {
|
|
b, _ := json.Marshal(s)
|
|
return b
|
|
}
|
|
|
|
func PartsContent(parts []ContentPart) json.RawMessage {
|
|
b, _ := json.Marshal(parts)
|
|
return b
|
|
}
|
|
|
|
func (m Message) ContentString() string {
|
|
var s string
|
|
if json.Unmarshal(m.Content, &s) == nil {
|
|
return s
|
|
}
|
|
return string(m.Content)
|
|
}
|
|
|
|
type ToolCallMsg struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Function ToolCallFuncMsg `json:"function"`
|
|
}
|
|
|
|
type ToolCallFuncMsg struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
}
|
|
|
|
type ChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
Tools json.RawMessage `json:"tools,omitempty"`
|
|
}
|
|
|
|
type ChatResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"message"`
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
type Orchestrator struct {
|
|
config *config.MuyueConfig
|
|
provider *config.AIProvider
|
|
client *http.Client
|
|
history []Message
|
|
histMu sync.Mutex
|
|
systemPrompt string
|
|
tools json.RawMessage
|
|
}
|
|
|
|
var sharedHTTPClient = &http.Client{
|
|
Timeout: 120 * time.Second,
|
|
}
|
|
|
|
// requestClient creates an HTTP client with the specified timeout.
|
|
func requestClient(timeout time.Duration) *http.Client {
|
|
return &http.Client{Timeout: timeout}
|
|
}
|
|
|
|
func New(cfg *config.MuyueConfig) (*Orchestrator, error) {
|
|
var provider *config.AIProvider
|
|
for i := range cfg.AI.Providers {
|
|
if cfg.AI.Providers[i].Active {
|
|
provider = &cfg.AI.Providers[i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if provider == nil {
|
|
return nil, fmt.Errorf("no active AI provider configured")
|
|
}
|
|
|
|
if provider.APIKey == "" {
|
|
return nil, fmt.Errorf("API key not set for %s", provider.Name)
|
|
}
|
|
|
|
return &Orchestrator{
|
|
config: cfg,
|
|
provider: provider,
|
|
client: sharedHTTPClient,
|
|
history: []Message{},
|
|
}, nil
|
|
}
|
|
|
|
func (o *Orchestrator) SetSystemPrompt(prompt string) {
|
|
o.systemPrompt = prompt
|
|
}
|
|
|
|
func (o *Orchestrator) SetTools(tools json.RawMessage) {
|
|
o.tools = tools
|
|
}
|
|
|
|
func (o *Orchestrator) ProviderName() string {
|
|
if o.provider == nil {
|
|
return ""
|
|
}
|
|
return o.provider.Name
|
|
}
|
|
|
|
func (o *Orchestrator) AppendHistory(msg Message) {
|
|
o.histMu.Lock()
|
|
defer o.histMu.Unlock()
|
|
o.history = append(o.history, msg)
|
|
if len(o.history) > maxHistorySize {
|
|
o.history = o.history[len(o.history)-maxHistorySize:]
|
|
}
|
|
}
|
|
|
|
func (o *Orchestrator) GetHistory() []Message {
|
|
o.histMu.Lock()
|
|
defer o.histMu.Unlock()
|
|
out := make([]Message, len(o.history))
|
|
copy(out, o.history)
|
|
return out
|
|
}
|
|
|
|
func (o *Orchestrator) Send(userMessage string) (string, error) {
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "user",
|
|
Content: TextContent(userMessage),
|
|
})
|
|
|
|
if len(o.history) > maxHistorySize {
|
|
o.history = o.history[len(o.history)-maxHistorySize:]
|
|
}
|
|
|
|
messages := make([]Message, 0, len(o.history)+1)
|
|
if o.systemPrompt != "" {
|
|
messages = append(messages, Message{Role: "system", Content: TextContent(o.systemPrompt)})
|
|
}
|
|
messages = append(messages, o.history...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: messages,
|
|
Stream: false,
|
|
Tools: o.tools,
|
|
}
|
|
o.histMu.Unlock()
|
|
|
|
chatResp, providerName, err := o.sendWithFallback(reqBody, "")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
content := cleanAIResponse(chatResp.Choices[0].Message.Content)
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "assistant",
|
|
Content: TextContent(content),
|
|
})
|
|
_ = providerName
|
|
o.histMu.Unlock()
|
|
|
|
return content, nil
|
|
}
|
|
|
|
func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) {
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "user",
|
|
Content: TextContent(userMessage),
|
|
})
|
|
|
|
if len(o.history) > maxHistorySize {
|
|
o.history = o.history[len(o.history)-maxHistorySize:]
|
|
}
|
|
|
|
messages := make([]Message, 0, len(o.history)+1)
|
|
if o.systemPrompt != "" {
|
|
messages = append(messages, Message{Role: "system", Content: TextContent(o.systemPrompt)})
|
|
}
|
|
messages = append(messages, o.history...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: messages,
|
|
Stream: true,
|
|
Tools: o.tools,
|
|
}
|
|
o.histMu.Unlock()
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
baseURL := o.provider.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = getProviderBaseURL(o.provider.Name)
|
|
}
|
|
|
|
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+o.provider.APIKey)
|
|
|
|
resp, err := o.client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var fullContent strings.Builder
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(chatResp.Choices) > 0 {
|
|
chunk := chatResp.Choices[0].Delta.Content
|
|
if chunk != "" {
|
|
fullContent.WriteString(chunk)
|
|
onChunk(chunk)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return fullContent.String(), fmt.Errorf("read stream: %w", err)
|
|
}
|
|
|
|
content := cleanAIResponse(fullContent.String())
|
|
o.histMu.Lock()
|
|
o.history = append(o.history, Message{
|
|
Role: "assistant",
|
|
Content: TextContent(content),
|
|
})
|
|
o.histMu.Unlock()
|
|
|
|
return content, nil
|
|
}
|
|
|
|
func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) {
|
|
fullMessages := make([]Message, 0, len(messages)+1)
|
|
if o.systemPrompt != "" {
|
|
fullMessages = append(fullMessages, Message{Role: "system", Content: TextContent(o.systemPrompt)})
|
|
}
|
|
fullMessages = append(fullMessages, messages...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: fullMessages,
|
|
Stream: false,
|
|
Tools: o.tools,
|
|
}
|
|
|
|
chatResp, _, err := o.sendWithFallback(reqBody, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no response from AI")
|
|
}
|
|
|
|
return chatResp, nil
|
|
}
|
|
|
|
// ChunkCallback is called for each streaming chunk.
|
|
type ChunkCallback func(content string, toolCalls []ToolCallMsg)
|
|
|
|
// SendWithToolsStream sends messages with streaming responses.
|
|
// The callback receives chunks of content and tool_calls as they arrive.
|
|
func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) {
|
|
fullMessages := make([]Message, 0, len(messages)+1)
|
|
if o.systemPrompt != "" {
|
|
fullMessages = append(fullMessages, Message{Role: "system", Content: TextContent(o.systemPrompt)})
|
|
}
|
|
fullMessages = append(fullMessages, messages...)
|
|
|
|
reqBody := ChatRequest{
|
|
Model: o.provider.Model,
|
|
Messages: fullMessages,
|
|
Stream: true,
|
|
Tools: o.tools,
|
|
}
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
provider := o.provider
|
|
baseURL := provider.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = getProviderBaseURL(provider.Name)
|
|
}
|
|
|
|
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
|
|
|
resp, err := o.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var fullContent strings.Builder
|
|
var accumulatedToolCalls []ToolCallMsg
|
|
var totalTokens int
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(chatResp.Choices) > 0 {
|
|
chunk := chatResp.Choices[0].Delta.Content
|
|
if chunk != "" {
|
|
fullContent.WriteString(chunk)
|
|
onChunk(chunk, nil)
|
|
}
|
|
|
|
// Handle delta tool calls
|
|
if len(chatResp.Choices[0].Delta.ToolCalls) > 0 {
|
|
for _, tc := range chatResp.Choices[0].Delta.ToolCalls {
|
|
// Find or create the tool call in accumulated list
|
|
found := false
|
|
for i := range accumulatedToolCalls {
|
|
if accumulatedToolCalls[i].ID == tc.ID {
|
|
// Append arguments
|
|
accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
accumulatedToolCalls = append(accumulatedToolCalls, tc)
|
|
}
|
|
}
|
|
onChunk("", accumulatedToolCalls)
|
|
}
|
|
|
|
// Capture usage from final chunk
|
|
if chatResp.Usage.TotalTokens > 0 {
|
|
totalTokens = chatResp.Usage.TotalTokens
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, fmt.Errorf("read stream: %w", err)
|
|
}
|
|
|
|
// Build final response
|
|
finalResp := &ChatResponse{
|
|
Usage: struct {
|
|
TotalTokens int `json:"total_tokens"`
|
|
}{TotalTokens: totalTokens},
|
|
Choices: []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"message"`
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}{},
|
|
}
|
|
|
|
finalContent := cleanAIResponse(fullContent.String())
|
|
finalResp.Choices[0].Message.Content = finalContent
|
|
finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls
|
|
|
|
return finalResp, nil
|
|
}
|
|
|
|
func cleanAIResponse(content string) string {
|
|
content = thinkRegex.ReplaceAllString(content, "")
|
|
lines := strings.Split(content, "\n")
|
|
var clean []string
|
|
inBlock := false
|
|
for _, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed == "<<" || trimmed == "<<<" {
|
|
inBlock = true
|
|
continue
|
|
}
|
|
if trimmed == ">>" || trimmed == ">>>" {
|
|
inBlock = false
|
|
continue
|
|
}
|
|
if inBlock {
|
|
continue
|
|
}
|
|
clean = append(clean, line)
|
|
}
|
|
result := strings.TrimSpace(strings.Join(clean, "\n"))
|
|
return result
|
|
}
|
|
|
|
func getProviderBaseURL(name string) string {
|
|
switch name {
|
|
case "minimax":
|
|
return "https://api.minimax.io/v1"
|
|
case "anthropic":
|
|
return "https://api.anthropic.com/v1"
|
|
case "openai":
|
|
return "https://api.openai.com/v1"
|
|
case "zai":
|
|
return "https://api.z.ai/v1"
|
|
case "mimo":
|
|
return "https://token-plan-ams.xiaomimimo.com/v1"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (o *Orchestrator) getAvailableProviders() []*config.AIProvider {
|
|
var providers []*config.AIProvider
|
|
for i := range o.config.AI.Providers {
|
|
prov := &o.config.AI.Providers[i]
|
|
if prov.APIKey != "" {
|
|
providers = append(providers, prov)
|
|
}
|
|
}
|
|
return providers
|
|
}
|
|
|
|
func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride string) (*ChatResponse, string, error) {
|
|
providers := o.getAvailableProviders()
|
|
|
|
if len(providers) == 0 {
|
|
return nil, "", fmt.Errorf("no providers available")
|
|
}
|
|
|
|
providerOrder := make([]*config.AIProvider, 0, len(providers))
|
|
if o.provider != nil {
|
|
providerOrder = append(providerOrder, o.provider)
|
|
}
|
|
var zaiProvider *config.AIProvider
|
|
for _, p := range providers {
|
|
if o.provider == nil || p.Name != o.provider.Name {
|
|
if p.Name == "zai" {
|
|
zaiProvider = p
|
|
} else {
|
|
providerOrder = append(providerOrder, p)
|
|
}
|
|
}
|
|
}
|
|
if zaiProvider != nil {
|
|
providerOrder = append(providerOrder, zaiProvider)
|
|
}
|
|
|
|
var lastErr error
|
|
var triedProviders []string
|
|
for _, prov := range providerOrder {
|
|
triedProviders = append(triedProviders, prov.Name)
|
|
baseURL := baseURLOverride
|
|
if baseURL == "" {
|
|
baseURL = prov.BaseURL
|
|
if baseURL == "" {
|
|
baseURL = getProviderBaseURL(prov.Name)
|
|
}
|
|
}
|
|
|
|
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("marshal request: %w", err)
|
|
continue
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("create request: %w", err)
|
|
continue
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// Provider-specific headers
|
|
if prov.Name == "anthropic" {
|
|
req.Header.Set("x-api-key", prov.APIKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
} else {
|
|
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
|
|
}
|
|
|
|
resp, err := o.client.Do(req)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("send request to %s: %w", prov.Name, err)
|
|
continue
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("read response: %w", err)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
lastErr = fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
|
continue
|
|
}
|
|
|
|
var chatResp ChatResponse
|
|
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
|
lastErr = fmt.Errorf("parse response: %w", err)
|
|
continue
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
lastErr = fmt.Errorf("no response from AI")
|
|
continue
|
|
}
|
|
|
|
o.provider = prov
|
|
return &chatResp, prov.Name, nil
|
|
}
|
|
|
|
log.Printf("[orchestrator] fallback from %v to next provider", triedProviders)
|
|
return nil, "", lastErr
|
|
}
|