package api import ( "encoding/json" "fmt" "net/http" "os" "os/exec" "path/filepath" "regexp" "runtime" "strings" "sync" "time" "github.com/gorilla/websocket" "github.com/muyue/muyue/internal/config" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") if origin == "" { return true } switch { case strings.HasPrefix(origin, "http://127.0.0.1"), strings.HasPrefix(origin, "http://localhost"), strings.HasPrefix(origin, "http://[::1]"), strings.HasPrefix(origin, "https://127.0.0.1"), strings.HasPrefix(origin, "https://localhost"), strings.HasPrefix(origin, "https://[::1]"): return true default: return false } }, } type wsMessage struct { Type string `json:"type"` Data string `json:"data"` Rows uint16 `json:"rows,omitempty"` Cols uint16 `json:"cols,omitempty"` } func (s *Server) handleTerminalWS(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() var initMsg wsMessage _, raw, err := conn.ReadMessage() if err != nil { conn.WriteJSON(wsMessage{Type: "error", Data: "failed to read init message"}) return } if err := json.Unmarshal(raw, &initMsg); err != nil { conn.WriteJSON(wsMessage{Type: "error", Data: "invalid init message"}) return } var cmd *exec.Cmd if initMsg.Type == "ssh" && initMsg.Data != "" { var sshConf struct { Host string `json:"host"` Port int `json:"port"` User string `json:"user"` KeyPath string `json:"key_path"` Password string `json:"password"` } if err := json.Unmarshal([]byte(initMsg.Data), &sshConf); err != nil { conn.WriteJSON(wsMessage{Type: "error", Data: "invalid ssh config"}) return } if sshConf.Port == 0 { sshConf.Port = 22 } sshArgs := []string{ "-o", "StrictHostKeyChecking=accept-new", "-o", "UserKnownHostsFile=/dev/null", "-o", "LogLevel=ERROR", } if sshConf.KeyPath != "" { sshArgs = append(sshArgs, "-i", sshConf.KeyPath) } if sshConf.Port != 22 { sshArgs = append(sshArgs, "-p", fmt.Sprintf("%d", sshConf.Port)) } sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", sshConf.User, sshConf.Host)) if sshConf.Password != "" { sshpassPath, err := exec.LookPath("sshpass") if err == nil { args := append([]string{"-e"}, "ssh") args = append(args, sshArgs...) cmd = exec.Command(sshpassPath, args...) cmd.Env = append(os.Environ(), "SSHPASS="+sshConf.Password) } else { cmd = exec.Command("ssh", sshArgs...) } } else { cmd = exec.Command("ssh", sshArgs...) } } else { shell := strings.TrimSpace(initMsg.Data) if shell == "" { shell = detectShell() } if shell == "" { shell = "/bin/sh" } // Support "wsl -d " shell strings sent from the UI quick-access. if extra, ok := parseWSLShell(shell); ok { wslPath, err := exec.LookPath("wsl") if err != nil { conn.WriteJSON(wsMessage{Type: "error", Data: "wsl not found on this host"}) return } cmd = exec.Command(wslPath, extra...) } else { if path, err := exec.LookPath(shell); err == nil { shell = path } if _, err := os.Stat(shell); err != nil { conn.WriteJSON(wsMessage{Type: "error", Data: fmt.Sprintf("shell not found: %s (resolved from: %q)", shell, initMsg.Data)}) return } shellName := filepath.Base(shell) switch shellName { case "wsl": cmd = exec.Command(shell, "--shell-type", "login") case "powershell", "pwsh": cmd = exec.Command(shell, "-NoLogo", "-NoProfile") case "fish": cmd = exec.Command(shell, "--login") default: cmd = exec.Command(shell) } } } if cmd.Env == nil { cmd.Env = os.Environ() } cmd.Env = append(cmd.Env, "TERM=xterm-256color") session, err := startTermSession(cmd) if err != nil { conn.WriteJSON(wsMessage{Type: "error", Data: err.Error()}) return } var once sync.Once cleanup := func() { once.Do(func() { session.Close() session.Wait() }) } defer cleanup() go func() { buf := make([]byte, 4096) for { n, err := session.Read(buf) if n > 0 { if err := conn.WriteJSON(wsMessage{ Type: "output", Data: string(buf[:n]), }); err != nil { cleanup() return } } if err != nil { cleanup() return } } }() conn.SetReadLimit(1 << 20) conn.SetReadDeadline(time.Time{}) for { _, raw, err := conn.ReadMessage() if err != nil { cleanup() return } var msg wsMessage if err := json.Unmarshal(raw, &msg); err != nil { continue } switch msg.Type { case "input": if _, err := session.Write([]byte(msg.Data)); err != nil { cleanup() return } case "resize": if msg.Rows > 0 && msg.Cols > 0 { session.Resize(msg.Rows, msg.Cols) } } } } func (s *Server) handleTerminalSessions(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { masked := make([]config.SSHConnection, len(s.config.Terminal.SSH)) for i, c := range s.config.Terminal.SSH { masked[i] = c if masked[i].Password != "" { masked[i].Password = "***" } } writeJSON(w, map[string]interface{}{ "ssh": masked, "system": detectSystemTerminals(), }) return } if r.Method != "POST" { writeError(w, "POST only", http.StatusMethodNotAllowed) return } var body struct { Name string `json:"name"` Host string `json:"host"` Port int `json:"port"` User string `json:"user"` KeyPath string `json:"key_path"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { writeError(w, err.Error(), http.StatusBadRequest) return } if body.Name == "" || body.Host == "" { writeError(w, "name and host required", http.StatusBadRequest) return } if body.Port == 0 { body.Port = 22 } for i, c := range s.config.Terminal.SSH { if c.Name == body.Name { password := body.Password if password == "***" { password = c.Password } s.config.Terminal.SSH[i] = config.SSHConnection{ Name: body.Name, Host: body.Host, Port: body.Port, User: body.User, KeyPath: body.KeyPath, Password: password, } if err := config.Save(s.config); err != nil { writeError(w, err.Error(), http.StatusInternalServerError) return } writeJSON(w, map[string]string{"status": "ok"}) return } } conn := config.SSHConnection{ Name: body.Name, Host: body.Host, Port: body.Port, User: body.User, KeyPath: body.KeyPath, Password: body.Password, } if s.config.Terminal.SSH == nil { s.config.Terminal.SSH = []config.SSHConnection{} } s.config.Terminal.SSH = append(s.config.Terminal.SSH, conn) if err := config.Save(s.config); err != nil { writeError(w, err.Error(), http.StatusInternalServerError) return } writeJSON(w, map[string]string{"status": "ok"}) } func (s *Server) handleTerminalSessionsDelete(w http.ResponseWriter, r *http.Request) { if r.Method != "DELETE" { writeError(w, "DELETE only", http.StatusMethodNotAllowed) return } name := strings.TrimPrefix(r.URL.Path, "/api/terminal/sessions/") if name == "" { writeError(w, "name required", http.StatusBadRequest) return } found := false for i, c := range s.config.Terminal.SSH { if c.Name == name { s.config.Terminal.SSH = append(s.config.Terminal.SSH[:i], s.config.Terminal.SSH[i+1:]...) found = true break } } if !found { writeError(w, "not found", http.StatusNotFound) return } if err := config.Save(s.config); err != nil { writeError(w, err.Error(), http.StatusInternalServerError) return } writeJSON(w, map[string]string{"status": "ok"}) } func detectShell() string { shells := []string{"zsh", "bash", "fish", "pwsh", "powershell"} for _, s := range shells { if path, err := exec.LookPath(s); err == nil { return path } } return "/bin/sh" } // listWSLDistros returns the list of installed WSL distribution names. // Windows hosts only — returns nil on other platforms or if WSL is unavailable. func listWSLDistros() []string { if runtime.GOOS != "windows" { return nil } out, err := exec.Command("wsl", "--list", "--quiet").Output() if err != nil { return nil } // `wsl --list --quiet` outputs UTF-16LE on Windows. Strip BOM and decode best-effort. raw := stripUTF16ToASCII(out) var distros []string seen := make(map[string]bool) for _, line := range strings.Split(raw, "\n") { name := strings.TrimSpace(line) if name == "" || seen[name] { continue } // Skip default-marker arrows or annotations. name = strings.TrimSpace(strings.TrimPrefix(name, "*")) if name == "" || !validWSLName.MatchString(name) { continue } seen[name] = true distros = append(distros, name) } return distros } var validWSLName = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) // parseWSLShell recognises strings of the form "wsl -d " (and optionally // "-u ") emitted by the Shell tab quick-access menu, returning the args // to pass to the wsl binary. Returns ok=false otherwise. func parseWSLShell(shell string) ([]string, bool) { parts := strings.Fields(shell) if len(parts) < 3 || parts[0] != "wsl" { return nil, false } args := []string{} i := 1 for i < len(parts) { switch parts[i] { case "-d", "--distribution": if i+1 >= len(parts) || !validWSLName.MatchString(parts[i+1]) { return nil, false } args = append(args, "-d", parts[i+1]) i += 2 case "-u", "--user": if i+1 >= len(parts) || !validWSLName.MatchString(parts[i+1]) { return nil, false } args = append(args, "-u", parts[i+1]) i += 2 default: return nil, false } } if len(args) == 0 { return nil, false } return args, true } func stripUTF16ToASCII(b []byte) string { // Best-effort: keep only printable bytes (drop high bytes from UTF-16LE pairs). var out []byte for i := 0; i < len(b); i++ { c := b[i] if c == 0 { continue } if c >= 32 && c < 127 || c == '\n' || c == '\r' || c == '\t' { out = append(out, c) } } return string(out) } func detectSystemTerminals() []map[string]string { var terminals []map[string]string terminals = append(terminals, map[string]string{ "type": "local", "name": "Default Shell", "shell": detectShell(), }) if runtime.GOOS == "windows" { if _, err := exec.LookPath("wsl"); err == nil { terminals = append(terminals, map[string]string{ "type": "local", "name": "WSL (default)", "shell": "wsl", }) for _, distro := range listWSLDistros() { terminals = append(terminals, map[string]string{ "type": "local", "name": "WSL: " + distro, "shell": "wsl -d " + distro, }) } } if _, err := exec.LookPath("powershell"); err == nil { terminals = append(terminals, map[string]string{ "type": "local", "name": "PowerShell", "shell": "powershell", }) } if _, err := exec.LookPath("pwsh"); err == nil { terminals = append(terminals, map[string]string{ "type": "local", "name": "PowerShell Core", "shell": "pwsh", }) } if _, err := exec.LookPath("cmd"); err == nil { terminals = append(terminals, map[string]string{ "type": "local", "name": "Command Prompt", "shell": "cmd", }) } } return terminals }