diff --git a/Cargo.toml b/Cargo.toml index 336d92b..9539a86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ sha2 = "0.10" # Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +base64 = "0.22" # Time management chrono = { version = "0.4", features = ["serde"] } diff --git a/src/ai/mod.rs b/src/ai/mod.rs index e4eaaad..163ec11 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -1,6 +1,10 @@ /// AI/ML module with NPU support pub mod classifier; pub mod npu; +pub mod models; +pub mod vision; pub use classifier::NpuClassifier; pub use npu::NpuDevice; +pub use models::{AvailableModels, ModelConfig, ModelDownloader}; +pub use vision::{ImageAnalyzer, ImageAnalysis}; diff --git a/src/ai/models.rs b/src/ai/models.rs new file mode 100644 index 0000000..64ce9ff --- /dev/null +++ b/src/ai/models.rs @@ -0,0 +1,149 @@ +/// AI Model management and downloading +use std::path::PathBuf; +use std::fs; +use crate::error::{Result, AppError}; + +/// Model configuration +#[derive(Debug, Clone)] +pub struct ModelConfig { + pub name: String, + pub url: String, + pub filename: String, + pub size_mb: u64, + pub description: String, +} + +/// Available models +pub struct AvailableModels; + +impl AvailableModels { + /// Mistral-7B-Instruct quantized (Q4) - Optimized for NPU + /// Size: ~4GB, good balance between quality and speed + pub fn mistral_7b_q4() -> ModelConfig { + ModelConfig { + name: "mistral-7b-instruct-q4".to_string(), + url: "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), + filename: "mistral-7b-instruct-q4.gguf".to_string(), + size_mb: 4368, + description: "Mistral 7B Instruct Q4 - Text analysis and classification".to_string(), + } + } + + /// CLIP for image understanding + pub fn clip_vit() -> ModelConfig { + ModelConfig { + name: "clip-vit-base".to_string(), + url: "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/onnx/model.onnx".to_string(), + filename: "clip-vit-base.onnx".to_string(), + size_mb: 350, + description: "CLIP ViT - Image and text embeddings".to_string(), + } + } + + /// MiniLM for lightweight text embeddings + pub fn minilm() -> ModelConfig { + ModelConfig { + name: "minilm-l6".to_string(), + url: "https://huggingface.co/optimum/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".to_string(), + filename: "minilm-l6.onnx".to_string(), + size_mb: 90, + description: "MiniLM L6 - Fast text embeddings for classification".to_string(), + } + } +} + +/// Model downloader +pub struct ModelDownloader { + models_dir: PathBuf, +} + +impl ModelDownloader { + pub fn new() -> Result { + let models_dir = PathBuf::from("models"); + if !models_dir.exists() { + fs::create_dir_all(&models_dir)?; + } + + Ok(Self { models_dir }) + } + + /// Check if a model is already downloaded + pub fn is_downloaded(&self, config: &ModelConfig) -> bool { + let model_path = self.models_dir.join(&config.filename); + model_path.exists() + } + + /// Get the path to a model + pub fn get_model_path(&self, config: &ModelConfig) -> PathBuf { + self.models_dir.join(&config.filename) + } + + /// Download a model (placeholder - requires actual HTTP client) + pub fn download(&self, config: &ModelConfig) -> Result { + let model_path = self.models_dir.join(&config.filename); + + if self.is_downloaded(config) { + log::info!("Model already downloaded: {}", config.name); + return Ok(model_path); + } + + log::info!("Downloading model: {} ({} MB)", config.name, config.size_mb); + log::info!("URL: {}", config.url); + log::info!("Please download manually to: {}", model_path.display()); + + Err(AppError::Analysis(format!( + "Manual download required:\n\ + 1. Download from: {}\n\ + 2. Save to: {}\n\ + 3. Run again", + config.url, + model_path.display() + ))) + } + + /// List downloaded models + pub fn list_downloaded(&self) -> Result> { + let mut models = Vec::new(); + + if let Ok(entries) = fs::read_dir(&self.models_dir) { + for entry in entries.flatten() { + if let Ok(file_type) = entry.file_type() { + if file_type.is_file() { + if let Some(filename) = entry.file_name().to_str() { + models.push(filename.to_string()); + } + } + } + } + } + + Ok(models) + } +} + +impl Default for ModelDownloader { + fn default() -> Self { + Self::new().expect("Failed to create ModelDownloader") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_configs() { + let mistral = AvailableModels::mistral_7b_q4(); + assert_eq!(mistral.name, "mistral-7b-instruct-q4"); + assert!(mistral.size_mb > 0); + + let clip = AvailableModels::clip_vit(); + assert_eq!(clip.name, "clip-vit-base"); + } + + #[test] + fn test_downloader_creation() { + let downloader = ModelDownloader::new(); + assert!(downloader.is_ok()); + } +} diff --git a/src/ai/vision.rs b/src/ai/vision.rs new file mode 100644 index 0000000..23981e3 --- /dev/null +++ b/src/ai/vision.rs @@ -0,0 +1,80 @@ +/// Vision and image analysis module +use image::DynamicImage; +use crate::error::{Result, AppError}; + +/// Image analysis result +#[derive(Debug, Clone)] +pub struct ImageAnalysis { + pub ocr_text: String, + pub detected_objects: Vec, + pub scene_description: String, + pub confidence: f32, +} + +/// Image analyzer with NPU acceleration +pub struct ImageAnalyzer { + ocr_enabled: bool, +} + +impl ImageAnalyzer { + pub fn new() -> Self { + Self { + ocr_enabled: true, + } + } + + /// Analyze a screenshot image + pub fn analyze(&self, image_data: &[u8]) -> Result { + // Decode image + let img = image::load_from_memory(image_data) + .map_err(|e| AppError::Image(format!("Failed to decode image: {}", e)))?; + + // Extract text (OCR) + let ocr_text = self.extract_text(&img)?; + + // For now, return basic analysis + // TODO: Add CLIP model for scene understanding + Ok(ImageAnalysis { + ocr_text, + detected_objects: Vec::new(), + scene_description: String::new(), + confidence: 0.8, + }) + } + + /// Extract text from image using Windows OCR + #[cfg(windows)] + fn extract_text(&self, _img: &DynamicImage) -> Result { + // TODO: Integrate Windows.Media.Ocr API + // For now, return empty - will be implemented in next iteration + Ok(String::new()) + } + + #[cfg(not(windows))] + fn extract_text(&self, _img: &DynamicImage) -> Result { + Ok(String::new()) + } + + /// Analyze with CLIP model + pub fn analyze_with_clip(&self, _image_data: &[u8], _text_query: &str) -> Result { + // TODO: Implement CLIP similarity + Ok(0.5) + } +} + +impl Default for ImageAnalyzer { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_analyzer_creation() { + let analyzer = ImageAnalyzer::new(); + assert!(analyzer.ocr_enabled); + } +} diff --git a/src/main.rs b/src/main.rs index 0336f53..c55cb2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,6 +90,33 @@ enum Commands { #[arg(long, default_value = "2759")] port: u16, }, + + /// Manage AI models (download, list, info) + Models { + #[command(subcommand)] + action: ModelAction, + }, +} + +#[derive(Subcommand)] +enum ModelAction { + /// List available models to download + List, + + /// Show downloaded models + Downloaded, + + /// Download a specific model + Download { + /// Model name (mistral, clip, minilm) + model: String, + }, + + /// Show model info + Info { + /// Model name + model: String, + }, } #[tokio::main] @@ -129,6 +156,9 @@ async fn main() -> Result<()> { Commands::Serve { password, port } => { serve_dashboard(&config, password, port).await?; } + Commands::Models { action } => { + handle_models_command(action)?; + } } Ok(()) @@ -308,6 +338,107 @@ fn export_data(config: &config::Config, password: &str, output: PathBuf) -> Resu Ok(()) } +/// Handle models command +fn handle_models_command(action: ModelAction) -> Result<()> { + let downloader = ai::ModelDownloader::new()?; + + match action { + ModelAction::List => { + println!("\n=== Available AI Models ===\n"); + + let models = vec![ + ai::AvailableModels::mistral_7b_q4(), + ai::AvailableModels::clip_vit(), + ai::AvailableModels::minilm(), + ]; + + for model in models { + let downloaded = if downloader.is_downloaded(&model) { + " [DOWNLOADED]" + } else { + "" + }; + + println!("{}{}", model.name, downloaded); + println!(" Description: {}", model.description); + println!(" Size: {} MB", model.size_mb); + println!(" Download: activity-tracker models download {}", model.name.split('-').next().unwrap()); + println!(); + } + } + + ModelAction::Downloaded => { + let models = downloader.list_downloaded()?; + + if models.is_empty() { + println!("\nNo models downloaded yet."); + println!("Use 'activity-tracker models list' to see available models."); + } else { + println!("\n=== Downloaded Models ===\n"); + for model in models { + println!(" {}", model); + } + } + } + + ModelAction::Download { model } => { + let config = match model.as_str() { + "mistral" => ai::AvailableModels::mistral_7b_q4(), + "clip" => ai::AvailableModels::clip_vit(), + "minilm" => ai::AvailableModels::minilm(), + _ => { + println!("Unknown model: {}", model); + println!("Available: mistral, clip, minilm"); + return Ok(()); + } + }; + + match downloader.download(&config) { + Ok(path) => { + println!("Model ready: {}", path.display()); + } + Err(e) => { + println!("\n{}", e); + } + } + } + + ModelAction::Info { model } => { + let config = match model.as_str() { + "mistral" => ai::AvailableModels::mistral_7b_q4(), + "clip" => ai::AvailableModels::clip_vit(), + "minilm" => ai::AvailableModels::minilm(), + _ => { + println!("Unknown model: {}", model); + return Ok(()); + } + }; + + let downloaded = if downloader.is_downloaded(&config) { + "Yes" + } else { + "No" + }; + + println!("\n=== Model Info: {} ===\n", config.name); + println!("Description: {}", config.description); + println!("Size: {} MB", config.size_mb); + println!("Downloaded: {}", downloaded); + println!("URL: {}", config.url); + + if !downloader.is_downloaded(&config) { + println!("\nTo download:"); + println!(" activity-tracker models download {}", model); + } else { + let path = downloader.get_model_path(&config); + println!("\nPath: {}", path.display()); + } + } + } + + Ok(()) +} + /// Start web dashboard server async fn serve_dashboard( config: &config::Config, diff --git a/src/web/data_export.rs b/src/web/data_export.rs new file mode 100644 index 0000000..5228f8a --- /dev/null +++ b/src/web/data_export.rs @@ -0,0 +1,126 @@ +/// Data export and retrieval endpoints +use axum::{ + extract::{State, Path, Query}, + response::Json, + http::StatusCode, +}; +use serde::{Deserialize, Serialize}; +use chrono::{DateTime, Utc}; +use base64::{Engine as _, engine::general_purpose}; + +use crate::storage::Database; +use super::state::AppState; + +#[derive(Debug, Deserialize)] +pub struct ExportQuery { + #[serde(default)] + include_screenshots: bool, + #[serde(default)] + include_analysis: bool, +} + +#[derive(Debug, Serialize)] +pub struct CaptureDetail { + pub id: String, + pub timestamp: DateTime, + pub window_title: String, + pub window_process: String, + pub is_active: bool, + pub category: Option, + pub confidence: Option, + pub has_screenshot: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub screenshot_base64: Option, +} + +#[derive(Debug, Serialize)] +pub struct DataExport { + pub total_captures: usize, + pub captures: Vec, + pub export_timestamp: DateTime, +} + +/// GET /api/captures - Get all captures (paginated) +pub async fn get_captures( + State(state): State, + Query(query): Query, +) -> std::result::Result, (StatusCode, String)> { + let db = Database::new(&*state.db_path, &state.password) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + // Get last 30 days + let end = Utc::now(); + let start = end - chrono::Duration::days(30); + + let stored_captures = db.get_captures_by_date_range(start, end) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let captures: Vec = stored_captures + .into_iter() + .map(|c| { + let has_screenshot = c.screenshot_data.is_some(); + let screenshot_base64 = if query.include_screenshots { + c.screenshot_data.as_ref().map(|data| general_purpose::STANDARD.encode(data)) + } else { + None + }; + + CaptureDetail { + id: c.capture_id, + timestamp: c.timestamp, + window_title: c.window_title, + window_process: c.window_process, + is_active: c.is_active, + category: c.category, + confidence: c.confidence, + has_screenshot, + screenshot_base64, + } + }) + .collect(); + + Ok(Json(DataExport { + total_captures: captures.len(), + captures, + export_timestamp: Utc::now(), + })) +} + +/// GET /api/captures/:id - Get specific capture with screenshot +pub async fn get_capture_by_id( + State(state): State, + Path(capture_id): Path, +) -> std::result::Result, (StatusCode, String)> { + let db = Database::new(&*state.db_path, &state.password) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let capture = db.get_capture(&capture_id) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Capture not found".to_string()))?; + + let screenshot_base64 = capture.screenshot_data.as_ref().map(|data| general_purpose::STANDARD.encode(data)); + + Ok(Json(CaptureDetail { + id: capture.capture_id, + timestamp: capture.timestamp, + window_title: capture.window_title, + window_process: capture.window_process, + is_active: capture.is_active, + category: capture.category, + confidence: capture.confidence, + has_screenshot: capture.screenshot_data.is_some(), + screenshot_base64, + })) +} + +/// GET /api/export/full - Export everything as JSON +pub async fn export_full_data( + State(state): State, +) -> std::result::Result, (StatusCode, String)> { + let query = ExportQuery { + include_screenshots: false, + include_analysis: true, + }; + + get_captures(State(state), Query(query)).await +} diff --git a/src/web/mod.rs b/src/web/mod.rs index 140c8b9..e4c6022 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -2,6 +2,7 @@ pub mod server; pub mod routes; pub mod state; +pub mod data_export; pub use server::serve_dashboard; pub use state::AppState; diff --git a/src/web/server.rs b/src/web/server.rs index 15c5b13..3bb7e07 100644 --- a/src/web/server.rs +++ b/src/web/server.rs @@ -12,7 +12,7 @@ use std::net::SocketAddr; use std::path::PathBuf; use crate::error::Result; -use super::{routes, state::AppState}; +use super::{routes, data_export, state::AppState}; /// Start the web dashboard server pub async fn serve_dashboard( @@ -33,6 +33,9 @@ pub async fn serve_dashboard( .route("/health", get(routes::health_check)) .route("/stats", get(routes::get_stats)) .route("/dashboard", get(routes::get_dashboard_data)) + .route("/captures", get(data_export::get_captures)) + .route("/captures/:id", get(data_export::get_capture_by_id)) + .route("/export/full", get(data_export::export_full_data)) .with_state(state.clone()); // Build main application