Files
MuyueWorkspace/internal/orchestrator/orchestrator.go
Augustin 61da8039bc feat(agent): refactor AI chat with streaming, agent registry, and tool execution
- Replace old tool-call regex with proper agent registry
- Add streaming chat via SSE (handleStreamChat / handleNonStreamChat)
- Add internal/agent package with tool definitions and execution
- Add orchestrator with system prompt and tool scaffolding
- Add internal/agent/ directory
- Studio.jsx: streaming chat with thinking indicator and tool result rendering
- global.css: chat bubble styles, streaming animation, thinking dots
- handlers_chat.go: full rewrite using new agent/orchestrator architecture

💘 Generated with Crush

Assisted-by: MiniMax-M2.7 via Crush <crush@charm.land>
2026-04-23 19:47:00 +02:00

414 lines
9.5 KiB
Go

package orchestrator
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/muyue/muyue/internal/config"
)
var thinkRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?</[Tt]hink>`)
const maxHistorySize = 100
type Message struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
type ToolCallMsg struct {
ID string `json:"id"`
Type string `json:"type"`
Function ToolCallFuncMsg `json:"function"`
}
type ToolCallFuncMsg struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Tools json.RawMessage `json:"tools,omitempty"`
}
type ChatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"message"`
Delta struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type Orchestrator struct {
config *config.MuyueConfig
provider *config.AIProvider
client *http.Client
history []Message
histMu sync.Mutex
systemPrompt string
tools json.RawMessage
}
var sharedHTTPClient = &http.Client{
Timeout: 120 * time.Second,
}
func New(cfg *config.MuyueConfig) (*Orchestrator, error) {
var provider *config.AIProvider
for i := range cfg.AI.Providers {
if cfg.AI.Providers[i].Active {
provider = &cfg.AI.Providers[i]
break
}
}
if provider == nil {
return nil, fmt.Errorf("no active AI provider configured")
}
if provider.APIKey == "" {
return nil, fmt.Errorf("API key not set for %s", provider.Name)
}
return &Orchestrator{
config: cfg,
provider: provider,
client: sharedHTTPClient,
history: []Message{},
}, nil
}
func (o *Orchestrator) SetSystemPrompt(prompt string) {
o.systemPrompt = prompt
}
func (o *Orchestrator) SetTools(tools json.RawMessage) {
o.tools = tools
}
func (o *Orchestrator) ProviderName() string {
if o.provider == nil {
return ""
}
return o.provider.Name
}
func (o *Orchestrator) AppendHistory(msg Message) {
o.histMu.Lock()
defer o.histMu.Unlock()
o.history = append(o.history, msg)
if len(o.history) > maxHistorySize {
o.history = o.history[len(o.history)-maxHistorySize:]
}
}
func (o *Orchestrator) GetHistory() []Message {
o.histMu.Lock()
defer o.histMu.Unlock()
out := make([]Message, len(o.history))
copy(out, o.history)
return out
}
func (o *Orchestrator) Send(userMessage string) (string, error) {
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "user",
Content: userMessage,
})
if len(o.history) > maxHistorySize {
o.history = o.history[len(o.history)-maxHistorySize:]
}
messages := make([]Message, 0, len(o.history)+1)
if o.systemPrompt != "" {
messages = append(messages, Message{Role: "system", Content: o.systemPrompt})
}
messages = append(messages, o.history...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: messages,
Stream: false,
Tools: o.tools,
}
o.histMu.Unlock()
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
baseURL := o.provider.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(o.provider.Name)
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+o.provider.APIKey)
resp, err := o.client.Do(req)
if err != nil {
return "", fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
}
var chatResp ChatResponse
if err := json.Unmarshal(respBody, &chatResp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", fmt.Errorf("no response from AI")
}
content := cleanAIResponse(chatResp.Choices[0].Message.Content)
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "assistant",
Content: content,
})
o.histMu.Unlock()
return content, nil
}
func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (string, error) {
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "user",
Content: userMessage,
})
if len(o.history) > maxHistorySize {
o.history = o.history[len(o.history)-maxHistorySize:]
}
messages := make([]Message, 0, len(o.history)+1)
if o.systemPrompt != "" {
messages = append(messages, Message{Role: "system", Content: o.systemPrompt})
}
messages = append(messages, o.history...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: messages,
Stream: true,
Tools: o.tools,
}
o.histMu.Unlock()
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
baseURL := o.provider.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(o.provider.Name)
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+o.provider.APIKey)
resp, err := o.client.Do(req)
if err != nil {
return "", fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
}
var fullContent strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chatResp ChatResponse
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
continue
}
if len(chatResp.Choices) > 0 {
chunk := chatResp.Choices[0].Delta.Content
if chunk != "" {
fullContent.WriteString(chunk)
onChunk(chunk)
}
}
}
if err := scanner.Err(); err != nil {
return fullContent.String(), fmt.Errorf("read stream: %w", err)
}
content := cleanAIResponse(fullContent.String())
o.histMu.Lock()
o.history = append(o.history, Message{
Role: "assistant",
Content: content,
})
o.histMu.Unlock()
return content, nil
}
func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) {
fullMessages := make([]Message, 0, len(messages)+1)
if o.systemPrompt != "" {
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
}
fullMessages = append(fullMessages, messages...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: fullMessages,
Stream: false,
Tools: o.tools,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
baseURL := o.provider.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(o.provider.Name)
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+o.provider.APIKey)
resp, err := o.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
}
var chatResp ChatResponse
if err := json.Unmarshal(respBody, &chatResp); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, fmt.Errorf("no response from AI")
}
return &chatResp, nil
}
func cleanAIResponse(content string) string {
content = thinkRegex.ReplaceAllString(content, "")
lines := strings.Split(content, "\n")
var clean []string
inBlock := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "<<" || trimmed == "<<<" {
inBlock = true
continue
}
if trimmed == ">>" || trimmed == ">>>" {
inBlock = false
continue
}
if inBlock {
continue
}
clean = append(clean, line)
}
result := strings.TrimSpace(strings.Join(clean, "\n"))
return result
}
func getProviderBaseURL(name string) string {
switch name {
case "minimax":
return "https://api.minimax.io/v1"
case "anthropic":
return "https://api.anthropic.com/v1"
case "openai":
return "https://api.openai.com/v1"
case "zai":
return "https://api.z.ai/v1"
default:
return ""
}
}