feat(agent): refactor AI chat with streaming, agent registry, and tool execution
- Replace old tool-call regex with proper agent registry - Add streaming chat via SSE (handleStreamChat / handleNonStreamChat) - Add internal/agent package with tool definitions and execution - Add orchestrator with system prompt and tool scaffolding - Add internal/agent/ directory - Studio.jsx: streaming chat with thinking indicator and tool result rendering - global.css: chat bubble styles, streaming animation, thinking dots - handlers_chat.go: full rewrite using new agent/orchestrator architecture 💘 Generated with Crush Assisted-by: MiniMax-M2.7 via Crush <crush@charm.land>
This commit is contained in:
579
internal/agent/impl.go
Normal file
579
internal/agent/impl.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func detectShell() string {
|
||||
shells := []string{"zsh", "bash", "fish", "pwsh", "powershell"}
|
||||
for _, s := range shells {
|
||||
if path, err := exec.LookPath(s); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if path == "~" {
|
||||
home, _ := os.UserHomeDir()
|
||||
return home
|
||||
}
|
||||
if strings.HasPrefix(path, "~/") {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, path[2:])
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func osUserHomeDir() (string, error) {
|
||||
return os.UserHomeDir()
|
||||
}
|
||||
|
||||
func readFileLimited(path string, offset, limit int) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if offset > len(lines) {
|
||||
offset = len(lines)
|
||||
}
|
||||
|
||||
end := offset + limit
|
||||
if limit <= 0 || limit > 2000 {
|
||||
limit = 2000
|
||||
}
|
||||
if end > len(lines) {
|
||||
end = len(lines)
|
||||
}
|
||||
if end-offset > limit {
|
||||
end = offset + limit
|
||||
}
|
||||
|
||||
selected := lines[offset:end]
|
||||
|
||||
var buf strings.Builder
|
||||
for i, line := range selected {
|
||||
fmt.Fprintf(&buf, "%6d\t%s\n", offset+i+1, line)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func listDirTree(dir string, maxDepth, currentDepth int) (string, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return dir + "\n", nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
indent := strings.Repeat(" ", currentDepth)
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasPrefix(name, ".") && name != "." && name != ".." {
|
||||
continue
|
||||
}
|
||||
|
||||
if entry.IsDir() {
|
||||
fmt.Fprintf(&buf, "%s%s/\n", indent, name)
|
||||
if currentDepth < maxDepth {
|
||||
sub, err := listDirTree(filepath.Join(dir, name), maxDepth, currentDepth+1)
|
||||
if err == nil {
|
||||
buf.WriteString(sub)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%s%s\n", indent, name)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func grepFiles(dir, pattern, include string) (string, error) {
|
||||
if include != "" {
|
||||
matches, err := filepath.Glob(filepath.Join(dir, include))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var buf strings.Builder
|
||||
for _, match := range matches {
|
||||
result, err := grepInFile(match, pattern)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
buf.WriteString(result)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
return grepInDir(dir, pattern, 0)
|
||||
}
|
||||
|
||||
func grepInDir(dir, pattern string, depth int) (string, error) {
|
||||
if depth > 10 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, name)
|
||||
|
||||
if entry.IsDir() {
|
||||
sub, err := grepInDir(path, pattern, depth+1)
|
||||
if err == nil {
|
||||
buf.WriteString(sub)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
result, err := grepInFile(path, pattern)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
buf.WriteString(result)
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func grepInFile(path, pattern string) (string, error) {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
re, err = regexp.Compile(regexp.QuoteMeta(pattern))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var buf strings.Builder
|
||||
scanner := bufio.NewScanner(file)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
lineNum := 0
|
||||
matchCount := 0
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
if re.MatchString(scanner.Text()) {
|
||||
fmt.Fprintf(&buf, "%s:%d: %s\n", path, lineNum, scanner.Text())
|
||||
matchCount++
|
||||
if matchCount >= 50 {
|
||||
buf.WriteString("... [truncated, more matches exist]\n")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func getConfigSection(section string) ToolResponse {
|
||||
configPath, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err))
|
||||
}
|
||||
configPath = filepath.Join(configPath, "muyue", "config.yaml")
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err))
|
||||
}
|
||||
|
||||
switch section {
|
||||
case "providers", "profile", "tools", "terminal":
|
||||
sectionData := extractYAMLSection(data, section)
|
||||
if sectionData == "" {
|
||||
return TextResponse(fmt.Sprintf("Section '%s' not found in config.", section))
|
||||
}
|
||||
return TextResponse(sectionData)
|
||||
default:
|
||||
content := string(data)
|
||||
if len(content) > 8000 {
|
||||
content = content[:8000] + "\n... [truncated]"
|
||||
}
|
||||
return TextResponse(content)
|
||||
}
|
||||
}
|
||||
|
||||
func extractYAMLSection(data []byte, section string) string {
|
||||
lines := strings.Split(string(data), "\n")
|
||||
inSection := false
|
||||
indentLevel := 0
|
||||
var buf strings.Builder
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
if inSection {
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !inSection {
|
||||
if strings.HasPrefix(trimmed, section+":") || strings.HasPrefix(trimmed, section+" ") {
|
||||
inSection = true
|
||||
indentLevel = len(line) - len(strings.TrimLeft(line, " "))
|
||||
buf.WriteString(line)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
currentIndent := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if currentIndent <= indentLevel && trimmed != "" {
|
||||
break
|
||||
}
|
||||
buf.WriteString(line)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(buf.String())
|
||||
}
|
||||
|
||||
func setProviderConfig(p SetProviderParams) ToolResponse {
|
||||
configPath, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err))
|
||||
}
|
||||
configPath = filepath.Join(configPath, "muyue", "config.yaml")
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err))
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
inProviders := false
|
||||
providerIndent := 0
|
||||
foundProvider := false
|
||||
insertIdx := -1
|
||||
lastProviderEnd := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !inProviders {
|
||||
if strings.HasPrefix(trimmed, "providers:") {
|
||||
inProviders = true
|
||||
providerIndent = len(line) - len(strings.TrimLeft(line, " "))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
currentIndent := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if currentIndent <= providerIndent && trimmed != "" && !strings.HasPrefix(trimmed, "#") {
|
||||
lastProviderEnd = i
|
||||
break
|
||||
}
|
||||
|
||||
if currentIndent == providerIndent+2 && strings.HasPrefix(trimmed, "- name:") {
|
||||
nameMatch := strings.TrimPrefix(trimmed, "- name:")
|
||||
nameMatch = strings.TrimSpace(nameMatch)
|
||||
if nameMatch == p.Name {
|
||||
foundProvider = true
|
||||
insertIdx = i
|
||||
}
|
||||
if insertIdx == -1 || insertIdx < i {
|
||||
insertIdx = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastProviderEnd == -1 {
|
||||
lastProviderEnd = len(lines)
|
||||
}
|
||||
|
||||
entryIndent := strings.Repeat(" ", providerIndent+4)
|
||||
|
||||
var newEntry strings.Builder
|
||||
newEntry.WriteString(fmt.Sprintf(" - name: %s\n", p.Name))
|
||||
if p.Model != "" {
|
||||
newEntry.WriteString(fmt.Sprintf("%smodel: %s\n", entryIndent, p.Model))
|
||||
}
|
||||
if p.BaseURL != "" {
|
||||
newEntry.WriteString(fmt.Sprintf("%sbase_url: %s\n", entryIndent, p.BaseURL))
|
||||
}
|
||||
if p.APIKey != "" {
|
||||
newEntry.WriteString(fmt.Sprintf("%sapi_key: %s\n", entryIndent, p.APIKey))
|
||||
}
|
||||
if p.Active != nil {
|
||||
newEntry.WriteString(fmt.Sprintf("%sactive: %v\n", entryIndent, *p.Active))
|
||||
}
|
||||
|
||||
if foundProvider && insertIdx >= 0 {
|
||||
var endIdx int
|
||||
for endIdx = insertIdx + 1; endIdx < len(lines); endIdx++ {
|
||||
li := len(lines[endIdx]) - len(strings.TrimLeft(lines[endIdx], " "))
|
||||
if li <= providerIndent+2 || lines[endIdx] == "" {
|
||||
if endIdx > insertIdx+1 && strings.TrimSpace(lines[endIdx]) == "" {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newLines := make([]string, 0, len(lines))
|
||||
newLines = append(newLines, lines[:insertIdx]...)
|
||||
newLines = append(newLines, strings.TrimSuffix(newEntry.String(), "\n"))
|
||||
newLines = append(newLines, lines[endIdx:]...)
|
||||
lines = newLines
|
||||
} else {
|
||||
insertAt := lastProviderEnd
|
||||
newLines := make([]string, 0, len(lines)+10)
|
||||
newLines = append(newLines, lines[:insertAt]...)
|
||||
newLines = append(newLines, strings.TrimSuffix(newEntry.String(), "\n"))
|
||||
newLines = append(newLines, lines[insertAt:]...)
|
||||
lines = newLines
|
||||
}
|
||||
|
||||
content := strings.Join(lines, "\n")
|
||||
if err := os.WriteFile(configPath, []byte(content), 0600); err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("write config error: %v", err))
|
||||
}
|
||||
|
||||
return TextResponse(fmt.Sprintf("Provider '%s' configured successfully.", p.Name))
|
||||
}
|
||||
|
||||
func manageSSHAction(p ManageSSHParams) ToolResponse {
|
||||
configPath, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err))
|
||||
}
|
||||
configPath = filepath.Join(configPath, "muyue", "config.yaml")
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err))
|
||||
}
|
||||
|
||||
switch p.Action {
|
||||
case "list":
|
||||
sshSection := extractYAMLSection(data, "ssh")
|
||||
if sshSection == "" {
|
||||
return TextResponse("No SSH connections configured.")
|
||||
}
|
||||
return TextResponse(sshSection)
|
||||
|
||||
case "add":
|
||||
if p.Name == "" || p.Host == "" || p.User == "" {
|
||||
return TextErrorResponse("name, host, and user are required for add action")
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = 22
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
sshIdx := -1
|
||||
sshIndent := 0
|
||||
lastSSHEnd := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if sshIdx == -1 && strings.HasPrefix(trimmed, "ssh:") {
|
||||
sshIdx = i
|
||||
sshIndent = len(line) - len(strings.TrimLeft(line, " "))
|
||||
continue
|
||||
}
|
||||
if sshIdx != -1 {
|
||||
li := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if li <= sshIndent && trimmed != "" {
|
||||
lastSSHEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastSSHEnd == -1 {
|
||||
lastSSHEnd = len(lines)
|
||||
}
|
||||
|
||||
entry := fmt.Sprintf(" - name: %s\n host: %s\n port: %d\n user: %s", p.Name, p.Host, p.Port, p.User)
|
||||
if p.KeyPath != "" {
|
||||
entry += fmt.Sprintf("\n key_path: %s", p.KeyPath)
|
||||
}
|
||||
|
||||
newLines := make([]string, 0, len(lines)+10)
|
||||
newLines = append(newLines, lines[:lastSSHEnd]...)
|
||||
newLines = append(newLines, entry)
|
||||
newLines = append(newLines, lines[lastSSHEnd:]...)
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(strings.Join(newLines, "\n")), 0600); err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("write config error: %v", err))
|
||||
}
|
||||
return TextResponse(fmt.Sprintf("SSH connection '%s' (%s@%s:%d) added.", p.Name, p.User, p.Host, p.Port))
|
||||
|
||||
case "remove":
|
||||
if p.Name == "" {
|
||||
return TextErrorResponse("name is required for remove action")
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
newLines := make([]string, 0, len(lines))
|
||||
skipping := false
|
||||
removed := false
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.Contains(trimmed, "name: "+p.Name) && strings.HasPrefix(trimmed, "-") {
|
||||
skipping = true
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
if skipping {
|
||||
li := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if li > 6 && i < len(lines)-1 && strings.TrimSpace(lines[i+1]) != "" {
|
||||
continue
|
||||
}
|
||||
skipping = false
|
||||
continue
|
||||
}
|
||||
newLines = append(newLines, line)
|
||||
}
|
||||
|
||||
if !removed {
|
||||
return TextErrorResponse(fmt.Sprintf("SSH connection '%s' not found.", p.Name))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, []byte(strings.Join(newLines, "\n")), 0600); err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("write config error: %v", err))
|
||||
}
|
||||
return TextResponse(fmt.Sprintf("SSH connection '%s' removed.", p.Name))
|
||||
|
||||
default:
|
||||
return TextErrorResponse("unknown action. Use 'list', 'add', or 'remove'")
|
||||
}
|
||||
}
|
||||
|
||||
func fetchURL(url string) ToolResponse {
|
||||
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
||||
return TextErrorResponse("only http/https URLs are supported")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("create request: %v", err))
|
||||
}
|
||||
req.Header.Set("User-Agent", "MuyueStudio/1.0")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("fetch error: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 50000))
|
||||
if err != nil {
|
||||
return TextErrorResponse(fmt.Sprintf("read error: %v", err))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return TextErrorResponse(fmt.Sprintf("HTTP %d: %s", resp.StatusCode, truncate(string(body), 2000)))
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
text := stripHTML(string(body))
|
||||
if len(text) > 8000 {
|
||||
text = text[:8000] + "\n... [truncated]"
|
||||
}
|
||||
return TextResponse(text)
|
||||
}
|
||||
|
||||
result := string(body)
|
||||
if len(result) > 10000 {
|
||||
result = result[:10000] + "\n... [truncated]"
|
||||
}
|
||||
return TextResponse(result)
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func stripHTML(html string) string {
|
||||
tagRe := regexp.MustCompile(`<[^>]*>`)
|
||||
text := tagRe.ReplaceAllString(html, " ")
|
||||
|
||||
entityRe := regexp.MustCompile(`&[a-zA-Z]+;`)
|
||||
text = entityRe.ReplaceAllStringFunc(text, func(s string) string {
|
||||
switch s {
|
||||
case "&":
|
||||
return "&"
|
||||
case "<":
|
||||
return "<"
|
||||
case ">":
|
||||
return ">"
|
||||
case """:
|
||||
return "\""
|
||||
case "'":
|
||||
return "'"
|
||||
case " ":
|
||||
return " "
|
||||
default:
|
||||
return " "
|
||||
}
|
||||
})
|
||||
|
||||
multiSpace := regexp.MustCompile(`\s+`)
|
||||
text = multiSpace.ReplaceAllString(text, " ")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
var _ = runtime.GOOS
|
||||
var _ = json.Marshal
|
||||
Reference in New Issue
Block a user