Replace message-count context windows with token-budget based ones for both studio and shell. Add /api/ai/task endpoint for background tool check/install/update. Enhance sudo blocking to catch piped/chained elevation commands. Add SSH password support via sshpass and connection editing UI. Remove realTokens persistence in favor of consumption tracking. Bump to 0.4.1. 💘 Generated with Crush Assisted-by: GLM-5.1 via Crush <crush@charm.land>
434 lines
12 KiB
Go
434 lines
12 KiB
Go
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/muyue/muyue/internal/agent"
|
|
"github.com/muyue/muyue/internal/orchestrator"
|
|
)
|
|
|
|
var thinkingTagRegex = regexp.MustCompile(`(?s)<[Tt]hink[^>]*>.*?</[Tt]hink>`)
|
|
var fileMentionRegex = regexp.MustCompile(`@(\S+\.[a-zA-Z0-9]+)`)
|
|
|
|
type ImageAttachment struct {
|
|
Data string `json:"data"`
|
|
Filename string `json:"filename"`
|
|
MimeType string `json:"mime_type"`
|
|
}
|
|
|
|
func resolveFileMentions(text string) string {
|
|
return fileMentionRegex.ReplaceAllStringFunc(text, func(match string) string {
|
|
filePath := match[1:]
|
|
if strings.HasPrefix(filePath, "~/") {
|
|
if home, err := os.UserHomeDir(); err == nil {
|
|
filePath = filepath.Join(home, filePath[2:])
|
|
}
|
|
}
|
|
if !filepath.IsAbs(filePath) {
|
|
if home, err := os.UserHomeDir(); err == nil {
|
|
filePath = filepath.Join(home, filePath)
|
|
}
|
|
}
|
|
data, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return match + fmt.Sprintf(" (erreur: fichier non trouve)")
|
|
}
|
|
content := string(data)
|
|
if len(content) > 50000 {
|
|
content = content[:50000] + "\n... (tronque a 50Ko)"
|
|
}
|
|
return fmt.Sprintf("[Fichier: %s]\n%s\n[Fin du fichier: %s]", filepath.Base(filePath), content, filepath.Base(filePath))
|
|
})
|
|
}
|
|
|
|
var vlmClient = &http.Client{Timeout: 60 * time.Second}
|
|
|
|
func (s *Server) describeImages(images []ImageAttachment) []string {
|
|
var apiKey string
|
|
for i := range s.config.AI.Providers {
|
|
if s.config.AI.Providers[i].Active {
|
|
apiKey = s.config.AI.Providers[i].APIKey
|
|
break
|
|
}
|
|
}
|
|
if apiKey == "" {
|
|
log.Printf("[vlm] no API key found for image description")
|
|
return nil
|
|
}
|
|
|
|
descriptions := make([]string, 0, len(images))
|
|
for i, img := range images {
|
|
desc, err := s.callVLM(apiKey, img)
|
|
if err != nil {
|
|
log.Printf("[vlm] image %d (%s) failed: %v", i+1, img.Filename, err)
|
|
descriptions = append(descriptions, fmt.Sprintf("(description unavailable: %v)", err))
|
|
} else {
|
|
descriptions = append(descriptions, desc)
|
|
}
|
|
}
|
|
return descriptions
|
|
}
|
|
|
|
func (s *Server) callVLM(apiKey string, img ImageAttachment) (string, error) {
|
|
payload := map[string]string{
|
|
"prompt": "Describe this image in detail. Include all text, UI elements, code, diagrams, or data visible. Be thorough and specific.",
|
|
"image_url": img.Data,
|
|
}
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal vlm request: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 55*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.minimax.io/v1/coding_plan/vlm", bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", fmt.Errorf("create vlm request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
resp, err := vlmClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("vlm request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read vlm response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("vlm API error (%d): %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Content string `json:"content"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return "", fmt.Errorf("parse vlm response: %w", err)
|
|
}
|
|
|
|
if result.Content == "" {
|
|
return "(empty description)", nil
|
|
}
|
|
return result.Content, nil
|
|
}
|
|
|
|
func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
writeError(w, "POST only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
var body struct {
|
|
Message string `json:"message"`
|
|
Stream bool `json:"stream"`
|
|
Images []ImageAttachment `json:"images"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
writeError(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
if body.Message == "" {
|
|
writeError(w, "no message", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if len(body.Images) > 3 {
|
|
writeError(w, "max 3 images", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
enrichedMessage := resolveFileMentions(body.Message)
|
|
|
|
var imageIDs []string
|
|
if len(body.Images) > 0 {
|
|
descriptions := s.describeImages(body.Images)
|
|
var imgContext strings.Builder
|
|
for i, desc := range descriptions {
|
|
imgContext.WriteString(fmt.Sprintf("\n[Image %d (%s): %s]\n", i+1, body.Images[i].Filename, desc))
|
|
|
|
id, err := saveImage(body.Images[i].Data, body.Images[i].Filename, body.Images[i].MimeType)
|
|
if err != nil {
|
|
log.Printf("[images] failed to save %s: %v", body.Images[i].Filename, err)
|
|
} else {
|
|
imageIDs = append(imageIDs, id)
|
|
}
|
|
}
|
|
enrichedMessage = imgContext.String() + enrichedMessage
|
|
}
|
|
|
|
displayMsg := body.Message
|
|
if len(body.Images) > 0 {
|
|
imgNames := make([]string, len(body.Images))
|
|
for i, img := range body.Images {
|
|
imgNames[i] = img.Filename
|
|
}
|
|
displayMsg += " [" + strings.Join(imgNames, ", ") + "]"
|
|
}
|
|
|
|
if len(imageIDs) > 0 {
|
|
s.convStore.AddWithImages("user", displayMsg, imageIDs)
|
|
} else {
|
|
s.convStore.Add("user", displayMsg)
|
|
}
|
|
|
|
if s.convStore.NeedsSummarization() {
|
|
s.autoSummarize()
|
|
}
|
|
|
|
orb, err := orchestrator.New(s.config)
|
|
if err != nil {
|
|
writeError(w, err.Error(), http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
var studioPrompt strings.Builder
|
|
studioPrompt.WriteString(agent.StudioSystemPrompt())
|
|
studioPrompt.WriteString(fmt.Sprintf("\nDate: %s\nHeure: %s\n", time.Now().Format("02/01/2006"), time.Now().Format("15:04:05")))
|
|
canSudo := !agent.NeedsSudoPassword()
|
|
studioPrompt.WriteString(fmt.Sprintf("Root: %t\n", !canSudo))
|
|
if !canSudo {
|
|
studioPrompt.WriteString("⚠️ Session sans sudo sans mot de passe — les commandes sudo/doas nécessitent une autorisation. N'utilise PAS sudo ou doas sans demander.\n")
|
|
} else {
|
|
studioPrompt.WriteString("⚠️ Session avec privilèges sudo sans mot de passe — les commandes sudo s'exécuteront directement.\n")
|
|
}
|
|
orb.SetSystemPrompt(studioPrompt.String())
|
|
orb.SetTools(s.agentToolsJSON)
|
|
|
|
if body.Stream {
|
|
s.handleStreamChat(w, orb, enrichedMessage)
|
|
} else {
|
|
s.handleNonStreamChat(w, orb, enrichedMessage)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) {
|
|
SetupSSEHeaders(w)
|
|
flusher, canFlush := w.(http.Flusher)
|
|
|
|
|
|
sseWriter := NewSSEWriter(w)
|
|
|
|
|
|
ctx := context.Background()
|
|
messages := s.buildContextMessages(userMessage)
|
|
|
|
engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON)
|
|
engine.OnChunk(func(data map[string]interface{}) {
|
|
if data == nil {
|
|
return
|
|
}
|
|
sseWriter.Write(data)
|
|
if canFlush {
|
|
flusher.Flush()
|
|
}
|
|
})
|
|
|
|
finalContent, allToolCalls, allToolResults, err := engine.RunWithTools(ctx, messages)
|
|
if err != nil {
|
|
sseWriter.Write(map[string]interface{}{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
storeContent := finalContent
|
|
if len(allToolCalls) > 0 {
|
|
storeObj := map[string]interface{}{
|
|
"content": storeContent,
|
|
"tool_calls": allToolCalls,
|
|
"tool_results": allToolResults,
|
|
}
|
|
storeJSON, _ := json.Marshal(storeObj)
|
|
storeContent = string(storeJSON)
|
|
}
|
|
s.convStore.Add("assistant", storeContent)
|
|
|
|
s.consumption.Record(engine.ProviderName(), engine.TotalTokens)
|
|
|
|
sseWriter.Write(map[string]interface{}{"done": "true"})
|
|
}
|
|
|
|
func (s *Server) handleNonStreamChat(w http.ResponseWriter, orb *orchestrator.Orchestrator, userMessage string) {
|
|
ctx := context.Background()
|
|
messages := s.buildContextMessages(userMessage)
|
|
|
|
engine := NewChatEngine(orb, s.agentRegistry, s.agentToolsJSON)
|
|
finalContent, err := engine.RunNonStream(ctx, messages)
|
|
if err != nil {
|
|
writeError(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
s.convStore.Add("assistant", finalContent)
|
|
|
|
s.consumption.Record(engine.ProviderName(), engine.TotalTokens)
|
|
|
|
writeJSON(w, map[string]string{"content": finalContent})
|
|
}
|
|
|
|
func cleanThinkingTags(content string) string {
|
|
return strings.TrimSpace(thinkingTagRegex.ReplaceAllString(content, ""))
|
|
}
|
|
|
|
func (s *Server) buildContextMessages(userMessage string) []orchestrator.Message {
|
|
history := s.convStore.Get()
|
|
|
|
sysPromptTokens := utf8.RuneCountInString(agent.StudioSystemPrompt())/charsPerToken + 50
|
|
toolsTokens := utf8.RuneCountInString(string(s.agentToolsJSON)) / charsPerToken
|
|
responseMargin := 4000
|
|
userMsgTokens := utf8.RuneCountInString(userMessage) / charsPerToken
|
|
|
|
overhead := sysPromptTokens + toolsTokens + responseMargin + userMsgTokens
|
|
available := contextWindowTokens - overhead
|
|
if available < 1000 {
|
|
available = 1000
|
|
}
|
|
|
|
included := 0
|
|
tokensUsed := 0
|
|
for i := len(history) - 1; i >= 0; i-- {
|
|
msgTokens := utf8.RuneCountInString(history[i].Content) / charsPerToken
|
|
if msgTokens == 0 {
|
|
msgTokens = 1
|
|
}
|
|
if tokensUsed+msgTokens > available {
|
|
break
|
|
}
|
|
tokensUsed += msgTokens
|
|
included++
|
|
}
|
|
|
|
start := len(history) - included
|
|
if start < 0 {
|
|
start = 0
|
|
}
|
|
|
|
if start > 0 {
|
|
log.Printf("[studio] context budget: %d/%d tokens, including %d/%d messages (dropped %d older)", tokensUsed+overhead, contextWindowTokens, included, len(history), start)
|
|
}
|
|
|
|
messages := make([]orchestrator.Message, 0, included+2)
|
|
|
|
summary := s.convStore.GetSummary()
|
|
if summary != "" && start > 0 {
|
|
messages = append(messages, orchestrator.Message{
|
|
Role: "system",
|
|
Content: orchestrator.TextContent("Résumé de la conversation précédente:\n" + summary),
|
|
})
|
|
}
|
|
|
|
for _, m := range history[start:] {
|
|
content := m.Content
|
|
if m.Role == "assistant" {
|
|
var parsed struct {
|
|
Content string `json:"content"`
|
|
ToolCalls []struct {
|
|
ToolCallID string `json:"tool_call_id"`
|
|
Name string `json:"name"`
|
|
Args string `json:"args"`
|
|
} `json:"tool_calls"`
|
|
}
|
|
if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Content != "" {
|
|
content = parsed.Content
|
|
}
|
|
}
|
|
role := m.Role
|
|
if role == "system" {
|
|
continue
|
|
}
|
|
messages = append(messages, orchestrator.Message{
|
|
Role: role,
|
|
Content: orchestrator.TextContent(content),
|
|
})
|
|
}
|
|
|
|
messages = append(messages, orchestrator.Message{
|
|
Role: "user",
|
|
Content: orchestrator.TextContent(userMessage),
|
|
})
|
|
|
|
return messages
|
|
}
|
|
|
|
func (s *Server) autoSummarize() {
|
|
messages := s.convStore.Get()
|
|
if len(messages) < 10 {
|
|
return
|
|
}
|
|
|
|
half := len(messages) / 2
|
|
var oldText string
|
|
for _, m := range messages[:half] {
|
|
oldText += m.Role + ": " + m.Content + "\n\n"
|
|
}
|
|
|
|
summary := s.convStore.GetSummary()
|
|
if summary != "" {
|
|
oldText = "Résumé précédent:\n" + summary + "\n\nNouveaux échanges:\n" + oldText
|
|
}
|
|
|
|
orb, err := orchestrator.New(s.config)
|
|
if err != nil {
|
|
return
|
|
}
|
|
orb.SetSystemPrompt(summarizePrompt)
|
|
|
|
result, err := orb.Send(oldText)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
s.convStore.SetSummary(result)
|
|
s.convStore.TrimOld(len(messages) - half)
|
|
s.convStore.Add("system", "[Conversation résumée automatiquement]")
|
|
}
|
|
|
|
func (s *Server) handleChatHistory(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
writeError(w, "GET only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
messages := s.convStore.Get()
|
|
writeJSON(w, map[string]interface{}{
|
|
"messages": messages,
|
|
"tokens": s.convStore.ApproxTokenCount(),
|
|
"max_tokens": contextWindowTokens,
|
|
"summarize_at": int(float64(contextWindowTokens) * summarizeRatio),
|
|
"summary": s.convStore.GetSummary(),
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleChatClear(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
writeError(w, "POST only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
s.convStore.Clear()
|
|
writeJSON(w, map[string]string{"status": "ok"})
|
|
}
|
|
|
|
func (s *Server) handleChatSummarize(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
writeError(w, "POST only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
s.autoSummarize()
|
|
writeJSON(w, map[string]interface{}{
|
|
"status": "ok",
|
|
"tokens": s.convStore.ApproxTokenCount(),
|
|
"summary": s.convStore.GetSummary(),
|
|
})
|
|
}
|