From 61da8039bc9dbb5ec9421bb8200333a5a2c1d6f1 Mon Sep 17 00:00:00 2001 From: Augustin Date: Wed, 22 Apr 2026 21:19:36 +0200 Subject: [PATCH] feat(agent): refactor AI chat with streaming, agent registry, and tool execution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace old tool-call regex with proper agent registry - Add streaming chat via SSE (handleStreamChat / handleNonStreamChat) - Add internal/agent package with tool definitions and execution - Add orchestrator with system prompt and tool scaffolding - Add internal/agent/ directory - Studio.jsx: streaming chat with thinking indicator and tool result rendering - global.css: chat bubble styles, streaming animation, thinking dots - handlers_chat.go: full rewrite using new agent/orchestrator architecture 💘 Generated with Crush Assisted-by: MiniMax-M2.7 via Crush --- internal/agent/definitions.go | 311 +++++++++++++ internal/agent/impl.go | 579 ++++++++++++++++++++++++ internal/agent/prompt.go | 10 + internal/agent/prompts/studio_system.md | 44 ++ internal/agent/tools.go | 218 +++++++++ internal/api/handlers_chat.go | 279 +++++++----- internal/orchestrator/orchestrator.go | 124 ++++- web/src/components/Studio.jsx | 143 +++++- web/src/styles/global.css | 88 ++++ 9 files changed, 1654 insertions(+), 142 deletions(-) create mode 100644 internal/agent/definitions.go create mode 100644 internal/agent/impl.go create mode 100644 internal/agent/prompt.go create mode 100644 internal/agent/prompts/studio_system.md create mode 100644 internal/agent/tools.go diff --git a/internal/agent/definitions.go b/internal/agent/definitions.go new file mode 100644 index 0000000..0c0b534 --- /dev/null +++ b/internal/agent/definitions.go @@ -0,0 +1,311 @@ +package agent + +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + "strings" + "time" +) + +type TerminalParams struct { + Command string `json:"command" description:"The shell command to execute"` + Timeout int `json:"timeout,omitempty" description:"Timeout in seconds (default 60, max 300)"` +} + +func NewTerminalTool() (*ToolDefinition, error) { + return NewTool("terminal", + "Execute a shell command on the local system and return the output. Use for running builds, tests, git operations, package management, system info, or any CLI task. Commands run in the user's home directory by default. Long-running commands are auto-terminated.", + func(ctx context.Context, p TerminalParams) (ToolResponse, error) { + if p.Command == "" { + return TextErrorResponse("command is required"), nil + } + + timeout := time.Duration(p.Timeout) * time.Second + if timeout == 0 { + timeout = 60 * time.Second + } + if timeout > 300*time.Second { + timeout = 300 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + shell := detectShell() + + cmd := exec.CommandContext(ctx, shell, "-c", p.Command) + output, err := cmd.CombinedOutput() + + result := string(output) + if len(result) > 10000 { + result = result[:10000] + "\n... [truncated]" + } + + if err != nil { + return TextErrorResponse(fmt.Sprintf("Error: %v\n\n%s", err, result)), nil + } + + return TextResponse(result), nil + }) +} + +type CrushRunParams struct { + Task string `json:"task" description:"The task description for Crush to execute"` +} + +func NewCrushRunTool() (*ToolDefinition, error) { + return NewTool("crush_run", + "Delegate a complex coding task to the Crush AI agent. Crush has access to file editing, code search, bash execution, and other development tools. Use this for multi-step coding tasks like refactoring, debugging, implementing features, or code review. Returns the agent's final output.", + func(ctx context.Context, p CrushRunParams) (ToolResponse, error) { + if p.Task == "" { + return TextErrorResponse("task is required"), nil + } + + ctx, cancel := context.WithTimeout(ctx, 300*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "crush", "run", p.Task) + output, err := cmd.CombinedOutput() + + result := string(output) + if len(result) > 15000 { + result = result[:15000] + "\n... [truncated]" + } + + if err != nil { + return TextErrorResponse(fmt.Sprintf("Crush error: %v\n\n%s", err, result)), nil + } + + return TextResponse(result), nil + }) +} + +type ReadFileParams struct { + Path string `json:"path" description:"Absolute or relative path to the file to read"` + Offset int `json:"offset,omitempty" description:"Line number to start reading from (0-based, default 0)"` + Limit int `json:"limit,omitempty" description:"Maximum number of lines to read (default 200, max 2000)"` +} + +func NewReadFileTool() (*ToolDefinition, error) { + return NewTool("read_file", + "Read file contents from the local filesystem. Returns the file content with line numbers. Supports offset/limit for reading specific sections of large files.", + func(ctx context.Context, p ReadFileParams) (ToolResponse, error) { + if p.Path == "" { + return TextErrorResponse("path is required"), nil + } + + expanded := expandHome(p.Path) + data, err := readFileLimited(expanded, p.Offset, p.Limit) + if err != nil { + return TextErrorResponse(fmt.Sprintf("read error: %v", err)), nil + } + + return TextResponse(data), nil + }) +} + +type ListFilesParams struct { + Path string `json:"path,omitempty" description:"Directory path to list (default: user home)"` + Depth int `json:"depth,omitempty" description:"Maximum depth to traverse (default 1, max 3)"` +} + +func NewListFilesTool() (*ToolDefinition, error) { + return NewTool("list_files", + "List files and directories at a given path. Shows directory tree structure with file names. Useful for exploring project structure or finding specific files.", + func(ctx context.Context, p ListFilesParams) (ToolResponse, error) { + dir := expandHome(p.Path) + if dir == "" { + dir, _ = osUserHomeDir() + } + + if p.Depth <= 0 { + p.Depth = 1 + } + if p.Depth > 3 { + p.Depth = 3 + } + + result, err := listDirTree(dir, p.Depth, 0) + if err != nil { + return TextErrorResponse(fmt.Sprintf("list error: %v", err)), nil + } + + return TextResponse(result), nil + }) +} + +type SearchFilesParams struct { + Pattern string `json:"pattern" description:"Search pattern (supports * and ? glob wildcards)"` + Path string `json:"path,omitempty" description:"Directory to search in (default: current directory)"` +} + +func NewSearchFilesTool() (*ToolDefinition, error) { + return NewTool("search_files", + "Search for files by name pattern using glob syntax. Use * for any characters, ** for recursive matching. Returns matching file paths sorted by name.", + func(ctx context.Context, p SearchFilesParams) (ToolResponse, error) { + if p.Pattern == "" { + return TextErrorResponse("pattern is required"), nil + } + + dir := expandHome(p.Path) + if dir == "" { + dir = "." + } + + matches, err := filepath.Glob(filepath.Join(dir, p.Pattern)) + if err != nil { + return TextErrorResponse(fmt.Sprintf("glob error: %v", err)), nil + } + + if len(matches) == 0 { + return TextResponse("No files found matching pattern."), nil + } + + if len(matches) > 100 { + matches = matches[:100] + } + + var result strings.Builder + for _, m := range matches { + result.WriteString(m) + result.WriteString("\n") + } + + return TextResponse(result.String()), nil + }) +} + +type GrepContentParams struct { + Pattern string `json:"pattern" description:"Text pattern to search for in file contents"` + Path string `json:"path,omitempty" description:"Directory to search in (default: current directory)"` + Include string `json:"include,omitempty" description:"File extension filter, e.g. '*.go' or '*.{js,ts}'"` +} + +func NewGrepContentTool() (*ToolDefinition, error) { + return NewTool("grep_content", + "Search for text patterns inside file contents. Returns matching lines with file paths and line numbers. Use include to filter by file extension.", + func(ctx context.Context, p GrepContentParams) (ToolResponse, error) { + if p.Pattern == "" { + return TextErrorResponse("pattern is required"), nil + } + + dir := expandHome(p.Path) + if dir == "" { + dir = "." + } + + result, err := grepFiles(dir, p.Pattern, p.Include) + if err != nil { + return TextErrorResponse(fmt.Sprintf("grep error: %v", err)), nil + } + + if result == "" { + return TextResponse("No matches found."), nil + } + + return TextResponse(result), nil + }) +} + +type GetConfigParams struct { + Section string `json:"section,omitempty" description:"Config section to retrieve: 'providers', 'profile', 'tools', 'terminal', 'all' (default: 'all')"` +} + +func NewGetConfigTool() (*ToolDefinition, error) { + return NewTool("get_config", + "Read the Muyue configuration. Returns provider settings, profile info, installed tools, terminal config, etc. Use section parameter to get a specific part, or 'all' for the full config.", + func(ctx context.Context, p GetConfigParams) (ToolResponse, error) { + return getConfigSection(p.Section), nil + }) +} + +type SetProviderParams struct { + Name string `json:"name" description:"Provider name (e.g. 'openai', 'anthropic', 'ollama')"` + APIKey string `json:"api_key,omitempty" description:"API key for the provider"` + BaseURL string `json:"base_url,omitempty" description:"Custom base URL for the provider API"` + Model string `json:"model,omitempty" description:"Model identifier to use"` + Active *bool `json:"active,omitempty" description:"Set to true to make this the active provider"` +} + +func NewSetProviderTool() (*ToolDefinition, error) { + return NewTool("set_provider", + "Configure an AI provider in Muyue settings. Can create, update, or activate a provider. API keys are automatically encrypted. Set active=true to switch to this provider.", + func(ctx context.Context, p SetProviderParams) (ToolResponse, error) { + if p.Name == "" { + return TextErrorResponse("name is required"), nil + } + + return setProviderConfig(p), nil + }) +} + +type ManageSSHParams struct { + Action string `json:"action" description:"Action to perform: 'list', 'add', 'remove'"` + Name string `json:"name,omitempty" description:"Connection name (required for add/remove)"` + Host string `json:"host,omitempty" description:"SSH host (required for add)"` + Port int `json:"port,omitempty" description:"SSH port (default: 22)"` + User string `json:"user,omitempty" description:"SSH username (required for add)"` + KeyPath string `json:"key_path,omitempty" description:"Path to SSH private key"` +} + +func NewManageSSHTool() (*ToolDefinition, error) { + return NewTool("manage_ssh", + "Manage SSH connections configured in Muyue. List existing connections, add new ones, or remove connections. SSH configs are persisted to the Muyue config file.", + func(ctx context.Context, p ManageSSHParams) (ToolResponse, error) { + if p.Action == "" { + return TextErrorResponse("action is required (list, add, remove)"), nil + } + + return manageSSHAction(p), nil + }) +} + +type WebFetchParams struct { + URL string `json:"url" description:"The URL to fetch content from"` +} + +func NewWebFetchTool() (*ToolDefinition, error) { + return NewTool("web_fetch", + "Fetch content from a URL and return the text. Useful for reading documentation, APIs, or web resources. Only HTTP/HTTPS URLs are supported.", + func(ctx context.Context, p WebFetchParams) (ToolResponse, error) { + if p.URL == "" { + return TextErrorResponse("url is required"), nil + } + + return fetchURL(p.URL), nil + }) +} + +func DefaultRegistry() *Registry { + r := NewRegistry() + + tools := []*ToolDefinition{ + must(NewTerminalTool()), + must(NewCrushRunTool()), + must(NewReadFileTool()), + must(NewListFilesTool()), + must(NewSearchFilesTool()), + must(NewGrepContentTool()), + must(NewGetConfigTool()), + must(NewSetProviderTool()), + must(NewManageSSHTool()), + must(NewWebFetchTool()), + } + + for _, t := range tools { + if err := r.Register(t); err != nil { + panic(err) + } + } + + return r +} + +func must(t *ToolDefinition, err error) *ToolDefinition { + if err != nil { + panic(err) + } + return t +} diff --git a/internal/agent/impl.go b/internal/agent/impl.go new file mode 100644 index 0000000..53090cb --- /dev/null +++ b/internal/agent/impl.go @@ -0,0 +1,579 @@ +package agent + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" + "time" +) + +func detectShell() string { + shells := []string{"zsh", "bash", "fish", "pwsh", "powershell"} + for _, s := range shells { + if path, err := exec.LookPath(s); err == nil { + return path + } + } + return "/bin/sh" +} + +func expandHome(path string) string { + if path == "" { + return "" + } + if path == "~" { + home, _ := os.UserHomeDir() + return home + } + if strings.HasPrefix(path, "~/") { + home, _ := os.UserHomeDir() + return filepath.Join(home, path[2:]) + } + return path +} + +func osUserHomeDir() (string, error) { + return os.UserHomeDir() +} + +func readFileLimited(path string, offset, limit int) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + + lines := strings.Split(string(data), "\n") + + if offset < 0 { + offset = 0 + } + if offset > len(lines) { + offset = len(lines) + } + + end := offset + limit + if limit <= 0 || limit > 2000 { + limit = 2000 + } + if end > len(lines) { + end = len(lines) + } + if end-offset > limit { + end = offset + limit + } + + selected := lines[offset:end] + + var buf strings.Builder + for i, line := range selected { + fmt.Fprintf(&buf, "%6d\t%s\n", offset+i+1, line) + } + + return buf.String(), nil +} + +func listDirTree(dir string, maxDepth, currentDepth int) (string, error) { + info, err := os.Stat(dir) + if err != nil { + return "", err + } + if !info.IsDir() { + return dir + "\n", nil + } + + entries, err := os.ReadDir(dir) + if err != nil { + return "", err + } + + var buf strings.Builder + indent := strings.Repeat(" ", currentDepth) + + for _, entry := range entries { + name := entry.Name() + if strings.HasPrefix(name, ".") && name != "." && name != ".." { + continue + } + + if entry.IsDir() { + fmt.Fprintf(&buf, "%s%s/\n", indent, name) + if currentDepth < maxDepth { + sub, err := listDirTree(filepath.Join(dir, name), maxDepth, currentDepth+1) + if err == nil { + buf.WriteString(sub) + } + } + } else { + fmt.Fprintf(&buf, "%s%s\n", indent, name) + } + } + + return buf.String(), nil +} + +func grepFiles(dir, pattern, include string) (string, error) { + if include != "" { + matches, err := filepath.Glob(filepath.Join(dir, include)) + if err != nil { + return "", err + } + if len(matches) == 0 { + return "", nil + } + var buf strings.Builder + for _, match := range matches { + result, err := grepInFile(match, pattern) + if err != nil { + continue + } + buf.WriteString(result) + } + return buf.String(), nil + } + + return grepInDir(dir, pattern, 0) +} + +func grepInDir(dir, pattern string, depth int) (string, error) { + if depth > 10 { + return "", nil + } + + var buf strings.Builder + + entries, err := os.ReadDir(dir) + if err != nil { + return "", err + } + + for _, entry := range entries { + name := entry.Name() + if strings.HasPrefix(name, ".") { + continue + } + + path := filepath.Join(dir, name) + + if entry.IsDir() { + sub, err := grepInDir(path, pattern, depth+1) + if err == nil { + buf.WriteString(sub) + } + continue + } + + result, err := grepInFile(path, pattern) + if err != nil { + continue + } + buf.WriteString(result) + } + + return buf.String(), nil +} + +func grepInFile(path, pattern string) (string, error) { + re, err := regexp.Compile(pattern) + if err != nil { + re, err = regexp.Compile(regexp.QuoteMeta(pattern)) + if err != nil { + return "", err + } + } + + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + + var buf strings.Builder + scanner := bufio.NewScanner(file) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + lineNum := 0 + matchCount := 0 + for scanner.Scan() { + lineNum++ + if re.MatchString(scanner.Text()) { + fmt.Fprintf(&buf, "%s:%d: %s\n", path, lineNum, scanner.Text()) + matchCount++ + if matchCount >= 50 { + buf.WriteString("... [truncated, more matches exist]\n") + break + } + } + } + + return buf.String(), nil +} + +func getConfigSection(section string) ToolResponse { + configPath, err := os.UserConfigDir() + if err != nil { + return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err)) + } + configPath = filepath.Join(configPath, "muyue", "config.yaml") + + data, err := os.ReadFile(configPath) + if err != nil { + return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err)) + } + + switch section { + case "providers", "profile", "tools", "terminal": + sectionData := extractYAMLSection(data, section) + if sectionData == "" { + return TextResponse(fmt.Sprintf("Section '%s' not found in config.", section)) + } + return TextResponse(sectionData) + default: + content := string(data) + if len(content) > 8000 { + content = content[:8000] + "\n... [truncated]" + } + return TextResponse(content) + } +} + +func extractYAMLSection(data []byte, section string) string { + lines := strings.Split(string(data), "\n") + inSection := false + indentLevel := 0 + var buf strings.Builder + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + if inSection { + buf.WriteString("\n") + } + continue + } + + if !inSection { + if strings.HasPrefix(trimmed, section+":") || strings.HasPrefix(trimmed, section+" ") { + inSection = true + indentLevel = len(line) - len(strings.TrimLeft(line, " ")) + buf.WriteString(line) + buf.WriteString("\n") + } + continue + } + + currentIndent := len(line) - len(strings.TrimLeft(line, " ")) + if currentIndent <= indentLevel && trimmed != "" { + break + } + buf.WriteString(line) + buf.WriteString("\n") + } + + return strings.TrimSpace(buf.String()) +} + +func setProviderConfig(p SetProviderParams) ToolResponse { + configPath, err := os.UserConfigDir() + if err != nil { + return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err)) + } + configPath = filepath.Join(configPath, "muyue", "config.yaml") + + data, err := os.ReadFile(configPath) + if err != nil { + return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err)) + } + + lines := strings.Split(string(data), "\n") + inProviders := false + providerIndent := 0 + foundProvider := false + insertIdx := -1 + lastProviderEnd := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inProviders { + if strings.HasPrefix(trimmed, "providers:") { + inProviders = true + providerIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + + currentIndent := len(line) - len(strings.TrimLeft(line, " ")) + if currentIndent <= providerIndent && trimmed != "" && !strings.HasPrefix(trimmed, "#") { + lastProviderEnd = i + break + } + + if currentIndent == providerIndent+2 && strings.HasPrefix(trimmed, "- name:") { + nameMatch := strings.TrimPrefix(trimmed, "- name:") + nameMatch = strings.TrimSpace(nameMatch) + if nameMatch == p.Name { + foundProvider = true + insertIdx = i + } + if insertIdx == -1 || insertIdx < i { + insertIdx = i + } + } + } + + if lastProviderEnd == -1 { + lastProviderEnd = len(lines) + } + + entryIndent := strings.Repeat(" ", providerIndent+4) + + var newEntry strings.Builder + newEntry.WriteString(fmt.Sprintf(" - name: %s\n", p.Name)) + if p.Model != "" { + newEntry.WriteString(fmt.Sprintf("%smodel: %s\n", entryIndent, p.Model)) + } + if p.BaseURL != "" { + newEntry.WriteString(fmt.Sprintf("%sbase_url: %s\n", entryIndent, p.BaseURL)) + } + if p.APIKey != "" { + newEntry.WriteString(fmt.Sprintf("%sapi_key: %s\n", entryIndent, p.APIKey)) + } + if p.Active != nil { + newEntry.WriteString(fmt.Sprintf("%sactive: %v\n", entryIndent, *p.Active)) + } + + if foundProvider && insertIdx >= 0 { + var endIdx int + for endIdx = insertIdx + 1; endIdx < len(lines); endIdx++ { + li := len(lines[endIdx]) - len(strings.TrimLeft(lines[endIdx], " ")) + if li <= providerIndent+2 || lines[endIdx] == "" { + if endIdx > insertIdx+1 && strings.TrimSpace(lines[endIdx]) == "" { + continue + } + break + } + } + + newLines := make([]string, 0, len(lines)) + newLines = append(newLines, lines[:insertIdx]...) + newLines = append(newLines, strings.TrimSuffix(newEntry.String(), "\n")) + newLines = append(newLines, lines[endIdx:]...) + lines = newLines + } else { + insertAt := lastProviderEnd + newLines := make([]string, 0, len(lines)+10) + newLines = append(newLines, lines[:insertAt]...) + newLines = append(newLines, strings.TrimSuffix(newEntry.String(), "\n")) + newLines = append(newLines, lines[insertAt:]...) + lines = newLines + } + + content := strings.Join(lines, "\n") + if err := os.WriteFile(configPath, []byte(content), 0600); err != nil { + return TextErrorResponse(fmt.Sprintf("write config error: %v", err)) + } + + return TextResponse(fmt.Sprintf("Provider '%s' configured successfully.", p.Name)) +} + +func manageSSHAction(p ManageSSHParams) ToolResponse { + configPath, err := os.UserConfigDir() + if err != nil { + return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err)) + } + configPath = filepath.Join(configPath, "muyue", "config.yaml") + + data, err := os.ReadFile(configPath) + if err != nil { + return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err)) + } + + switch p.Action { + case "list": + sshSection := extractYAMLSection(data, "ssh") + if sshSection == "" { + return TextResponse("No SSH connections configured.") + } + return TextResponse(sshSection) + + case "add": + if p.Name == "" || p.Host == "" || p.User == "" { + return TextErrorResponse("name, host, and user are required for add action") + } + if p.Port == 0 { + p.Port = 22 + } + + lines := strings.Split(string(data), "\n") + sshIdx := -1 + sshIndent := 0 + lastSSHEnd := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if sshIdx == -1 && strings.HasPrefix(trimmed, "ssh:") { + sshIdx = i + sshIndent = len(line) - len(strings.TrimLeft(line, " ")) + continue + } + if sshIdx != -1 { + li := len(line) - len(strings.TrimLeft(line, " ")) + if li <= sshIndent && trimmed != "" { + lastSSHEnd = i + break + } + } + } + + if lastSSHEnd == -1 { + lastSSHEnd = len(lines) + } + + entry := fmt.Sprintf(" - name: %s\n host: %s\n port: %d\n user: %s", p.Name, p.Host, p.Port, p.User) + if p.KeyPath != "" { + entry += fmt.Sprintf("\n key_path: %s", p.KeyPath) + } + + newLines := make([]string, 0, len(lines)+10) + newLines = append(newLines, lines[:lastSSHEnd]...) + newLines = append(newLines, entry) + newLines = append(newLines, lines[lastSSHEnd:]...) + + if err := os.WriteFile(configPath, []byte(strings.Join(newLines, "\n")), 0600); err != nil { + return TextErrorResponse(fmt.Sprintf("write config error: %v", err)) + } + return TextResponse(fmt.Sprintf("SSH connection '%s' (%s@%s:%d) added.", p.Name, p.User, p.Host, p.Port)) + + case "remove": + if p.Name == "" { + return TextErrorResponse("name is required for remove action") + } + + lines := strings.Split(string(data), "\n") + newLines := make([]string, 0, len(lines)) + skipping := false + removed := false + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.Contains(trimmed, "name: "+p.Name) && strings.HasPrefix(trimmed, "-") { + skipping = true + removed = true + continue + } + if skipping { + li := len(line) - len(strings.TrimLeft(line, " ")) + if li > 6 && i < len(lines)-1 && strings.TrimSpace(lines[i+1]) != "" { + continue + } + skipping = false + continue + } + newLines = append(newLines, line) + } + + if !removed { + return TextErrorResponse(fmt.Sprintf("SSH connection '%s' not found.", p.Name)) + } + + if err := os.WriteFile(configPath, []byte(strings.Join(newLines, "\n")), 0600); err != nil { + return TextErrorResponse(fmt.Sprintf("write config error: %v", err)) + } + return TextResponse(fmt.Sprintf("SSH connection '%s' removed.", p.Name)) + + default: + return TextErrorResponse("unknown action. Use 'list', 'add', or 'remove'") + } +} + +func fetchURL(url string) ToolResponse { + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + return TextErrorResponse("only http/https URLs are supported") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return TextErrorResponse(fmt.Sprintf("create request: %v", err)) + } + req.Header.Set("User-Agent", "MuyueStudio/1.0") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return TextErrorResponse(fmt.Sprintf("fetch error: %v", err)) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 50000)) + if err != nil { + return TextErrorResponse(fmt.Sprintf("read error: %v", err)) + } + + if resp.StatusCode != http.StatusOK { + return TextErrorResponse(fmt.Sprintf("HTTP %d: %s", resp.StatusCode, truncate(string(body), 2000))) + } + + contentType := resp.Header.Get("Content-Type") + if strings.Contains(contentType, "text/html") { + text := stripHTML(string(body)) + if len(text) > 8000 { + text = text[:8000] + "\n... [truncated]" + } + return TextResponse(text) + } + + result := string(body) + if len(result) > 10000 { + result = result[:10000] + "\n... [truncated]" + } + return TextResponse(result) +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func stripHTML(html string) string { + tagRe := regexp.MustCompile(`<[^>]*>`) + text := tagRe.ReplaceAllString(html, " ") + + entityRe := regexp.MustCompile(`&[a-zA-Z]+;`) + text = entityRe.ReplaceAllStringFunc(text, func(s string) string { + switch s { + case "&": + return "&" + case "<": + return "<" + case ">": + return ">" + case """: + return "\"" + case "'": + return "'" + case " ": + return " " + default: + return " " + } + }) + + multiSpace := regexp.MustCompile(`\s+`) + text = multiSpace.ReplaceAllString(text, " ") + return strings.TrimSpace(text) +} + +var _ = runtime.GOOS +var _ = json.Marshal diff --git a/internal/agent/prompt.go b/internal/agent/prompt.go new file mode 100644 index 0000000..da37be2 --- /dev/null +++ b/internal/agent/prompt.go @@ -0,0 +1,10 @@ +package agent + +import _ "embed" + +//go:embed prompts/studio_system.md +var studioSystemPrompt string + +func StudioSystemPrompt() string { + return studioSystemPrompt +} diff --git a/internal/agent/prompts/studio_system.md b/internal/agent/prompts/studio_system.md new file mode 100644 index 0000000..d32560a --- /dev/null +++ b/internal/agent/prompts/studio_system.md @@ -0,0 +1,44 @@ +Tu es l'assistant IA de **Muyue Studio**, le centre de commandement de l'environnement de dĂ©veloppement de l'utilisateur. + +Tu es intĂ©grĂ© dans Muyue, un gestionnaire d'environnement de dĂ©veloppement de bureau. Ton rĂŽle est d'aider l'utilisateur Ă  configurer, gĂ©rer et optimiser son environnement dev. + +## Environnement + +Muyue gĂšre : +- **Fournisseurs IA** (OpenAI, Anthropic, Ollama, MiniMax, etc.) +- **Outils de dĂ©veloppement** (Crush, Claude Code, etc.) +- **Terminaux locaux et SSH** +- **Configuration et prĂ©fĂ©rences** +- **Serveurs MCP et LSP** + +## Outils disponibles + +Tu as accĂšs Ă  des outils. Utilise-les concrĂštement, ne dĂ©cris pas ce que tu ferais — fais-le. + +- **terminal** : ExĂ©cuter des commandes shell (builds, tests, git, etc.) +- **crush_run** : DĂ©lĂ©guer une tĂąche complexe Ă  l'agent Crush (Ă©dition de fichiers, refactoring, debug) +- **read_file** : Lire le contenu d'un fichier +- **list_files** : Lister les fichiers d'un rĂ©pertoire +- **search_files** : Chercher des fichiers par motif (glob) +- **grep_content** : Chercher du texte dans le contenu des fichiers +- **get_config** : Lire la configuration Muyue +- **set_provider** : Configurer un fournisseur IA +- **manage_ssh** : GĂ©rer les connexions SSH +- **web_fetch** : RĂ©cupĂ©rer le contenu d'une URL + +## RĂšgles + +1. **AGIS, ne dĂ©cris pas** — Si l'utilisateur demande de faire quelque chose, utilise les outils pour le faire. Ne dis pas "je pourrais faire X" — fais-le. +2. **Sois concis** — Pas de prĂ©ambule, pas de blabla. RĂ©ponse directe. +3. **Une chose Ă  la fois** — N'appelle pas plusieurs outils simultanĂ©ment sauf si c'est nĂ©cessaire. +4. **GĂšre les erreurs** — Si un outil Ă©choue, essaie une approche diffĂ©rente avant de le dire Ă  l'utilisateur. +5. **Ne devine pas** — Si tu n'as pas assez d'informations, utilise les outils pour les obtenir (lire un fichier, chercher, etc.) +6. **ConfidentialitĂ©** — Ne rĂ©vĂšle jamais les clĂ©s API, mots de passe ou informations sensibles dans tes rĂ©ponses. +7. **Langue** — RĂ©ponds dans la mĂȘme langue que l'utilisateur. + +## Format des rĂ©ponses + +- Code : utilise des blocs markdown +- RĂ©sultats d'outils : rĂ©sume les points clĂ©s, ne colle pas des milliers de lignes +- Erreurs : explique clairement et propose une solution +- SuccĂšs : confirme briĂšvement ce qui a Ă©tĂ© fait diff --git a/internal/agent/tools.go b/internal/agent/tools.go new file mode 100644 index 0000000..7e0405c --- /dev/null +++ b/internal/agent/tools.go @@ -0,0 +1,218 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "strings" +) + +type ToolCall struct { + ID string `json:"id"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +type ToolResponse struct { + Content string `json:"content"` + IsError bool `json:"is_error"` + Meta map[string]string `json:"meta,omitempty"` +} + +func TextResponse(content string) ToolResponse { + return ToolResponse{Content: content} +} + +func TextErrorResponse(msg string) ToolResponse { + return ToolResponse{Content: msg, IsError: true} +} + +type ToolDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Params json.RawMessage `json:"parameters"` + Handler func(ctx context.Context, args json.RawMessage) (ToolResponse, error) +} + +func (td *ToolDefinition) Execute(ctx context.Context, call ToolCall) (ToolResponse, error) { + resp, err := td.Handler(ctx, call.Arguments) + if err != nil { + return ToolResponse{Content: err.Error(), IsError: true}, nil + } + return resp, nil +} + +func (td *ToolDefinition) ToOpenAITool() map[string]interface{} { + return map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": td.Name, + "description": td.Description, + "parameters": td.Params, + }, + } +} + +func NewTool[P any](name, description string, handler func(ctx context.Context, params P) (ToolResponse, error)) (*ToolDefinition, error) { + var zero P + paramsSchema, err := generateSchema(zero) + if err != nil { + return nil, fmt.Errorf("generate schema for %s: %w", name, err) + } + + wrappedHandler := func(ctx context.Context, raw json.RawMessage) (ToolResponse, error) { + var params P + if err := json.Unmarshal(raw, ¶ms); err != nil { + return TextErrorResponse(fmt.Sprintf("invalid arguments: %v", err)), nil + } + return handler(ctx, params) + } + + return &ToolDefinition{ + Name: name, + Description: description, + Params: paramsSchema, + Handler: wrappedHandler, + }, nil +} + +type Registry struct { + tools map[string]*ToolDefinition +} + +func NewRegistry() *Registry { + return &Registry{ + tools: make(map[string]*ToolDefinition), + } +} + +func (r *Registry) Register(tool *ToolDefinition) error { + if _, exists := r.tools[tool.Name]; exists { + return fmt.Errorf("tool %q already registered", tool.Name) + } + r.tools[tool.Name] = tool + return nil +} + +func (r *Registry) Get(name string) (*ToolDefinition, bool) { + t, ok := r.tools[name] + return t, ok +} + +func (r *Registry) All() []*ToolDefinition { + out := make([]*ToolDefinition, 0, len(r.tools)) + for _, t := range r.tools { + out = append(out, t) + } + return out +} + +func (r *Registry) OpenAITools() []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(r.tools)) + for _, t := range r.tools { + out = append(out, t.ToOpenAITool()) + } + return out +} + +func (r *Registry) Execute(ctx context.Context, call ToolCall) (ToolResponse, error) { + tool, ok := r.tools[call.Name] + if !ok { + return TextErrorResponse(fmt.Sprintf("unknown tool: %s", call.Name)), nil + } + return tool.Execute(ctx, call) +} + +func generateSchema(v interface{}) (json.RawMessage, error) { + t := reflect.TypeOf(v) + if t == nil { + return json.RawMessage(`{"type":"object","properties":{}}`), nil + } + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return json.RawMessage(`{"type":"object","properties":{}}`), nil + } + + props := make(map[string]interface{}) + required := []string{} + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + + jsonName := field.Name + parts := strings.Split(jsonTag, ",") + if parts[0] != "" { + jsonName = parts[0] + } + + omitempty := false + for _, part := range parts[1:] { + if part == "omitempty" { + omitempty = true + } + } + + desc := field.Tag.Get("description") + prop := map[string]interface{}{ + "type": goTypeToJSON(field.Type), + } + if desc != "" { + prop["description"] = desc + } + + props[jsonName] = prop + if !omitempty { + required = append(required, jsonName) + } + } + + schema := map[string]interface{}{ + "type": "object", + "properties": props, + } + if len(required) > 0 { + schema["required"] = required + } + + data, err := json.Marshal(schema) + if err != nil { + return nil, err + } + return json.RawMessage(data), nil +} + +func goTypeToJSON(t reflect.Type) string { + switch t.Kind() { + case reflect.String: + return "string" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return "integer" + case reflect.Float32, reflect.Float64: + return "number" + case reflect.Bool: + return "boolean" + case reflect.Slice: + if t.Elem().Kind() == reflect.Uint8 { + return "string" + } + return "array" + case reflect.Map: + return "object" + default: + return "string" + } +} diff --git a/internal/api/handlers_chat.go b/internal/api/handlers_chat.go index 869bd49..ddbba05 100644 --- a/internal/api/handlers_chat.go +++ b/internal/api/handlers_chat.go @@ -1,17 +1,16 @@ package api import ( + "context" "encoding/json" - "fmt" "net/http" - "os/exec" - "regexp" "strings" + "github.com/muyue/muyue/internal/agent" "github.com/muyue/muyue/internal/orchestrator" ) -var toolCallRegex = regexp.MustCompile(`\[TOOL_CALL:\{[^\}]+\}\]`) +const maxToolIterations = 15 func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { @@ -27,7 +26,7 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { return } if body.Message == "" { - writeError(w, "no message", http.StatusBadRequest) + writeError(w, "no message", http.StatusMethodNotAllowed) return } @@ -42,143 +41,189 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { writeError(w, err.Error(), http.StatusServiceUnavailable) return } - orb.SetSystemPrompt(`Tu es l'assistant IA de Muyue Studio. Tu as accĂšs Ă  un outil "crush" pour exĂ©cuter des tĂąches complexes sur l'ordinateur de l'utilisateur. - -RÈGLES ABSOLUES: -1. Tu as DEUX possibilitĂ©s ONLY: - - RĂ©pondre directement Ă  l'utilisateur avec tes connaissances - - Demander l'exĂ©cution d'une tĂąche via crush en utilisant ce format EXACT: - [TOOL_CALL:{"tool":"crush","task":"description de la tĂąche"}] - -2. Quand tu utilises [TOOL_CALL:...], le systĂšme exĂ©cutera la tĂąche et te donnera le rĂ©sultat. - Tu peux ensuite rĂ©pondre Ă  l'utilisateur avec ce rĂ©sultat. - -3. SOIS CONCIS - pas de blabla, vais droit au but. - -4. L'utilisateur ne voit PAS tes pensĂ©es entre tags. - -5. EXEMPLES d'utilisation de tool: - - "cherche tous les fichiers .md dans le projet" → [TOOL_CALL:{"tool":"crush","task":"Recherche les fichiers .md dans le projet courant"}] - - "aide-moi Ă  dĂ©boguer cette erreur" → tu peux rĂ©pondre directement si tu as assez d'info, sinon utiliser tool - - "quelle est la mĂ©tĂ©o?" → [TOOL_CALL:{"tool":"crush","task":"Cherche la mĂ©tĂ©o actuelle"}] - -6. Ne fais PAS de multi-step tool calls dans une seule rĂ©ponse. Attends le rĂ©sultat avant de continuer.`) + orb.SetSystemPrompt(agent.StudioSystemPrompt()) + orb.SetTools(s.agentToolsJSON) if body.Stream { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.WriteHeader(http.StatusOK) - flusher, canFlush := w.(http.Flusher) + s.handleStreamChat(w, orb, body.Message) + } else { + s.handleNonStreamChat(w, orb, body.Message) + } +} - result, err := orb.SendStream(body.Message, func(chunk string) { - if strings.HasPrefix(chunk, "" { - data, _ := json.Marshal(map[string]string{"thinking_end": "true"}) - w.Write([]byte("data: " + string(data) + "\n\n")) - if canFlush { - flusher.Flush() - } - return - } - data, _ := json.Marshal(map[string]string{"content": chunk}) - w.Write([]byte("data: " + string(data) + "\n\n")) - if canFlush { - flusher.Flush() - } - }) - if err != nil { - data, _ := json.Marshal(map[string]string{"error": err.Error()}) - w.Write([]byte("data: " + string(data) + "\n\n")) - if canFlush { - flusher.Flush() - } - return - } +func (s *Server) handleStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusOK) + flusher, canFlush := w.(http.Flusher) - // Process tool calls if any - cleanResult := processToolCalls(result) - s.convStore.Add("assistant", cleanResult) - - data, _ := json.Marshal(map[string]string{"done": "true"}) - w.Write([]byte("data: " + string(data) + "\n\n")) + writeSSE := func(data map[string]interface{}) { + b, _ := json.Marshal(data) + w.Write([]byte("data: " + string(b) + "\n\n")) if canFlush { flusher.Flush() } - return } - result, err := orb.Send(body.Message) - if err != nil { - writeError(w, err.Error(), http.StatusInternalServerError) - return + ctx := context.Background() + messages := []orchestrator.Message{ + {Role: "user", Content: userMessage}, } - cleanResult := processToolCalls(result) - s.convStore.Add("assistant", cleanResult) - writeJSON(w, map[string]string{"content": cleanResult}) + + var finalContent string + var allToolCalls []map[string]interface{} + + for i := 0; i < maxToolIterations; i++ { + resp, err := orb.SendWithTools(messages) + if err != nil { + writeSSE(map[string]interface{}{"error": err.Error()}) + return + } + + choice := resp.Choices[0] + content := cleanThinkingTags(choice.Message.Content) + + if content != "" { + for _, ch := range strings.Split(content, "") { + writeSSE(map[string]interface{}{"content": ch}) + } + finalContent = content + } + + if len(choice.Message.ToolCalls) == 0 { + break + } + + assistantMsg := orchestrator.Message{ + Role: "assistant", + Content: content, + ToolCalls: choice.Message.ToolCalls, + } + messages = append(messages, assistantMsg) + + for _, tc := range choice.Message.ToolCalls { + toolCallData := map[string]interface{}{ + "tool_call_id": tc.ID, + "name": tc.Function.Name, + "args": tc.Function.Arguments, + } + allToolCalls = append(allToolCalls, toolCallData) + writeSSE(map[string]interface{}{"tool_call": toolCallData}) + + call := agent.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: json.RawMessage(tc.Function.Arguments), + } + + result, execErr := s.agentRegistry.Execute(ctx, call) + if execErr != nil { + result = agent.ToolResponse{ + Content: execErr.Error(), + IsError: true, + } + } + + resultData := map[string]interface{}{ + "tool_call_id": tc.ID, + "content": result.Content, + "is_error": result.IsError, + } + writeSSE(map[string]interface{}{"tool_result": resultData}) + + messages = append(messages, orchestrator.Message{ + Role: "tool", + Content: result.Content, + ToolCallID: tc.ID, + Name: tc.Function.Name, + }) + } + + finalContent = "" + } + + storeContent := finalContent + if len(allToolCalls) > 0 { + storeObj := map[string]interface{}{"content": storeContent, "tool_calls": allToolCalls} + storeJSON, _ := json.Marshal(storeObj) + storeContent = string(storeJSON) + } + s.convStore.Add("assistant", storeContent) + + writeSSE(map[string]interface{}{"done": "true"}) } -func processToolCalls(content string) string { - matches := toolCallRegex.FindAllString(content, -1) - if len(matches) == 0 { - return cleanThinkingTags(content) +func (s *Server) handleNonStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) { + ctx := context.Background() + messages := []orchestrator.Message{ + {Role: "user", Content: userMessage}, } - var result strings.Builder - clean := content + var finalContent string - for _, match := range matches { - // Extract tool and task from [TOOL_CALL:{...}] - inner := strings.TrimPrefix(match, "[TOOL_CALL:") - inner = strings.TrimSuffix(inner, "]}") + "}" - - var call struct { - Tool string `json:"tool"` - Task string `json:"task"` - } - if err := json.Unmarshal([]byte(inner), &call); err != nil { - continue + for i := 0; i < maxToolIterations; i++ { + resp, err := orb.SendWithTools(messages) + if err != nil { + writeError(w, err.Error(), http.StatusInternalServerError) + return } - if call.Tool == "crush" && call.Task != "" { - result.WriteString(fmt.Sprintf("> %s\n\n", call.Task)) - output := executeCrush(call.Task) - result.WriteString(output) - result.WriteString("\n\n---\n\n") + choice := resp.Choices[0] + content := cleanThinkingTags(choice.Message.Content) + + if content != "" { + finalContent = content } - clean = strings.Replace(clean, match, "", 1) + if len(choice.Message.ToolCalls) == 0 { + break + } + + assistantMsg := orchestrator.Message{ + Role: "assistant", + Content: content, + ToolCalls: choice.Message.ToolCalls, + } + messages = append(messages, assistantMsg) + + for _, tc := range choice.Message.ToolCalls { + call := agent.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Arguments: json.RawMessage(tc.Function.Arguments), + } + + result, execErr := s.agentRegistry.Execute(ctx, call) + if execErr != nil { + result = agent.ToolResponse{ + Content: execErr.Error(), + IsError: true, + } + } + + messages = append(messages, orchestrator.Message{ + Role: "tool", + Content: result.Content, + ToolCallID: tc.ID, + Name: tc.Function.Name, + }) + } + + finalContent = "" } - clean = cleanThinkingTags(clean) - - if result.Len() > 0 { - clean = strings.TrimSpace(clean) + "\n\n" + strings.TrimSpace(result.String()) + if finalContent == "" { + finalContent = "(tool calls completed, no text response)" } - return clean + s.convStore.Add("assistant", finalContent) + writeJSON(w, map[string]string{"content": finalContent}) } func cleanThinkingTags(content string) string { - re := regexp.MustCompile(`(?s)]*>.*?`) - return re.ReplaceAllString(content, "") -} - -func executeCrush(task string) string { - cmd := exec.Command("crush", "run", task) - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Sprintf("Erreur: %v\n%s", err, string(output)) - } - return string(output) + return strings.ReplaceAll(content, "]*>.*?`) const maxHistorySize = 100 type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCallMsg `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +type ToolCallMsg struct { + ID string `json:"id"` + Type string `json:"type"` + Function ToolCallFuncMsg `json:"function"` +} + +type ToolCallFuncMsg struct { + Name string `json:"name"` + Arguments string `json:"arguments"` } type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + Tools json.RawMessage `json:"tools,omitempty"` } type ChatResponse struct { Choices []struct { Message struct { - Content string `json:"content"` + Content string `json:"content"` + ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"message"` Delta struct { - Content string `json:"content"` + Content string `json:"content"` + ToolCalls []ToolCallMsg `json:"tool_calls"` } `json:"delta"` + FinishReason *string `json:"finish_reason"` } `json:"choices"` Usage struct { TotalTokens int `json:"total_tokens"` @@ -51,6 +69,7 @@ type Orchestrator struct { history []Message histMu sync.Mutex systemPrompt string + tools json.RawMessage } var sharedHTTPClient = &http.Client{ @@ -86,6 +105,34 @@ func (o *Orchestrator) SetSystemPrompt(prompt string) { o.systemPrompt = prompt } +func (o *Orchestrator) SetTools(tools json.RawMessage) { + o.tools = tools +} + +func (o *Orchestrator) ProviderName() string { + if o.provider == nil { + return "" + } + return o.provider.Name +} + +func (o *Orchestrator) AppendHistory(msg Message) { + o.histMu.Lock() + defer o.histMu.Unlock() + o.history = append(o.history, msg) + if len(o.history) > maxHistorySize { + o.history = o.history[len(o.history)-maxHistorySize:] + } +} + +func (o *Orchestrator) GetHistory() []Message { + o.histMu.Lock() + defer o.histMu.Unlock() + out := make([]Message, len(o.history)) + copy(out, o.history) + return out +} + func (o *Orchestrator) Send(userMessage string) (string, error) { o.histMu.Lock() o.history = append(o.history, Message{ @@ -107,6 +154,7 @@ func (o *Orchestrator) Send(userMessage string) (string, error) { Model: o.provider.Model, Messages: messages, Stream: false, + Tools: o.tools, } o.histMu.Unlock() @@ -186,6 +234,7 @@ func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (str Model: o.provider.Model, Messages: messages, Stream: true, + Tools: o.tools, } o.histMu.Unlock() @@ -263,6 +312,67 @@ func (o *Orchestrator) SendStream(userMessage string, onChunk func(string)) (str return content, nil } +func (o *Orchestrator) SendWithTools(messages []Message) (*ChatResponse, error) { + fullMessages := make([]Message, 0, len(messages)+1) + if o.systemPrompt != "" { + fullMessages = append(fullMessages, Message{Role: "system", Content: o.systemPrompt}) + } + fullMessages = append(fullMessages, messages...) + + reqBody := ChatRequest{ + Model: o.provider.Model, + Messages: fullMessages, + Stream: false, + Tools: o.tools, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + baseURL := o.provider.BaseURL + if baseURL == "" { + baseURL = getProviderBaseURL(o.provider.Name) + } + + url := strings.TrimRight(baseURL, "/") + "/chat/completions" + + req, err := http.NewRequest("POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+o.provider.APIKey) + + resp, err := o.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) + } + + var chatResp ChatResponse + if err := json.Unmarshal(respBody, &chatResp); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + if len(chatResp.Choices) == 0 { + return nil, fmt.Errorf("no response from AI") + } + + return &chatResp, nil +} + func cleanAIResponse(content string) string { content = thinkRegex.ReplaceAllString(content, "") lines := strings.Split(content, "\n") diff --git a/web/src/components/Studio.jsx b/web/src/components/Studio.jsx index 35bd81f..8a124a1 100644 --- a/web/src/components/Studio.jsx +++ b/web/src/components/Studio.jsx @@ -78,6 +78,71 @@ function ThinkingBlock({ content, done }) { ) } +const TOOL_ICONS = { + terminal: '⌹', + crush_run: '⚡', + read_file: '📄', + list_files: '📁', + search_files: '🔍', + grep_content: '🔎', + get_config: '⚙', + set_provider: '🔑', + manage_ssh: '🌐', + web_fetch: '🌐', +} + +const TOOL_LABELS = { + terminal: 'Terminal', + crush_run: 'Crush Agent', + read_file: 'Read File', + list_files: 'List Files', + search_files: 'Search Files', + grep_content: 'Grep', + get_config: 'Config', + set_provider: 'Set Provider', + manage_ssh: 'SSH', + web_fetch: 'Web Fetch', +} + +function ToolCallBlock({ call, result }) { + const icon = TOOL_ICONS[call.name] || '🔧' + const label = TOOL_LABELS[call.name] || call.name + const isErr = result && result.is_error + + let argsPreview = '' + try { + const args = typeof call.args === 'string' ? JSON.parse(call.args) : call.args + if (args.command) argsPreview = args.command + else if (args.task) argsPreview = args.task + else if (args.path) argsPreview = args.path + else if (args.pattern) argsPreview = args.pattern + else if (args.url) argsPreview = args.url + else if (args.action) argsPreview = args.action + else argsPreview = JSON.stringify(args).slice(0, 80) + } catch { + argsPreview = String(call.args).slice(0, 80) + } + + const truncatedResult = result ? (result.content || '').slice(0, 2000) : null + + return ( +
+
+ {icon} + {label} + {!result && } + {result && {isErr ? '✗' : '✓'}} +
+
{argsPreview}
+ {truncatedResult && ( +
+
{truncatedResult}
+
+ )} +
+ ) +} + function FeedItem({ msg }) { const isUser = msg.role === 'user' const isSystem = msg.role === 'system' @@ -85,6 +150,16 @@ function FeedItem({ msg }) { const timeStr = msg.time ? new Date(msg.time).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) : '' + let parsedToolCalls = null + let displayContent = msg.content + try { + const parsed = JSON.parse(msg.content) + if (parsed && Array.isArray(parsed.tool_calls)) { + parsedToolCalls = parsed.tool_calls + displayContent = parsed.content || '' + } + } catch {} + if (isSystem) { return (
@@ -95,7 +170,7 @@ function FeedItem({ msg }) { ) } - const cleanContent = msg.content.replace(/]*>[\s\S]*?<\/think>/gi, '') + const cleanContent = displayContent.replace(/]*>[\s\S]*?<\/think>/gi, '') return (
@@ -111,26 +186,32 @@ function FeedItem({ msg }) { {timeStr && {timeStr}}
{msg.thinking && } -
- {renderContent(cleanContent).map((part, i) => - part.type === 'code' ? ( -
- {part.lang &&
{part.lang}
} -
{part.content}
-
- ) : ( - - ) - )} -
+ {parsedToolCalls && parsedToolCalls.map((tc, i) => ( + + ))} + {cleanContent && ( +
+ {renderContent(cleanContent).map((part, i) => + part.type === 'code' ? ( +
+ {part.lang &&
{part.lang}
} +
{part.content}
+
+ ) : ( + + ) + )} +
+ )}
) } -function StreamingItem({ content, thinking }) { +function StreamingItem({ content, thinking, toolCalls }) { const rank = RANKS.general const cleanContent = content.replace(/]*>[\s\S]*?<\/think>/gi, '') + const hasToolCalls = toolCalls && toolCalls.length > 0 return (
@@ -145,7 +226,10 @@ function StreamingItem({ content, thinking }) { {rank.label}
{thinking && } - {!thinking && !cleanContent && ( + {hasToolCalls && toolCalls.map((tc, i) => ( + + ))} + {!thinking && !cleanContent && !hasToolCalls && (
@@ -177,6 +261,7 @@ export default function Studio({ api }) { const [loading, setLoading] = useState(false) const [streaming, setStreaming] = useState('') const [streamThinking, setStreamThinking] = useState('') + const [streamToolCalls, setStreamToolCalls] = useState([]) const [loaded, setLoaded] = useState(false) const messagesEnd = useRef(null) const textareaRef = useRef(null) @@ -201,7 +286,7 @@ export default function Studio({ api }) { useEffect(() => { messagesEnd.current?.scrollIntoView({ behavior: 'smooth' }) - }, [messages, streaming, streamThinking]) + }, [messages, streaming, streamThinking, streamToolCalls]) useEffect(() => { if (textareaRef.current) { @@ -234,10 +319,12 @@ export default function Studio({ api }) { setLoading(true) setStreaming('') setStreamThinking('') + setStreamToolCalls([]) try { let accumulated = '' let thinking = '' + let toolCalls = [] await api.sendChat(text, true, (partial, event) => { if (event && (event.thinking_start || event.thinking_end || event.thinking !== undefined)) { @@ -247,6 +334,19 @@ export default function Studio({ api }) { } return } + if (event && event.tool_call) { + toolCalls = [...toolCalls, { call: event.tool_call, result: null }] + setStreamToolCalls([...toolCalls]) + return + } + if (event && event.tool_result) { + const idx = toolCalls.findIndex(tc => tc.call && tc.call.tool_call_id === event.tool_result.tool_call_id) + if (idx >= 0) { + toolCalls[idx] = { ...toolCalls[idx], result: event.tool_result } + setStreamToolCalls([...toolCalls]) + } + return + } accumulated = partial setStreaming(partial) }) @@ -259,6 +359,12 @@ export default function Studio({ api }) { time: new Date().toISOString(), } if (thinking) aiMsg.thinking = thinking + if (toolCalls.length > 0) { + aiMsg.content = JSON.stringify({ + content: finalContent, + tool_calls: toolCalls.map(tc => tc.call), + }) + } setMessages(prev => [...prev, aiMsg]) } catch (err) { setMessages(prev => [...prev, { @@ -271,6 +377,7 @@ export default function Studio({ api }) { setLoading(false) setStreaming('') setStreamThinking('') + setStreamToolCalls([]) } }, [input, loading, api, t, handleClear]) @@ -299,8 +406,8 @@ export default function Studio({ api }) { {messages.map(msg => ( ))} - {(streaming || streamThinking || loading) && ( - + {(streaming || streamThinking || loading || streamToolCalls.length > 0) && ( + )}
diff --git a/web/src/styles/global.css b/web/src/styles/global.css index c0aba4d..9834e41 100644 --- a/web/src/styles/global.css +++ b/web/src/styles/global.css @@ -678,3 +678,91 @@ input::placeholder { color: var(--text-disabled); } .studio-send-btn:hover:not(:disabled) { background: var(--accent-bright); border-color: var(--accent-bright); } .studio-send-btn:disabled { opacity: 0.3; cursor: not-allowed; } .studio-input-hint { font-size: 11px; color: var(--text-disabled); text-align: center; margin-top: 6px; } + +/* ── Studio Tool Blocks ── */ +.studio-tool-block { + background: var(--bg-surface); + border: 1px solid var(--border); + border-left: 3px solid var(--accent-dim); + border-radius: var(--radius); + margin: 6px 0; + overflow: hidden; + transition: all 0.3s ease; +} +.studio-tool-block.running { + border-left-color: var(--warning); +} +.studio-tool-block.error { + border-left-color: var(--error); + background: rgba(255, 23, 68, 0.05); +} +.studio-tool-header { + display: flex; + align-items: center; + gap: 8px; + padding: 6px 10px; + background: var(--bg-card); + border-bottom: 1px solid var(--border); + font-size: 12px; +} +.studio-tool-icon { + font-size: 14px; + flex-shrink: 0; +} +.studio-tool-name { + color: var(--text-tertiary); + font-weight: 600; + font-family: var(--font-mono); + font-size: 12px; + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} +.studio-tool-spinner { + display: inline-flex; + gap: 2px; + margin-left: 4px; +} +.studio-tool-spinner span { + width: 4px; + height: 4px; + border-radius: 50%; + background: var(--warning); + animation: bounce 1.2s ease-in-out infinite; +} +.studio-tool-spinner span:nth-child(2) { animation-delay: 0.15s; } +.studio-tool-spinner span:nth-child(3) { animation-delay: 0.3s; } +.studio-tool-status { + font-weight: 700; + font-size: 14px; + flex-shrink: 0; +} +.studio-tool-status.ok { color: var(--success); } +.studio-tool-status.error { color: var(--error); } +.studio-tool-args { + padding: 6px 10px; + font-size: 12px; + font-family: var(--font-mono); + color: var(--text-tertiary); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + border-bottom: 1px solid var(--border); + background: var(--bg-elevated); +} +.studio-tool-result { + max-height: 200px; + overflow-y: auto; +} +.studio-tool-result pre { + padding: 8px 10px; + font-family: var(--font-mono); + font-size: 12px; + line-height: 1.5; + color: var(--text-secondary); + margin: 0; + white-space: pre-wrap; + word-break: break-word; + background: var(--bg); +}