refactor(chat): deduplicate streaming code, add multi-conv, and XSS protection
All checks were successful
Beta Release / beta (push) Successful in 2m23s

- Add ChatEngine for deduplicated chat logic (handlers_chat/shell_chat)
- Add SendWithToolsStream for real-time streaming responses
- Add /help, /plan, /export, /model commands in Studio
- Fix XSS: sanitize HTML after markdown rendering
- Add ConversationStoreMulti for multi-conversation support
- Add Anthropic headers (x-api-key, anthropic-version)
- Add fallback logging when provider switch occurs
- Add API handler tests (handlers_test.go)
- Polish Studio: max-height 200px, word-break on tool args
- Update CLI version to show full info (version, go, platform)

🤖 Generated with Crush

Assisted-by: MiniMax-M2.5 via Crush <crush@charm.land>
This commit is contained in:
Augustin
2026-04-22 22:58:05 +02:00
parent 65804aae4e
commit 3948a4c656
12 changed files with 1024 additions and 312 deletions

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strings"
@@ -76,6 +77,11 @@ var sharedHTTPClient = &http.Client{
Timeout: 120 * time.Second,
}
// requestClient creates an HTTP client with the specified timeout.
func requestClient(timeout time.Duration) *http.Client {
return &http.Client{Timeout: timeout}
}
func New(cfg *config.MuyueConfig) (*Orchestrator, error) {
var provider *config.AIProvider
for i := range cfg.AI.Providers {
@@ -300,6 +306,142 @@ func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error)
return chatResp, nil
}
// ChunkCallback is called for each streaming chunk.
type ChunkCallback func(content string, toolCalls []ToolCallMsg)
// SendWithToolsStream sends messages with streaming responses.
// The callback receives chunks of content and tool_calls as they arrive.
func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) {
fullMessages := make([]Message, 0, len(messages)+1)
if o.systemPrompt != "" {
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
}
fullMessages = append(fullMessages, messages...)
reqBody := ChatRequest{
Model: o.provider.Model,
Messages: fullMessages,
Stream: true,
Tools: o.tools,
}
body, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
provider := o.provider
baseURL := provider.BaseURL
if baseURL == "" {
baseURL = getProviderBaseURL(provider.Name)
}
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
resp, err := o.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
}
var fullContent strings.Builder
var accumulatedToolCalls []ToolCallMsg
var totalTokens int
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chatResp ChatResponse
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
continue
}
if len(chatResp.Choices) > 0 {
chunk := chatResp.Choices[0].Delta.Content
if chunk != "" {
fullContent.WriteString(chunk)
onChunk(chunk, nil)
}
// Handle delta tool calls
if len(chatResp.Choices[0].Delta.ToolCalls) > 0 {
for _, tc := range chatResp.Choices[0].Delta.ToolCalls {
// Find or create the tool call in accumulated list
found := false
for i := range accumulatedToolCalls {
if accumulatedToolCalls[i].ID == tc.ID {
// Append arguments
accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments
found = true
break
}
}
if !found {
accumulatedToolCalls = append(accumulatedToolCalls, tc)
}
}
onChunk("", accumulatedToolCalls)
}
// Capture usage from final chunk
if chatResp.Usage.TotalTokens > 0 {
totalTokens = chatResp.Usage.TotalTokens
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("read stream: %w", err)
}
// Build final response
finalResp := &ChatResponse{
Usage: struct {
TotalTokens int `json:"total_tokens"`
}{TotalTokens: totalTokens},
Choices: []struct {
Message struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"message"`
Delta struct {
Content string `json:"content"`
ToolCalls []ToolCallMsg `json:"tool_calls"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
}{},
}
finalContent := cleanAIResponse(fullContent.String())
finalResp.Choices[0].Message.Content = finalContent
finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls
return finalResp, nil
}
func cleanAIResponse(content string) string {
content = thinkRegex.ReplaceAllString(content, "")
lines := strings.Split(content, "\n")
@@ -368,7 +510,9 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str
}
var lastErr error
var triedProviders []string
for _, prov := range providerOrder {
triedProviders = append(triedProviders, prov.Name)
baseURL := baseURLOverride
if baseURL == "" {
baseURL = prov.BaseURL
@@ -392,7 +536,14 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
// Provider-specific headers
if prov.Name == "anthropic" {
req.Header.Set("x-api-key", prov.APIKey)
req.Header.Set("anthropic-version", "2023-06-01")
} else {
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
}
resp, err := o.client.Do(req)
if err != nil {
@@ -427,5 +578,6 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str
return &chatResp, prov.Name, nil
}
log.Printf("[orchestrator] fallback from %v to next provider", triedProviders)
return nil, "", lastErr
}