refactor(chat): deduplicate streaming code, add multi-conv, and XSS protection
All checks were successful
Beta Release / beta (push) Successful in 2m23s
All checks were successful
Beta Release / beta (push) Successful in 2m23s
- Add ChatEngine for deduplicated chat logic (handlers_chat/shell_chat) - Add SendWithToolsStream for real-time streaming responses - Add /help, /plan, /export, /model commands in Studio - Fix XSS: sanitize HTML after markdown rendering - Add ConversationStoreMulti for multi-conversation support - Add Anthropic headers (x-api-key, anthropic-version) - Add fallback logging when provider switch occurs - Add API handler tests (handlers_test.go) - Polish Studio: max-height 200px, word-break on tool args - Update CLI version to show full info (version, go, platform) 🤖 Generated with Crush Assisted-by: MiniMax-M2.5 via Crush <crush@charm.land>
This commit is contained in:
249
internal/api/chat_engine.go
Normal file
249
internal/api/chat_engine.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/muyue/muyue/internal/agent"
|
||||
"github.com/muyue/muyue/internal/orchestrator"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxToolIterations = 15
|
||||
)
|
||||
|
||||
// ChatEngine handles chat interactions with tool execution.
|
||||
// This deduplicates chat logic previously repeated in handlers_chat.go and handlers_shell_chat.go.
|
||||
type ChatEngine struct {
|
||||
orchestrator *orchestrator.Orchestrator
|
||||
registry *agent.Registry
|
||||
tools json.RawMessage
|
||||
onChunk func(map[string]interface{})
|
||||
stream bool
|
||||
}
|
||||
|
||||
// NewChatEngine creates a new ChatEngine instance.
|
||||
func NewChatEngine(orb *orchestrator.Orchestrator, registry *agent.Registry, tools json.RawMessage) *ChatEngine {
|
||||
return &ChatEngine{
|
||||
orchestrator: orb,
|
||||
registry: registry,
|
||||
tools: tools,
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetStream enables streaming mode for the chat engine.
|
||||
func (ce *ChatEngine) SetStream(enabled bool) {
|
||||
ce.stream = enabled
|
||||
}
|
||||
|
||||
// OnChunk sets the callback for SSE chunk writing.
|
||||
func (ce *ChatEngine) OnChunk(fn func(map[string]interface{})) {
|
||||
ce.onChunk = fn
|
||||
}
|
||||
|
||||
// RunWithTools executes the chat loop with tool calls.
|
||||
// Returns final content, tool calls, tool results, and error.
|
||||
func (ce *ChatEngine) RunWithTools(ctx context.Context, messages []orchestrator.Message) (string, []map[string]interface{}, []map[string]interface{}, error) {
|
||||
var finalContent string
|
||||
var allToolCalls []map[string]interface{}
|
||||
var allToolResults []map[string]interface{}
|
||||
|
||||
for i := 0; i < MaxToolIterations; i++ {
|
||||
var resp *orchestrator.ChatResponse
|
||||
var err error
|
||||
|
||||
if ce.stream {
|
||||
// Use streaming version
|
||||
resp, err = ce.orchestrator.SendWithToolsStream(messages, func(content string, toolCalls []orchestrator.ToolCallMsg) {
|
||||
if ce.onChunk != nil && content != "" {
|
||||
ce.onChunk(map[string]interface{}{"content": content})
|
||||
}
|
||||
})
|
||||
} else {
|
||||
resp, err = ce.orchestrator.SendWithTools(messages)
|
||||
}
|
||||
if err != nil {
|
||||
if ce.onChunk != nil {
|
||||
ce.onChunk(map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
return finalContent, allToolCalls, allToolResults, err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
content := cleanThinkingTags(choice.Message.Content)
|
||||
|
||||
if content != "" {
|
||||
words := strings.Fields(content)
|
||||
for _, w := range words {
|
||||
chunk := w
|
||||
if ce.onChunk != nil {
|
||||
ce.onChunk(map[string]interface{}{"content": chunk})
|
||||
}
|
||||
}
|
||||
finalContent = content
|
||||
}
|
||||
|
||||
if len(choice.Message.ToolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
assistantMsg := orchestrator.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: choice.Message.ToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
toolCallData := map[string]interface{}{
|
||||
"tool_call_id": tc.ID,
|
||||
"name": tc.Function.Name,
|
||||
"args": tc.Function.Arguments,
|
||||
}
|
||||
allToolCalls = append(allToolCalls, toolCallData)
|
||||
|
||||
if ce.onChunk != nil {
|
||||
ce.onChunk(map[string]interface{}{"tool_call": toolCallData})
|
||||
}
|
||||
|
||||
call := agent.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: json.RawMessage(tc.Function.Arguments),
|
||||
}
|
||||
|
||||
result, execErr := ce.registry.Execute(ctx, call)
|
||||
if execErr != nil {
|
||||
result = agent.ToolResponse{
|
||||
Content: execErr.Error(),
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
resultData := map[string]interface{}{
|
||||
"tool_call_id": tc.ID,
|
||||
"content": result.Content,
|
||||
"is_error": result.IsError,
|
||||
}
|
||||
allToolResults = append(allToolResults, map[string]interface{}{
|
||||
"tool_call_id": tc.ID,
|
||||
"name": tc.Function.Name,
|
||||
"args": tc.Function.Arguments,
|
||||
"result": result.Content,
|
||||
"is_error": result.IsError,
|
||||
})
|
||||
|
||||
if ce.onChunk != nil {
|
||||
ce.onChunk(map[string]interface{}{"tool_result": resultData})
|
||||
}
|
||||
|
||||
messages = append(messages, orchestrator.Message{
|
||||
Role: "tool",
|
||||
Content: result.Content,
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
})
|
||||
}
|
||||
|
||||
finalContent = ""
|
||||
}
|
||||
|
||||
return finalContent, allToolCalls, allToolResults, nil
|
||||
}
|
||||
|
||||
// RunNonStream executes chat without streaming content to client.
|
||||
func (ce *ChatEngine) RunNonStream(ctx context.Context, messages []orchestrator.Message) (string, error) {
|
||||
var finalContent string
|
||||
|
||||
for i := 0; i < MaxToolIterations; i++ {
|
||||
resp, err := ce.orchestrator.SendWithTools(messages)
|
||||
if err != nil {
|
||||
return finalContent, err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
content := cleanThinkingTags(choice.Message.Content)
|
||||
|
||||
if content != "" {
|
||||
finalContent = content
|
||||
}
|
||||
|
||||
if len(choice.Message.ToolCalls) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
assistantMsg := orchestrator.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: choice.Message.ToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
call := agent.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: json.RawMessage(tc.Function.Arguments),
|
||||
}
|
||||
|
||||
result, execErr := ce.registry.Execute(ctx, call)
|
||||
if execErr != nil {
|
||||
result = agent.ToolResponse{
|
||||
Content: execErr.Error(),
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, orchestrator.Message{
|
||||
Role: "tool",
|
||||
Content: result.Content,
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
})
|
||||
}
|
||||
|
||||
finalContent = ""
|
||||
}
|
||||
|
||||
if finalContent == "" {
|
||||
finalContent = "(tool calls completed, no text response)"
|
||||
}
|
||||
|
||||
return finalContent, nil
|
||||
}
|
||||
|
||||
// SSEWriter handles Server-Sent Events writing to HTTP response.
|
||||
type SSEWriter struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
}
|
||||
|
||||
// NewSSEWriter creates a new SSEWriter.
|
||||
func NewSSEWriter(w http.ResponseWriter) *SSEWriter {
|
||||
sse := &SSEWriter{w: w}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
sse.flusher = f
|
||||
}
|
||||
return sse
|
||||
}
|
||||
|
||||
// Write sends an SSE message.
|
||||
func (s *SSEWriter) Write(data map[string]interface{}) {
|
||||
b, _ := json.Marshal(data)
|
||||
s.w.Write([]byte("data: " + string(b) + "\n\n"))
|
||||
if s.flusher != nil {
|
||||
s.flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// SetupSSEHeaders sets up SSE response headers.
|
||||
func SetupSSEHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
Reference in New Issue
Block a user