refactor(chat): deduplicate streaming code, add multi-conv, and XSS protection
All checks were successful
Beta Release / beta (push) Successful in 2m23s
All checks were successful
Beta Release / beta (push) Successful in 2m23s
- Add ChatEngine for deduplicated chat logic (handlers_chat/shell_chat) - Add SendWithToolsStream for real-time streaming responses - Add /help, /plan, /export, /model commands in Studio - Fix XSS: sanitize HTML after markdown rendering - Add ConversationStoreMulti for multi-conversation support - Add Anthropic headers (x-api-key, anthropic-version) - Add fallback logging when provider switch occurs - Add API handler tests (handlers_test.go) - Polish Studio: max-height 200px, word-break on tool args - Update CLI version to show full info (version, go, platform) 🤖 Generated with Crush Assisted-by: MiniMax-M2.5 via Crush <crush@charm.land>
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -76,6 +77,11 @@ var sharedHTTPClient = &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
// requestClient creates an HTTP client with the specified timeout.
|
||||
func requestClient(timeout time.Duration) *http.Client {
|
||||
return &http.Client{Timeout: timeout}
|
||||
}
|
||||
|
||||
func New(cfg *config.MuyueConfig) (*Orchestrator, error) {
|
||||
var provider *config.AIProvider
|
||||
for i := range cfg.AI.Providers {
|
||||
@@ -300,6 +306,142 @@ func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error)
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
// ChunkCallback is called for each streaming chunk.
|
||||
type ChunkCallback func(content string, toolCalls []ToolCallMsg)
|
||||
|
||||
// SendWithToolsStream sends messages with streaming responses.
|
||||
// The callback receives chunks of content and tool_calls as they arrive.
|
||||
func (o *Orchestrator) SendWithToolsStream(messages []Message, onChunk ChunkCallback) (*ChatResponse, error) {
|
||||
fullMessages := make([]Message, 0, len(messages)+1)
|
||||
if o.systemPrompt != "" {
|
||||
fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt})
|
||||
}
|
||||
fullMessages = append(fullMessages, messages...)
|
||||
|
||||
reqBody := ChatRequest{
|
||||
Model: o.provider.Model,
|
||||
Messages: fullMessages,
|
||||
Stream: true,
|
||||
Tools: o.tools,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
provider := o.provider
|
||||
baseURL := provider.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = getProviderBaseURL(provider.Name)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(baseURL, "/") + "/chat/completions"
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+provider.APIKey)
|
||||
|
||||
resp, err := o.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var fullContent strings.Builder
|
||||
var accumulatedToolCalls []ToolCallMsg
|
||||
var totalTokens int
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chatResp ChatResponse
|
||||
if err := json.Unmarshal([]byte(data), &chatResp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) > 0 {
|
||||
chunk := chatResp.Choices[0].Delta.Content
|
||||
if chunk != "" {
|
||||
fullContent.WriteString(chunk)
|
||||
onChunk(chunk, nil)
|
||||
}
|
||||
|
||||
// Handle delta tool calls
|
||||
if len(chatResp.Choices[0].Delta.ToolCalls) > 0 {
|
||||
for _, tc := range chatResp.Choices[0].Delta.ToolCalls {
|
||||
// Find or create the tool call in accumulated list
|
||||
found := false
|
||||
for i := range accumulatedToolCalls {
|
||||
if accumulatedToolCalls[i].ID == tc.ID {
|
||||
// Append arguments
|
||||
accumulatedToolCalls[i].Function.Arguments += tc.Function.Arguments
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
accumulatedToolCalls = append(accumulatedToolCalls, tc)
|
||||
}
|
||||
}
|
||||
onChunk("", accumulatedToolCalls)
|
||||
}
|
||||
|
||||
// Capture usage from final chunk
|
||||
if chatResp.Usage.TotalTokens > 0 {
|
||||
totalTokens = chatResp.Usage.TotalTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("read stream: %w", err)
|
||||
}
|
||||
|
||||
// Build final response
|
||||
finalResp := &ChatResponse{
|
||||
Usage: struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}{TotalTokens: totalTokens},
|
||||
Choices: []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCallMsg `json:"tool_calls"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}{},
|
||||
}
|
||||
|
||||
finalContent := cleanAIResponse(fullContent.String())
|
||||
finalResp.Choices[0].Message.Content = finalContent
|
||||
finalResp.Choices[0].Message.ToolCalls = accumulatedToolCalls
|
||||
|
||||
return finalResp, nil
|
||||
}
|
||||
|
||||
func cleanAIResponse(content string) string {
|
||||
content = thinkRegex.ReplaceAllString(content, "")
|
||||
lines := strings.Split(content, "\n")
|
||||
@@ -368,7 +510,9 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
var triedProviders []string
|
||||
for _, prov := range providerOrder {
|
||||
triedProviders = append(triedProviders, prov.Name)
|
||||
baseURL := baseURLOverride
|
||||
if baseURL == "" {
|
||||
baseURL = prov.BaseURL
|
||||
@@ -392,7 +536,14 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
|
||||
|
||||
// Provider-specific headers
|
||||
if prov.Name == "anthropic" {
|
||||
req.Header.Set("x-api-key", prov.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Bearer "+prov.APIKey)
|
||||
}
|
||||
|
||||
resp, err := o.client.Do(req)
|
||||
if err != nil {
|
||||
@@ -427,5 +578,6 @@ func (o *Orchestrator) sendWithFallback(reqBody ChatRequest, baseURLOverride str
|
||||
return &chatResp, prov.Name, nil
|
||||
}
|
||||
|
||||
log.Printf("[orchestrator] fallback from %v to next provider", triedProviders)
|
||||
return nil, "", lastErr
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user