package agent import ( "bufio" "context" "encoding/json" "fmt" "io" "net/http" "os" "os/exec" "path/filepath" "regexp" "runtime" "strings" "time" ) func detectShell() string { shells := []string{"zsh", "bash", "fish", "pwsh", "powershell"} for _, s := range shells { if path, err := exec.LookPath(s); err == nil { return path } } return "/bin/sh" } var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) // buildAgentCommand assembles an agent execution command, optionally launching it // inside a WSL distribution (Windows host only) and applying a working directory. // On non-Windows hosts, wsl_* parameters are ignored. func buildAgentCommand(ctx context.Context, bin string, args []string, cwd, wslDistro, wslUser string) (*exec.Cmd, error) { if wslDistro != "" && runtime.GOOS == "windows" { if !validIdentifier.MatchString(wslDistro) { return nil, fmt.Errorf("invalid wsl_distro: %q", wslDistro) } if wslUser != "" && !validIdentifier.MatchString(wslUser) { return nil, fmt.Errorf("invalid wsl_user: %q", wslUser) } wslArgs := []string{"-d", wslDistro} if wslUser != "" { wslArgs = append(wslArgs, "-u", wslUser) } if cwd != "" { wslArgs = append(wslArgs, "--cd", cwd) } wslArgs = append(wslArgs, "--") wslArgs = append(wslArgs, bin) wslArgs = append(wslArgs, args...) return exec.CommandContext(ctx, "wsl", wslArgs...), nil } cmd := exec.CommandContext(ctx, bin, args...) if cwd != "" { dir := expandHome(cwd) if info, err := os.Stat(dir); err != nil || !info.IsDir() { return nil, fmt.Errorf("cwd does not exist or is not a directory: %s", cwd) } cmd.Dir = dir } return cmd, nil } func expandHome(path string) string { if path == "" { return "" } if path == "~" { home, _ := os.UserHomeDir() return home } if strings.HasPrefix(path, "~/") { home, _ := os.UserHomeDir() return filepath.Join(home, path[2:]) } return path } func osUserHomeDir() (string, error) { return os.UserHomeDir() } func readFileLimited(path string, offset, limit int) (string, error) { data, err := os.ReadFile(path) if err != nil { return "", err } lines := strings.Split(string(data), "\n") if offset < 0 { offset = 0 } if offset > len(lines) { offset = len(lines) } end := offset + limit if limit <= 0 || limit > 2000 { limit = 2000 } if end > len(lines) { end = len(lines) } if end-offset > limit { end = offset + limit } selected := lines[offset:end] var buf strings.Builder for i, line := range selected { fmt.Fprintf(&buf, "%6d\t%s\n", offset+i+1, line) } return buf.String(), nil } func listDirTree(dir string, maxDepth, currentDepth int) (string, error) { info, err := os.Stat(dir) if err != nil { return "", err } if !info.IsDir() { return dir + "\n", nil } entries, err := os.ReadDir(dir) if err != nil { return "", err } var buf strings.Builder indent := strings.Repeat(" ", currentDepth) for _, entry := range entries { name := entry.Name() if strings.HasPrefix(name, ".") && name != "." && name != ".." { continue } if entry.IsDir() { fmt.Fprintf(&buf, "%s%s/\n", indent, name) if currentDepth < maxDepth { sub, err := listDirTree(filepath.Join(dir, name), maxDepth, currentDepth+1) if err == nil { buf.WriteString(sub) } } } else { fmt.Fprintf(&buf, "%s%s\n", indent, name) } } return buf.String(), nil } func grepFiles(dir, pattern, include string) (string, error) { if include != "" { matches, err := filepath.Glob(filepath.Join(dir, include)) if err != nil { return "", err } if len(matches) == 0 { return "", nil } var buf strings.Builder for _, match := range matches { result, err := grepInFile(match, pattern) if err != nil { continue } buf.WriteString(result) } return buf.String(), nil } return grepInDir(dir, pattern, 0) } func grepInDir(dir, pattern string, depth int) (string, error) { if depth > 10 { return "", nil } var buf strings.Builder entries, err := os.ReadDir(dir) if err != nil { return "", err } for _, entry := range entries { name := entry.Name() if strings.HasPrefix(name, ".") { continue } path := filepath.Join(dir, name) if entry.IsDir() { sub, err := grepInDir(path, pattern, depth+1) if err == nil { buf.WriteString(sub) } continue } result, err := grepInFile(path, pattern) if err != nil { continue } buf.WriteString(result) } return buf.String(), nil } func grepInFile(path, pattern string) (string, error) { re, err := regexp.Compile(pattern) if err != nil { re, err = regexp.Compile(regexp.QuoteMeta(pattern)) if err != nil { return "", err } } file, err := os.Open(path) if err != nil { return "", err } defer file.Close() var buf strings.Builder scanner := bufio.NewScanner(file) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) lineNum := 0 matchCount := 0 for scanner.Scan() { lineNum++ if re.MatchString(scanner.Text()) { fmt.Fprintf(&buf, "%s:%d: %s\n", path, lineNum, scanner.Text()) matchCount++ if matchCount >= 50 { buf.WriteString("... [truncated, more matches exist]\n") break } } } return buf.String(), nil } func getConfigSection(section string) ToolResponse { configPath, err := os.UserConfigDir() if err != nil { return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err)) } configPath = filepath.Join(configPath, "muyue", "config.yaml") data, err := os.ReadFile(configPath) if err != nil { return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err)) } switch section { case "providers", "profile", "tools", "terminal": sectionData := extractYAMLSection(data, section) if sectionData == "" { return TextResponse(fmt.Sprintf("Section '%s' not found in config.", section)) } return TextResponse(sectionData) default: content := string(data) if len(content) > 8000 { content = content[:8000] + "\n... [truncated]" } return TextResponse(content) } } func extractYAMLSection(data []byte, section string) string { lines := strings.Split(string(data), "\n") inSection := false indentLevel := 0 var buf strings.Builder for _, line := range lines { trimmed := strings.TrimSpace(line) if trimmed == "" || strings.HasPrefix(trimmed, "#") { if inSection { buf.WriteString("\n") } continue } if !inSection { if strings.HasPrefix(trimmed, section+":") || strings.HasPrefix(trimmed, section+" ") { inSection = true indentLevel = len(line) - len(strings.TrimLeft(line, " ")) buf.WriteString(line) buf.WriteString("\n") } continue } currentIndent := len(line) - len(strings.TrimLeft(line, " ")) if currentIndent <= indentLevel && trimmed != "" { break } buf.WriteString(line) buf.WriteString("\n") } return strings.TrimSpace(buf.String()) } func setProviderConfig(p SetProviderParams) ToolResponse { configPath, err := os.UserConfigDir() if err != nil { return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err)) } configPath = filepath.Join(configPath, "muyue", "config.yaml") data, err := os.ReadFile(configPath) if err != nil { return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err)) } lines := strings.Split(string(data), "\n") inProviders := false providerIndent := 0 foundProvider := false insertIdx := -1 lastProviderEnd := -1 for i, line := range lines { trimmed := strings.TrimSpace(line) if !inProviders { if strings.HasPrefix(trimmed, "providers:") { inProviders = true providerIndent = len(line) - len(strings.TrimLeft(line, " ")) } continue } currentIndent := len(line) - len(strings.TrimLeft(line, " ")) if currentIndent <= providerIndent && trimmed != "" && !strings.HasPrefix(trimmed, "#") { lastProviderEnd = i break } if currentIndent == providerIndent+2 && strings.HasPrefix(trimmed, "- name:") { nameMatch := strings.TrimPrefix(trimmed, "- name:") nameMatch = strings.TrimSpace(nameMatch) if nameMatch == p.Name { foundProvider = true insertIdx = i } if insertIdx == -1 || insertIdx < i { insertIdx = i } } } if lastProviderEnd == -1 { lastProviderEnd = len(lines) } entryIndent := strings.Repeat(" ", providerIndent+4) var newEntry strings.Builder newEntry.WriteString(fmt.Sprintf(" - name: %s\n", p.Name)) if p.Model != "" { newEntry.WriteString(fmt.Sprintf("%smodel: %s\n", entryIndent, p.Model)) } if p.BaseURL != "" { newEntry.WriteString(fmt.Sprintf("%sbase_url: %s\n", entryIndent, p.BaseURL)) } if p.APIKey != "" { newEntry.WriteString(fmt.Sprintf("%sapi_key: %s\n", entryIndent, p.APIKey)) } if p.Active != nil { newEntry.WriteString(fmt.Sprintf("%sactive: %v\n", entryIndent, *p.Active)) } if foundProvider && insertIdx >= 0 { var endIdx int for endIdx = insertIdx + 1; endIdx < len(lines); endIdx++ { li := len(lines[endIdx]) - len(strings.TrimLeft(lines[endIdx], " ")) if li <= providerIndent+2 || lines[endIdx] == "" { if endIdx > insertIdx+1 && strings.TrimSpace(lines[endIdx]) == "" { continue } break } } newLines := make([]string, 0, len(lines)) newLines = append(newLines, lines[:insertIdx]...) newLines = append(newLines, strings.TrimSuffix(newEntry.String(), "\n")) newLines = append(newLines, lines[endIdx:]...) lines = newLines } else { insertAt := lastProviderEnd newLines := make([]string, 0, len(lines)+10) newLines = append(newLines, lines[:insertAt]...) newLines = append(newLines, strings.TrimSuffix(newEntry.String(), "\n")) newLines = append(newLines, lines[insertAt:]...) lines = newLines } content := strings.Join(lines, "\n") if err := os.WriteFile(configPath, []byte(content), 0600); err != nil { return TextErrorResponse(fmt.Sprintf("write config error: %v", err)) } return TextResponse(fmt.Sprintf("Provider '%s' configured successfully.", p.Name)) } func manageSSHAction(p ManageSSHParams) ToolResponse { configPath, err := os.UserConfigDir() if err != nil { return TextErrorResponse(fmt.Sprintf("cannot find config dir: %v", err)) } configPath = filepath.Join(configPath, "muyue", "config.yaml") data, err := os.ReadFile(configPath) if err != nil { return TextErrorResponse(fmt.Sprintf("cannot read config: %v", err)) } switch p.Action { case "list": sshSection := extractYAMLSection(data, "ssh") if sshSection == "" { return TextResponse("No SSH connections configured.") } return TextResponse(sshSection) case "add": if p.Name == "" || p.Host == "" || p.User == "" { return TextErrorResponse("name, host, and user are required for add action") } if p.Port == 0 { p.Port = 22 } lines := strings.Split(string(data), "\n") sshIdx := -1 sshIndent := 0 lastSSHEnd := -1 for i, line := range lines { trimmed := strings.TrimSpace(line) if sshIdx == -1 && strings.HasPrefix(trimmed, "ssh:") { sshIdx = i sshIndent = len(line) - len(strings.TrimLeft(line, " ")) continue } if sshIdx != -1 { li := len(line) - len(strings.TrimLeft(line, " ")) if li <= sshIndent && trimmed != "" { lastSSHEnd = i break } } } if lastSSHEnd == -1 { lastSSHEnd = len(lines) } entry := fmt.Sprintf(" - name: %s\n host: %s\n port: %d\n user: %s", p.Name, p.Host, p.Port, p.User) if p.KeyPath != "" { entry += fmt.Sprintf("\n key_path: %s", p.KeyPath) } newLines := make([]string, 0, len(lines)+10) newLines = append(newLines, lines[:lastSSHEnd]...) newLines = append(newLines, entry) newLines = append(newLines, lines[lastSSHEnd:]...) if err := os.WriteFile(configPath, []byte(strings.Join(newLines, "\n")), 0600); err != nil { return TextErrorResponse(fmt.Sprintf("write config error: %v", err)) } return TextResponse(fmt.Sprintf("SSH connection '%s' (%s@%s:%d) added.", p.Name, p.User, p.Host, p.Port)) case "remove": if p.Name == "" { return TextErrorResponse("name is required for remove action") } lines := strings.Split(string(data), "\n") newLines := make([]string, 0, len(lines)) skipping := false removed := false for i, line := range lines { trimmed := strings.TrimSpace(line) if strings.Contains(trimmed, "name: "+p.Name) && strings.HasPrefix(trimmed, "-") { skipping = true removed = true continue } if skipping { li := len(line) - len(strings.TrimLeft(line, " ")) if li > 6 && i < len(lines)-1 && strings.TrimSpace(lines[i+1]) != "" { continue } skipping = false continue } newLines = append(newLines, line) } if !removed { return TextErrorResponse(fmt.Sprintf("SSH connection '%s' not found.", p.Name)) } if err := os.WriteFile(configPath, []byte(strings.Join(newLines, "\n")), 0600); err != nil { return TextErrorResponse(fmt.Sprintf("write config error: %v", err)) } return TextResponse(fmt.Sprintf("SSH connection '%s' removed.", p.Name)) default: return TextErrorResponse("unknown action. Use 'list', 'add', or 'remove'") } } func fetchURL(url string) ToolResponse { if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { return TextErrorResponse("only http/https URLs are supported") } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return TextErrorResponse(fmt.Sprintf("create request: %v", err)) } req.Header.Set("User-Agent", "MuyueStudio/1.0") resp, err := http.DefaultClient.Do(req) if err != nil { return TextErrorResponse(fmt.Sprintf("fetch error: %v", err)) } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 50000)) if err != nil { return TextErrorResponse(fmt.Sprintf("read error: %v", err)) } if resp.StatusCode != http.StatusOK { return TextErrorResponse(fmt.Sprintf("HTTP %d: %s", resp.StatusCode, truncate(string(body), 2000))) } contentType := resp.Header.Get("Content-Type") if strings.Contains(contentType, "text/html") { text := stripHTML(string(body)) if len(text) > 8000 { text = text[:8000] + "\n... [truncated]" } return TextResponse(text) } result := string(body) if len(result) > 10000 { result = result[:10000] + "\n... [truncated]" } return TextResponse(result) } func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } func stripHTML(html string) string { tagRe := regexp.MustCompile(`<[^>]*>`) text := tagRe.ReplaceAllString(html, " ") entityRe := regexp.MustCompile(`&[a-zA-Z]+;`) text = entityRe.ReplaceAllStringFunc(text, func(s string) string { switch s { case "&": return "&" case "<": return "<" case ">": return ">" case """: return "\"" case "'": return "'" case " ": return " " default: return " " } }) multiSpace := regexp.MustCompile(`\s+`) text = multiSpace.ReplaceAllString(text, " ") return strings.TrimSpace(text) } var _ = runtime.GOOS var _ = json.Marshal