package agent import ( "bytes" "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "github.com/muyue/muyue/internal/config" ) type ImageGenerationTool struct { apiKey string baseURL string model string saveDir string } func NewImageGenerationTool(cfg *config.MuyueConfig) (*ImageGenerationTool, error) { configDir, err := config.ConfigDir() if err != nil { return nil, err } saveDir := filepath.Join(configDir, "images") if err := os.MkdirAll(saveDir, 0755); err != nil { return nil, fmt.Errorf("creating images dir: %w", err) } var apiKey, baseURL, model string for _, p := range cfg.AI.Providers { if p.Active { apiKey = p.APIKey baseURL = p.BaseURL model = p.Model break } } if baseURL == "" { baseURL = "https://api.openai.com/v1" } return &ImageGenerationTool{ apiKey: apiKey, baseURL: strings.TrimRight(baseURL, "/"), model: model, saveDir: saveDir, }, nil } func (t *ImageGenerationTool) Name() string { return "generate_image" } func (t *ImageGenerationTool) Description() string { return "Generate an image from a text prompt using DALL-E or compatible API. Returns a local URL to the generated image." } func (t *ImageGenerationTool) Parameters() map[string]interface{} { return map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "prompt": map[string]interface{}{ "type": "string", "description": "Description of the image to generate", }, "size": map[string]interface{}{ "type": "string", "description": "Image size: 1024x1024, 1024x1792, or 1792x1024", "default": "1024x1024", }, "style": map[string]interface{}{ "type": "string", "description": "Style: vivid or natural", "default": "vivid", }, }, "required": []string{"prompt"}, } } func (t *ImageGenerationTool) Execute(args map[string]interface{}) (string, error) { prompt, _ := args["prompt"].(string) if prompt == "" { return "", fmt.Errorf("prompt is required") } size, _ := args["size"].(string) if size == "" { size = "1024x1024" } style, _ := args["style"].(string) if style == "" { style = "vivid" } reqBody := map[string]interface{}{ "model": "dall-e-3", "prompt": prompt, "size": size, "style": style, "n": 1, } bodyBytes, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal request: %w", err) } url := t.baseURL + "/images/generations" req, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes)) if err != nil { return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") if t.apiKey != "" { req.Header.Set("Authorization", "Bearer "+t.apiKey) } client := &http.Client{Timeout: 120 * time.Second} resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("send request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("read response: %w", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, string(respBody)) } var genResp struct { Data []struct { URL string `json:"url"` B64JSON string `json:"b64_json"` RevisedPrompt string `json:"revised_prompt"` } `json:"data"` } if err := json.Unmarshal(respBody, &genResp); err != nil { return "", fmt.Errorf("parse response: %w", err) } if len(genResp.Data) == 0 { return "", fmt.Errorf("no image returned") } imgData := genResp.Data[0] filename := fmt.Sprintf("img-%d.png", time.Now().UnixNano()) localPath := filepath.Join(t.saveDir, filename) if imgData.B64JSON != "" { return "", fmt.Errorf("base64 response not yet supported") } if imgData.URL != "" { if err := t.downloadImage(imgData.URL, localPath); err != nil { return "", fmt.Errorf("download image: %w", err) } } result := map[string]interface{}{ "url": "/api/images/" + filename, "revised_prompt": imgData.RevisedPrompt, "size": size, } resultJSON, _ := json.Marshal(result) return string(resultJSON), nil } func (t *ImageGenerationTool) downloadImage(url, localPath string) error { resp, err := http.Get(url) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("download failed: %d", resp.StatusCode) } f, err := os.Create(localPath) if err != nil { return err } defer f.Close() _, err = io.Copy(f, resp.Body) return err }