diff --git a/cmd/muyue/commands/version.go b/cmd/muyue/commands/version.go
index 696b612..4d3baf3 100644
--- a/cmd/muyue/commands/version.go
+++ b/cmd/muyue/commands/version.go
@@ -9,7 +9,7 @@ import (
var versionCmd = &cobra.Command{
Use: "version",
- Short: "Print version",
+ Short: "Print version info",
RunE: runVersion,
}
@@ -18,6 +18,6 @@ func init() {
}
func runVersion(cmd *cobra.Command, args []string) error {
- fmt.Printf("Muyue version %s\n", version.Version)
+ fmt.Print(version.FullInfo())
return nil
}
\ No newline at end of file
diff --git a/go.mod b/go.mod
index 7dc8717..b1ad938 100644
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ toolchain go1.24.3
require (
github.com/charmbracelet/huh v1.0.0
github.com/creack/pty/v2 v2.0.1
+ github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/spf13/cobra v1.10.2
gopkg.in/yaml.v3 v3.0.1
diff --git a/go.sum b/go.sum
index 3ab332e..19f94e1 100644
--- a/go.sum
+++ b/go.sum
@@ -51,6 +51,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
diff --git a/internal/api/chat_engine.go b/internal/api/chat_engine.go
new file mode 100644
index 0000000..28feab8
--- /dev/null
+++ b/internal/api/chat_engine.go
@@ -0,0 +1,249 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "strings"
+
+ "github.com/muyue/muyue/internal/agent"
+ "github.com/muyue/muyue/internal/orchestrator"
+)
+
+const (
+ MaxToolIterations = 15
+)
+
+// ChatEngine handles chat interactions with tool execution.
+// This deduplicates chat logic previously repeated in handlers_chat.go and handlers_shell_chat.go.
+type ChatEngine struct {
+ orchestrator *orchestrator.Orchestrator
+ registry *agent.Registry
+ tools json.RawMessage
+ onChunk func(map[string]interface{})
+ stream bool
+}
+
+// NewChatEngine creates a new ChatEngine instance.
+func NewChatEngine(orb *orchestrator.Orchestrator, registry *agent.Registry, tools json.RawMessage) *ChatEngine {
+ return &ChatEngine{
+ orchestrator: orb,
+ registry: registry,
+ tools: tools,
+ stream: false,
+ }
+}
+
+// SetStream enables streaming mode for the chat engine.
+func (ce *ChatEngine) SetStream(enabled bool) {
+ ce.stream = enabled
+}
+
+// OnChunk sets the callback for SSE chunk writing.
+func (ce *ChatEngine) OnChunk(fn func(map[string]interface{})) {
+ ce.onChunk = fn
+}
+
+// RunWithTools executes the chat loop with tool calls.
+// Returns final content, tool calls, tool results, and error.
+func (ce *ChatEngine) RunWithTools(ctx context.Context, messages []orchestrator.Message) (string, []map[string]interface{}, []map[string]interface{}, error) {
+ var finalContent string
+ var allToolCalls []map[string]interface{}
+ var allToolResults []map[string]interface{}
+
+ for i := 0; i < MaxToolIterations; i++ {
+ var resp *orchestrator.ChatResponse
+ var err error
+
+ if ce.stream {
+ // Use streaming version
+ resp, err = ce.orchestrator.SendWithToolsStream(messages, func(content string, toolCalls []orchestrator.ToolCallMsg) {
+ if ce.onChunk != nil && content != "" {
+ ce.onChunk(map[string]interface{}{"content": content})
+ }
+ })
+ } else {
+ resp, err = ce.orchestrator.SendWithTools(messages)
+ }
+ if err != nil {
+ if ce.onChunk != nil {
+ ce.onChunk(map[string]interface{}{"error": err.Error()})
+ }
+ return finalContent, allToolCalls, allToolResults, err
+ }
+
+ choice := resp.Choices[0]
+ content := cleanThinkingTags(choice.Message.Content)
+
+ if content != "" {
+ words := strings.Fields(content)
+ for _, w := range words {
+ chunk := w
+ if ce.onChunk != nil {
+ ce.onChunk(map[string]interface{}{"content": chunk})
+ }
+ }
+ finalContent = content
+ }
+
+ if len(choice.Message.ToolCalls) == 0 {
+ break
+ }
+
+ assistantMsg := orchestrator.Message{
+ Role: "assistant",
+ Content: content,
+ ToolCalls: choice.Message.ToolCalls,
+ }
+ messages = append(messages, assistantMsg)
+
+ for _, tc := range choice.Message.ToolCalls {
+ toolCallData := map[string]interface{}{
+ "tool_call_id": tc.ID,
+ "name": tc.Function.Name,
+ "args": tc.Function.Arguments,
+ }
+ allToolCalls = append(allToolCalls, toolCallData)
+
+ if ce.onChunk != nil {
+ ce.onChunk(map[string]interface{}{"tool_call": toolCallData})
+ }
+
+ call := agent.ToolCall{
+ ID: tc.ID,
+ Name: tc.Function.Name,
+ Arguments: json.RawMessage(tc.Function.Arguments),
+ }
+
+ result, execErr := ce.registry.Execute(ctx, call)
+ if execErr != nil {
+ result = agent.ToolResponse{
+ Content: execErr.Error(),
+ IsError: true,
+ }
+ }
+
+ resultData := map[string]interface{}{
+ "tool_call_id": tc.ID,
+ "content": result.Content,
+ "is_error": result.IsError,
+ }
+ allToolResults = append(allToolResults, map[string]interface{}{
+ "tool_call_id": tc.ID,
+ "name": tc.Function.Name,
+ "args": tc.Function.Arguments,
+ "result": result.Content,
+ "is_error": result.IsError,
+ })
+
+ if ce.onChunk != nil {
+ ce.onChunk(map[string]interface{}{"tool_result": resultData})
+ }
+
+ messages = append(messages, orchestrator.Message{
+ Role: "tool",
+ Content: result.Content,
+ ToolCallID: tc.ID,
+ Name: tc.Function.Name,
+ })
+ }
+
+ finalContent = ""
+ }
+
+ return finalContent, allToolCalls, allToolResults, nil
+}
+
+// RunNonStream executes chat without streaming content to client.
+func (ce *ChatEngine) RunNonStream(ctx context.Context, messages []orchestrator.Message) (string, error) {
+ var finalContent string
+
+ for i := 0; i < MaxToolIterations; i++ {
+ resp, err := ce.orchestrator.SendWithTools(messages)
+ if err != nil {
+ return finalContent, err
+ }
+
+ choice := resp.Choices[0]
+ content := cleanThinkingTags(choice.Message.Content)
+
+ if content != "" {
+ finalContent = content
+ }
+
+ if len(choice.Message.ToolCalls) == 0 {
+ break
+ }
+
+ assistantMsg := orchestrator.Message{
+ Role: "assistant",
+ Content: content,
+ ToolCalls: choice.Message.ToolCalls,
+ }
+ messages = append(messages, assistantMsg)
+
+ for _, tc := range choice.Message.ToolCalls {
+ call := agent.ToolCall{
+ ID: tc.ID,
+ Name: tc.Function.Name,
+ Arguments: json.RawMessage(tc.Function.Arguments),
+ }
+
+ result, execErr := ce.registry.Execute(ctx, call)
+ if execErr != nil {
+ result = agent.ToolResponse{
+ Content: execErr.Error(),
+ IsError: true,
+ }
+ }
+
+ messages = append(messages, orchestrator.Message{
+ Role: "tool",
+ Content: result.Content,
+ ToolCallID: tc.ID,
+ Name: tc.Function.Name,
+ })
+ }
+
+ finalContent = ""
+ }
+
+ if finalContent == "" {
+ finalContent = "(tool calls completed, no text response)"
+ }
+
+ return finalContent, nil
+}
+
+// SSEWriter handles Server-Sent Events writing to HTTP response.
+type SSEWriter struct {
+ w http.ResponseWriter
+ flusher http.Flusher
+}
+
+// NewSSEWriter creates a new SSEWriter.
+func NewSSEWriter(w http.ResponseWriter) *SSEWriter {
+ sse := &SSEWriter{w: w}
+ if f, ok := w.(http.Flusher); ok {
+ sse.flusher = f
+ }
+ return sse
+}
+
+// Write sends an SSE message.
+func (s *SSEWriter) Write(data map[string]interface{}) {
+ b, _ := json.Marshal(data)
+ s.w.Write([]byte("data: " + string(b) + "\n\n"))
+ if s.flusher != nil {
+ s.flusher.Flush()
+ }
+}
+
+// SetupSSEHeaders sets up SSE response headers.
+func SetupSSEHeaders(w http.ResponseWriter) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.WriteHeader(http.StatusOK)
+}
\ No newline at end of file
diff --git a/internal/api/conversation_multi.go b/internal/api/conversation_multi.go
new file mode 100644
index 0000000..bb488e6
--- /dev/null
+++ b/internal/api/conversation_multi.go
@@ -0,0 +1,370 @@
+package api
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/muyue/muyue/internal/config"
+)
+
+// ConversationMeta represents metadata for a conversation (used for listing).
+type ConversationMeta struct {
+ ID string `json:"id"`
+ Title string `json:"title"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+ MessageCount int `json:"message_count"`
+}
+
+// ConversationStoreMulti manages multiple conversations.
+type ConversationStoreMulti struct {
+ mu sync.RWMutex
+ dir string
+ currentID string
+ conversations map[string]*Conversation
+}
+
+func NewConversationStoreMulti() *ConversationStoreMulti {
+ dir, err := config.ConfigDir()
+ if err != nil {
+ dir = "/tmp/muyue"
+ }
+ dir = filepath.Join(dir, "conversations")
+
+ cs := &ConversationStoreMulti{
+ dir: dir,
+ conversations: make(map[string]*Conversation),
+ }
+ cs.loadIndex()
+ return cs
+}
+
+func (cs *ConversationStoreMulti) loadIndex() {
+ os.MkdirAll(cs.dir, 0755)
+
+ // Load index file if exists
+ indexPath := filepath.Join(cs.dir, "index.json")
+ data, err := os.ReadFile(indexPath)
+ if err != nil {
+ // Create default conversation
+ cs.createDefault()
+ return
+ }
+
+ var index struct {
+ CurrentID string `json:"current_id"`
+ Conversations []ConversationMeta `json:"conversations"`
+ }
+ if err := json.Unmarshal(data, &index); err != nil {
+ cs.createDefault()
+ return
+ }
+
+ cs.currentID = index.CurrentID
+ if cs.currentID == "" {
+ cs.createDefault()
+ return
+ }
+
+ // Load all conversations
+ for _, meta := range index.Conversations {
+ convPath := filepath.Join(cs.dir, fmt.Sprintf("conv_%s.json", meta.ID))
+ data, err := os.ReadFile(convPath)
+ if err != nil {
+ continue
+ }
+ var conv Conversation
+ if err := json.Unmarshal(data, &conv); err != nil {
+ continue
+ }
+ cs.conversations[meta.ID] = &conv
+ }
+
+ // Ensure current conversation exists
+ if _, ok := cs.conversations[cs.currentID]; !ok {
+ cs.createDefault()
+ }
+}
+
+func (cs *ConversationStoreMulti) createDefault() {
+ cs.currentID = uuid.New().String()
+ cs.conversations[cs.currentID] = &Conversation{
+ Messages: []FeedMessage{},
+ CreatedAt: time.Now().Format(time.RFC3339),
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ }
+ cs.saveIndex()
+}
+
+func (cs *ConversationStoreMulti) saveIndex() error {
+ var metas []ConversationMeta
+ for id, conv := range cs.conversations {
+ title := "Nouvelle conversation"
+ if len(conv.Messages) > 0 {
+ // Use first user message as title
+ for _, m := range conv.Messages {
+ if m.Role == "user" {
+ if len(m.Content) > 50 {
+ title = m.Content[:50] + "..."
+ } else {
+ title = m.Content
+ }
+ break
+ }
+ }
+ }
+ metas = append(metas, ConversationMeta{
+ ID: id,
+ Title: title,
+ CreatedAt: conv.CreatedAt,
+ UpdatedAt: conv.UpdatedAt,
+ MessageCount: len(conv.Messages),
+ })
+ }
+
+ index := struct {
+ CurrentID string `json:"current_id"`
+ Conversations []ConversationMeta `json:"conversations"`
+ }{
+ CurrentID: cs.currentID,
+ Conversations: metas,
+ }
+
+ data, err := json.MarshalIndent(index, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(filepath.Join(cs.dir, "index.json"), data, 0600)
+}
+
+func (cs *ConversationStoreMulti) saveCurrent() error {
+ conv, ok := cs.conversations[cs.currentID]
+ if !ok {
+ return fmt.Errorf("no current conversation")
+ }
+
+ conv.UpdatedAt = time.Now().Format(time.RFC3339)
+ data, err := json.MarshalIndent(conv, "", " ")
+ if err != nil {
+ return err
+ }
+
+ convPath := filepath.Join(cs.dir, fmt.Sprintf("conv_%s.json", cs.currentID))
+ if err := os.WriteFile(convPath, data, 0600); err != nil {
+ return err
+ }
+
+ return cs.saveIndex()
+}
+
+// Current returns the current conversation store.
+func (cs *ConversationStoreMulti) Current() *ConversationStore {
+ cs.mu.RLock()
+ defer cs.mu.RUnlock()
+
+ conv, ok := cs.conversations[cs.currentID]
+ if !ok {
+ return &ConversationStore{
+ conv: &Conversation{
+ Messages: []FeedMessage{},
+ CreatedAt: time.Now().Format(time.RFC3339),
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ },
+ }
+ }
+
+ return &ConversationStore{
+ conv: conv,
+ }
+}
+
+// Get returns the current conversation messages.
+func (cs *ConversationStoreMulti) Get() []FeedMessage {
+ cs.mu.RLock()
+ defer cs.mu.RUnlock()
+
+ conv, ok := cs.conversations[cs.currentID]
+ if !ok {
+ return []FeedMessage{}
+ }
+
+ out := make([]FeedMessage, len(conv.Messages))
+ copy(out, conv.Messages)
+ return out
+}
+
+// Add adds a message to the current conversation.
+func (cs *ConversationStoreMulti) Add(role, content string) FeedMessage {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ conv, ok := cs.conversations[cs.currentID]
+ if !ok {
+ cs.currentID = uuid.New().String()
+ conv = &Conversation{
+ Messages: []FeedMessage{},
+ CreatedAt: time.Now().Format(time.RFC3339),
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ }
+ cs.conversations[cs.currentID] = conv
+ }
+
+ msg := FeedMessage{
+ ID: generateMsgID(),
+ Role: role,
+ Content: content,
+ Time: time.Now().Format(time.RFC3339),
+ }
+ conv.Messages = append(conv.Messages, msg)
+
+ go cs.saveCurrent() // Fire and forget
+
+ return msg
+}
+
+// Clear clears the current conversation.
+func (cs *ConversationStoreMulti) Clear() {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ conv, ok := cs.conversations[cs.currentID]
+ if !ok {
+ return
+ }
+
+ conv.Messages = []FeedMessage{}
+ conv.Summary = ""
+ conv.CreatedAt = time.Now().Format(time.RFC3339)
+ conv.UpdatedAt = time.Now().Format(time.RFC3339)
+
+ cs.saveCurrent()
+}
+
+// List returns all conversations.
+func (cs *ConversationStoreMulti) List() []ConversationMeta {
+ cs.mu.RLock()
+ defer cs.mu.RUnlock()
+
+ var metas []ConversationMeta
+ for id, conv := range cs.conversations {
+ title := "Nouvelle conversation"
+ if len(conv.Messages) > 0 {
+ for _, m := range conv.Messages {
+ if m.Role == "user" {
+ if len(m.Content) > 50 {
+ title = m.Content[:50] + "..."
+ } else {
+ title = m.Content
+ }
+ break
+ }
+ }
+ }
+ metas = append(metas, ConversationMeta{
+ ID: id,
+ Title: title,
+ CreatedAt: conv.CreatedAt,
+ UpdatedAt: conv.UpdatedAt,
+ MessageCount: len(conv.Messages),
+ })
+ }
+
+ return metas
+}
+
+// Create creates a new conversation and switches to it.
+func (cs *ConversationStoreMulti) Create() string {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ id := uuid.New().String()
+ cs.conversations[id] = &Conversation{
+ Messages: []FeedMessage{},
+ CreatedAt: time.Now().Format(time.RFC3339),
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ }
+ cs.currentID = id
+ cs.saveIndex()
+
+ return id
+}
+
+// Switch switches to a different conversation.
+func (cs *ConversationStoreMulti) Switch(id string) error {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ if _, ok := cs.conversations[id]; !ok {
+ return fmt.Errorf("conversation not found: %s", id)
+ }
+
+ cs.currentID = id
+ cs.saveIndex()
+
+ return nil
+}
+
+// GetByID returns a conversation by ID.
+func (cs *ConversationStoreMulti) GetByID(id string) (*Conversation, error) {
+ cs.mu.RLock()
+ defer cs.mu.RUnlock()
+
+ conv, ok := cs.conversations[id]
+ if !ok {
+ return nil, fmt.Errorf("conversation not found: %s", id)
+ }
+
+ return conv, nil
+}
+
+// Delete deletes a conversation.
+func (cs *ConversationStoreMulti) Delete(id string) error {
+ cs.mu.Lock()
+ defer cs.mu.Unlock()
+
+ if _, ok := cs.conversations[id]; !ok {
+ return fmt.Errorf("conversation not found: %s", id)
+ }
+
+ delete(cs.conversations, id)
+
+ // Delete file
+ convPath := filepath.Join(cs.dir, fmt.Sprintf("conv_%s.json", id))
+ os.Remove(convPath)
+
+ // If deleted current, switch to another
+ if cs.currentID == id {
+ if len(cs.conversations) > 0 {
+ for newID := range cs.conversations {
+ cs.currentID = newID
+ break
+ }
+ } else {
+ // Create new default
+ cs.currentID = uuid.New().String()
+ cs.conversations[cs.currentID] = &Conversation{
+ Messages: []FeedMessage{},
+ CreatedAt: time.Now().Format(time.RFC3339),
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ }
+ }
+ }
+
+ cs.saveIndex()
+
+ return nil
+}
+
+// CurrentID returns the current conversation ID.
+func (cs *ConversationStoreMulti) CurrentID() string {
+ cs.mu.RLock()
+ defer cs.mu.RUnlock()
+
+ return cs.currentID
+}
\ No newline at end of file
diff --git a/internal/api/handlers_chat.go b/internal/api/handlers_chat.go
index 6378c6a..aef47ba 100644
--- a/internal/api/handlers_chat.go
+++ b/internal/api/handlers_chat.go
@@ -13,8 +13,6 @@ import (
var thinkingTagRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?[Tt]hink>`)
-const maxToolIterations = 15
-
func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeError(w, "POST only", http.StatusMethodNotAllowed)
@@ -55,108 +53,31 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) {
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.WriteHeader(http.StatusOK)
+ SetupSSEHeaders(w)
flusher, canFlush := w.(http.Flusher)
- writeSSE := func(data map[string]interface{}) {
- b, _ := json.Marshal(data)
- w.Write([]byte("data: " + string(b) + "\n\n"))
- if canFlush {
- flusher.Flush()
- }
- }
+
+ sseWriter := NewSSEWriter(w)
+
ctx := context.Background()
messages := s.buildContextMessages(userMessage)
- var finalContent string
- var allToolCalls []map[string]interface{}
- var allToolResults []map[string]interface{}
-
- for i := 0; i < maxToolIterations; i++ {
- resp, err := orb.SendWithTools(messages)
- if err != nil {
- writeSSE(map[string]interface{}{"error": err.Error()})
+ engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON)
+ engine.OnChunk(func(data map[string]interface{}) {
+ if data == nil {
return
}
-
- choice := resp.Choices[0]
- content := cleanThinkingTags(choice.Message.Content)
-
- if content != "" {
- words := strings.Fields(content)
- for i, w := range words {
- chunk := w
- if i < len(words)-1 {
- chunk += " "
- }
- writeSSE(map[string]interface{}{"content": chunk})
- }
- finalContent = content
+ sseWriter.Write(data)
+ if canFlush {
+ flusher.Flush()
}
+ })
- if len(choice.Message.ToolCalls) == 0 {
- break
- }
-
- assistantMsg := orchestrator.Message{
- Role: "assistant",
- Content: content,
- ToolCalls: choice.Message.ToolCalls,
- }
- messages = append(messages, assistantMsg)
-
- for _, tc := range choice.Message.ToolCalls {
- toolCallData := map[string]interface{}{
- "tool_call_id": tc.ID,
- "name": tc.Function.Name,
- "args": tc.Function.Arguments,
- }
- allToolCalls = append(allToolCalls, toolCallData)
- writeSSE(map[string]interface{}{"tool_call": toolCallData})
-
- call := agent.ToolCall{
- ID: tc.ID,
- Name: tc.Function.Name,
- Arguments: json.RawMessage(tc.Function.Arguments),
- }
-
- result, execErr := s.agentRegistry.Execute(ctx, call)
- if execErr != nil {
- result = agent.ToolResponse{
- Content: execErr.Error(),
- IsError: true,
- }
- }
-
- resultData := map[string]interface{}{
- "tool_call_id": tc.ID,
- "content": result.Content,
- "is_error": result.IsError,
- }
- writeSSE(map[string]interface{}{"tool_result": resultData})
-
- allToolResults = append(allToolResults, map[string]interface{}{
- "tool_call_id": tc.ID,
- "name": tc.Function.Name,
- "args": tc.Function.Arguments,
- "result": result.Content,
- "is_error": result.IsError,
- })
-
- messages = append(messages, orchestrator.Message{
- Role: "tool",
- Content: result.Content,
- ToolCallID: tc.ID,
- Name: tc.Function.Name,
- })
- }
-
- finalContent = ""
+ finalContent, allToolCalls, allToolResults, err := engine.RunWithTools(ctx, messages)
+ if err != nil {
+ sseWriter.Write(map[string]interface{}{"error": err.Error()})
+ return
}
storeContent := finalContent
@@ -171,68 +92,18 @@ func (s *Server) handleStreamChat(w http.ResponseWriter, orb *orchestrator.Orche
}
s.convStore.Add("assistant", storeContent)
- writeSSE(map[string]interface{}{"done": "true"})
+ sseWriter.Write(map[string]interface{}{"done": "true"})
}
func (s *Server) handleNonStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) {
ctx := context.Background()
messages := s.buildContextMessages(userMessage)
- var finalContent string
-
- for i := 0; i < maxToolIterations; i++ {
- resp, err := orb.SendWithTools(messages)
- if err != nil {
- writeError(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- choice := resp.Choices[0]
- content := cleanThinkingTags(choice.Message.Content)
-
- if content != "" {
- finalContent = content
- }
-
- if len(choice.Message.ToolCalls) == 0 {
- break
- }
-
- assistantMsg := orchestrator.Message{
- Role: "assistant",
- Content: content,
- ToolCalls: choice.Message.ToolCalls,
- }
- messages = append(messages, assistantMsg)
-
- for _, tc := range choice.Message.ToolCalls {
- call := agent.ToolCall{
- ID: tc.ID,
- Name: tc.Function.Name,
- Arguments: json.RawMessage(tc.Function.Arguments),
- }
-
- result, execErr := s.agentRegistry.Execute(ctx, call)
- if execErr != nil {
- result = agent.ToolResponse{
- Content: execErr.Error(),
- IsError: true,
- }
- }
-
- messages = append(messages, orchestrator.Message{
- Role: "tool",
- Content: result.Content,
- ToolCallID: tc.ID,
- Name: tc.Function.Name,
- })
- }
-
- finalContent = ""
- }
-
- if finalContent == "" {
- finalContent = "(tool calls completed, no text response)"
+ engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON)
+ finalContent, err := engine.RunNonStream(ctx, messages)
+ if err != nil {
+ writeError(w, err.Error(), http.StatusInternalServerError)
+ return
}
s.convStore.Add("assistant", finalContent)
diff --git a/internal/api/handlers_shell_chat.go b/internal/api/handlers_shell_chat.go
index 60bb0d3..e3c633a 100644
--- a/internal/api/handlers_shell_chat.go
+++ b/internal/api/handlers_shell_chat.go
@@ -6,7 +6,6 @@ import (
"net/http"
"strings"
- "github.com/muyue/muyue/internal/agent"
"github.com/muyue/muyue/internal/orchestrator"
)
@@ -35,6 +34,22 @@ type ToolCallInfo struct {
Error string `json:"error,omitempty"`
}
+func toString(v interface{}) string {
+ if v == nil {
+ return ""
+ }
+ s, _ := v.(string)
+ return s
+}
+
+func toBool(v interface{}) bool {
+ if v == nil {
+ return false
+ }
+ b, _ := v.(bool)
+ return b
+}
+
func (s *Server) handleShellChat(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeError(w, "POST only", http.StatusMethodNotAllowed)
@@ -102,109 +117,59 @@ Tu peux appeler des outils pour exécuter des commandes, lire des fichiers, etc.
}
func (s *Server) handleShellChatStream(w http.ResponseWriter, orb *orchestrator.Orchestrator, req ShellChatRequest) {
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.WriteHeader(http.StatusOK)
+ SetupSSEHeaders(w)
flusher, canFlush := w.(http.Flusher)
-
- writeSSE := func(data map[string]interface{}) {
- b, _ := json.Marshal(data)
- w.Write([]byte("data: " + string(b) + "\n\n"))
- if canFlush {
- flusher.Flush()
- }
- }
+ sseWriter := NewSSEWriter(w)
ctx := context.Background()
messages := []orchestrator.Message{
{Role: "user", Content: req.Message},
}
- var finalContent string
- var toolCalls []ToolCallInfo
+ engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON)
- for i := 0; i < maxShellToolIterations; i++ {
- resp, err := orb.SendWithTools(messages)
- if err != nil {
- writeSSE(map[string]interface{}{"error": err.Error()})
+ var toolCalls []ToolCallInfo
+ engine.OnChunk(func(data map[string]interface{}) {
+ if data == nil {
return
}
-
- choice := resp.Choices[0]
- content := cleanThinkingTags(choice.Message.Content)
-
- if content != "" {
- words := strings.Fields(content)
- for i, w := range words {
- chunk := w
- if i < len(words)-1 {
- chunk += " "
- }
- writeSSE(map[string]interface{}{"content": chunk})
- }
- finalContent = content
+ sseWriter.Write(data)
+ if canFlush {
+ flusher.Flush()
}
-
- if len(choice.Message.ToolCalls) == 0 {
- break
- }
-
- assistantMsg := orchestrator.Message{
- Role: "assistant",
- Content: content,
- ToolCalls: choice.Message.ToolCalls,
- }
- messages = append(messages, assistantMsg)
-
- for _, tc := range choice.Message.ToolCalls {
- toolCallData := map[string]interface{}{
- "tool_call_id": tc.ID,
- "name": tc.Function.Name,
- "args": tc.Function.Arguments,
- }
- writeSSE(map[string]interface{}{"tool_call": toolCallData})
-
+ if tc, ok := data["tool_call"].(map[string]interface{}); ok {
argsMap := make(map[string]interface{})
- json.Unmarshal([]byte(tc.Function.Arguments), &argsMap)
-
- tcInfo := ToolCallInfo{
- ID: tc.ID,
- Name: tc.Function.Name,
+ if args, ok := tc["args"].(string); ok {
+ json.Unmarshal([]byte(args), &argsMap)
+ }
+ toolCalls = append(toolCalls, ToolCallInfo{
+ ID: toString(tc["tool_call_id"]),
+ Name: toString(tc["name"]),
Args: argsMap,
- }
-
- call := agent.ToolCall{
- ID: tc.ID,
- Name: tc.Function.Name,
- Arguments: json.RawMessage(tc.Function.Arguments),
- }
-
- result, execErr := s.agentRegistry.Execute(ctx, call)
- if execErr != nil {
- tcInfo.Error = execErr.Error()
- writeSSE(map[string]interface{}{"tool_result": tcInfo})
- } else {
- tcInfo.Result = &toolResponseData{
- Content: result.Content,
- IsError: result.IsError,
- Meta: result.Meta,
- }
- writeSSE(map[string]interface{}{"tool_result": tcInfo})
- }
-
- toolCalls = append(toolCalls, tcInfo)
-
- messages = append(messages, orchestrator.Message{
- Role: "tool",
- Content: result.Content,
- ToolCallID: tc.ID,
- Name: tc.Function.Name,
})
}
+ if tr, ok := data["tool_result"].(map[string]interface{}); ok {
+ tcID := toString(tr["tool_call_id"])
+ for i := range toolCalls {
+ if toolCalls[i].ID == tcID {
+ if err, ok := tr["is_error"].(bool); ok && err {
+ toolCalls[i].Error = toString(tr["content"])
+ } else {
+ toolCalls[i].Result = &toolResponseData{
+ Content: toString(tr["content"]),
+ IsError: toBool(tr["is_error"]),
+ }
+ }
+ break
+ }
+ }
+ }
+ })
- finalContent = ""
+ finalContent, _, _, err := engine.RunWithTools(ctx, messages)
+ if err != nil {
+ sseWriter.Write(map[string]interface{}{"error": err.Error()})
+ return
}
if finalContent == "" && len(toolCalls) > 0 {
@@ -215,7 +180,7 @@ func (s *Server) handleShellChatStream(w http.ResponseWriter, orb *orchestrator.
Content: finalContent,
ToolCalls: toolCalls,
})
- writeSSE(map[string]interface{}{"done": true, "response": string(writeJSONResp)})
+ sseWriter.Write(map[string]interface{}{"done": true, "response": string(writeJSONResp)})
}
func (s *Server) handleShellChatNonStream(w http.ResponseWriter, orb *orchestrator.Orchestrator, req ShellChatRequest) {
@@ -224,80 +189,20 @@ func (s *Server) handleShellChatNonStream(w http.ResponseWriter, orb *orchestrat
{Role: "user", Content: req.Message},
}
- var finalContent string
- var toolCalls []ToolCallInfo
+ engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON)
- for i := 0; i < maxShellToolIterations; i++ {
- resp, err := orb.SendWithTools(messages)
- if err != nil {
- writeError(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- choice := resp.Choices[0]
- content := cleanThinkingTags(choice.Message.Content)
-
- if content != "" {
- finalContent = content
- }
-
- if len(choice.Message.ToolCalls) == 0 {
- break
- }
-
- assistantMsg := orchestrator.Message{
- Role: "assistant",
- Content: content,
- ToolCalls: choice.Message.ToolCalls,
- }
- messages = append(messages, assistantMsg)
-
- for _, tc := range choice.Message.ToolCalls {
- argsMap := make(map[string]interface{})
- json.Unmarshal([]byte(tc.Function.Arguments), &argsMap)
-
- tcInfo := ToolCallInfo{
- ID: tc.ID,
- Name: tc.Function.Name,
- Args: argsMap,
- }
-
- call := agent.ToolCall{
- ID: tc.ID,
- Name: tc.Function.Name,
- Arguments: json.RawMessage(tc.Function.Arguments),
- }
-
- result, execErr := s.agentRegistry.Execute(ctx, call)
- if execErr != nil {
- tcInfo.Error = execErr.Error()
- } else {
- tcInfo.Result = &toolResponseData{
- Content: result.Content,
- IsError: result.IsError,
- Meta: result.Meta,
- }
- }
-
- toolCalls = append(toolCalls, tcInfo)
-
- messages = append(messages, orchestrator.Message{
- Role: "tool",
- Content: result.Content,
- ToolCallID: tc.ID,
- Name: tc.Function.Name,
- })
- }
-
- finalContent = ""
+ finalContent, err := engine.RunNonStream(ctx, messages)
+ if err != nil {
+ writeError(w, err.Error(), http.StatusInternalServerError)
+ return
}
- if finalContent == "" && len(toolCalls) > 0 {
+ if finalContent == "" {
finalContent = "(tool calls completed, no text response)"
}
writeJSON(w, ShellChatResponse{
Content: finalContent,
- ToolCalls: toolCalls,
+ ToolCalls: nil,
})
}
\ No newline at end of file
diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go
new file mode 100644
index 0000000..6ee285c
--- /dev/null
+++ b/internal/api/handlers_test.go
@@ -0,0 +1,66 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/muyue/muyue/internal/agent"
+)
+
+func TestHandleToolCall(t *testing.T) {
+ // Test unknown tool returns error
+ registry := agent.NewRegistry()
+
+ // Register a test tool
+ testTool, _ := agent.NewTool[struct{ Command string }]("test_tool", "Test tool", func(ctx context.Context, params struct{ Command string }) (agent.ToolResponse, error) {
+ return agent.TextResponse("executed: " + params.Command), nil
+ })
+ registry.Register(testTool)
+
+ // Test executing known tool
+ resp, err := registry.Execute(context.Background(), agent.ToolCall{
+ ID: "test-id",
+ Name: "test_tool",
+ Arguments: json.RawMessage(`{"Command": "hello"}`),
+ })
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if resp.IsError {
+ t.Errorf("expected no error, got error response")
+ }
+
+ // Test executing unknown tool
+ resp, err = registry.Execute(context.Background(), agent.ToolCall{
+ ID: "test-id",
+ Name: "unknown_tool",
+ Arguments: json.RawMessage(`{}`),
+ })
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if !resp.IsError {
+ t.Errorf("expected error for unknown tool")
+ }
+}
+
+func TestCleanThinkingTags(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"hello world", "hello world"},
+ {"$1')
.replace(/^### (.+)$/gm, '