Files
MuyueWorkspace/internal/agent/tools.go
Augustin 61da8039bc feat(agent): refactor AI chat with streaming, agent registry, and tool execution
- Replace old tool-call regex with proper agent registry
- Add streaming chat via SSE (handleStreamChat / handleNonStreamChat)
- Add internal/agent package with tool definitions and execution
- Add orchestrator with system prompt and tool scaffolding
- Add internal/agent/ directory
- Studio.jsx: streaming chat with thinking indicator and tool result rendering
- global.css: chat bubble styles, streaming animation, thinking dots
- handlers_chat.go: full rewrite using new agent/orchestrator architecture

💘 Generated with Crush

Assisted-by: MiniMax-M2.7 via Crush <crush@charm.land>
2026-04-23 19:47:00 +02:00

219 lines
4.9 KiB
Go

package agent
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
)
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
}
type ToolResponse struct {
Content string `json:"content"`
IsError bool `json:"is_error"`
Meta map[string]string `json:"meta,omitempty"`
}
func TextResponse(content string) ToolResponse {
return ToolResponse{Content: content}
}
func TextErrorResponse(msg string) ToolResponse {
return ToolResponse{Content: msg, IsError: true}
}
type ToolDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Params json.RawMessage `json:"parameters"`
Handler func(ctx context.Context, args json.RawMessage) (ToolResponse, error)
}
func (td *ToolDefinition) Execute(ctx context.Context, call ToolCall) (ToolResponse, error) {
resp, err := td.Handler(ctx, call.Arguments)
if err != nil {
return ToolResponse{Content: err.Error(), IsError: true}, nil
}
return resp, nil
}
func (td *ToolDefinition) ToOpenAITool() map[string]interface{} {
return map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": td.Name,
"description": td.Description,
"parameters": td.Params,
},
}
}
func NewTool[P any](name, description string, handler func(ctx context.Context, params P) (ToolResponse, error)) (*ToolDefinition, error) {
var zero P
paramsSchema, err := generateSchema(zero)
if err != nil {
return nil, fmt.Errorf("generate schema for %s: %w", name, err)
}
wrappedHandler := func(ctx context.Context, raw json.RawMessage) (ToolResponse, error) {
var params P
if err := json.Unmarshal(raw, &params); err != nil {
return TextErrorResponse(fmt.Sprintf("invalid arguments: %v", err)), nil
}
return handler(ctx, params)
}
return &ToolDefinition{
Name: name,
Description: description,
Params: paramsSchema,
Handler: wrappedHandler,
}, nil
}
type Registry struct {
tools map[string]*ToolDefinition
}
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]*ToolDefinition),
}
}
func (r *Registry) Register(tool *ToolDefinition) error {
if _, exists := r.tools[tool.Name]; exists {
return fmt.Errorf("tool %q already registered", tool.Name)
}
r.tools[tool.Name] = tool
return nil
}
func (r *Registry) Get(name string) (*ToolDefinition, bool) {
t, ok := r.tools[name]
return t, ok
}
func (r *Registry) All() []*ToolDefinition {
out := make([]*ToolDefinition, 0, len(r.tools))
for _, t := range r.tools {
out = append(out, t)
}
return out
}
func (r *Registry) OpenAITools() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(r.tools))
for _, t := range r.tools {
out = append(out, t.ToOpenAITool())
}
return out
}
func (r *Registry) Execute(ctx context.Context, call ToolCall) (ToolResponse, error) {
tool, ok := r.tools[call.Name]
if !ok {
return TextErrorResponse(fmt.Sprintf("unknown tool: %s", call.Name)), nil
}
return tool.Execute(ctx, call)
}
func generateSchema(v interface{}) (json.RawMessage, error) {
t := reflect.TypeOf(v)
if t == nil {
return json.RawMessage(`{"type":"object","properties":{}}`), nil
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return json.RawMessage(`{"type":"object","properties":{}}`), nil
}
props := make(map[string]interface{})
required := []string{}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if !field.IsExported() {
continue
}
jsonTag := field.Tag.Get("json")
if jsonTag == "-" {
continue
}
jsonName := field.Name
parts := strings.Split(jsonTag, ",")
if parts[0] != "" {
jsonName = parts[0]
}
omitempty := false
for _, part := range parts[1:] {
if part == "omitempty" {
omitempty = true
}
}
desc := field.Tag.Get("description")
prop := map[string]interface{}{
"type": goTypeToJSON(field.Type),
}
if desc != "" {
prop["description"] = desc
}
props[jsonName] = prop
if !omitempty {
required = append(required, jsonName)
}
}
schema := map[string]interface{}{
"type": "object",
"properties": props,
}
if len(required) > 0 {
schema["required"] = required
}
data, err := json.Marshal(schema)
if err != nil {
return nil, err
}
return json.RawMessage(data), nil
}
func goTypeToJSON(t reflect.Type) string {
switch t.Kind() {
case reflect.String:
return "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "integer"
case reflect.Float32, reflect.Float64:
return "number"
case reflect.Bool:
return "boolean"
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
return "string"
}
return "array"
case reflect.Map:
return "object"
default:
return "string"
}
}