diff --git a/src/ai/models.rs b/src/ai/models.rs index be3a329..9ee8d58 100644 --- a/src/ai/models.rs +++ b/src/ai/models.rs @@ -14,52 +14,63 @@ pub struct ModelConfig { pub description: String, } -/// Available models +/// Available models - ONNX optimized for NPU/DirectML pub struct AvailableModels; impl AvailableModels { - /// Mistral-7B-Instruct quantized (Q4) - Optimized for NPU - /// Size: ~4GB, good balance between quality and speed - pub fn mistral_7b_q4() -> ModelConfig { + /// Phi-3 Mini DirectML - Microsoft's official DirectML-optimized model + /// Size: ~2.4GB INT4, optimized for NPU inference with DirectML + pub fn phi3_mini_directml() -> ModelConfig { ModelConfig { - name: "mistral-7b-instruct-q4".to_string(), - url: "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), - filename: "mistral-7b-instruct-q4.gguf".to_string(), - size_mb: 4368, - description: "Mistral 7B Instruct Q4 - Text analysis and classification".to_string(), + name: "phi3-mini-directml".to_string(), + url: "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/directml/directml-int4-awq-block-128/phi3-mini-4k-instruct-directml-int4-awq-block-128.onnx".to_string(), + filename: "phi3-mini-directml.onnx".to_string(), + size_mb: 2400, + description: "Phi-3 Mini INT4 DirectML - Microsoft LLM optimized for NPU/DirectML".to_string(), } } - /// CLIP for image understanding - pub fn clip_vit() -> ModelConfig { - ModelConfig { - name: "clip-vit-base".to_string(), - url: "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/onnx/model.onnx".to_string(), - filename: "clip-vit-base.onnx".to_string(), - size_mb: 350, - description: "CLIP ViT - Image and text embeddings".to_string(), - } - } - - /// MiniLM for lightweight text embeddings - pub fn minilm() -> ModelConfig { - ModelConfig { - name: "minilm-l6".to_string(), - url: "https://huggingface.co/optimum/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".to_string(), - filename: "minilm-l6.onnx".to_string(), - size_mb: 90, - description: "MiniLM L6 - Fast text embeddings for classification".to_string(), - } - } - - /// DistilBERT for sequence classification (lightweight, ONNX optimized) + /// DistilBERT ONNX (Xenova repo - known to work) pub fn distilbert_base() -> ModelConfig { ModelConfig { - name: "distilbert-base".to_string(), - url: "https://huggingface.co/optimum/distilbert-base-uncased/resolve/main/onnx/model.onnx".to_string(), + name: "distilbert-base-onnx".to_string(), + url: "https://huggingface.co/Xenova/distilbert-base-uncased/resolve/main/onnx/model.onnx".to_string(), filename: "distilbert-base.onnx".to_string(), size_mb: 265, - description: "DistilBERT Base - Fast text classification with ONNX".to_string(), + description: "DistilBERT Base - Fast text classification with ONNX/DirectML".to_string(), + } + } + + /// MiniLM for lightweight text embeddings (Xenova repo) + pub fn minilm() -> ModelConfig { + ModelConfig { + name: "minilm-l6-onnx".to_string(), + url: "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".to_string(), + filename: "minilm-l6.onnx".to_string(), + size_mb: 90, + description: "MiniLM L6 - Lightweight text embeddings for NPU".to_string(), + } + } + + /// BERT-base for text classification (Xenova repo) + pub fn bert_base_onnx() -> ModelConfig { + ModelConfig { + name: "bert-base-onnx".to_string(), + url: "https://huggingface.co/Xenova/bert-base-uncased/resolve/main/onnx/model.onnx".to_string(), + filename: "bert-base.onnx".to_string(), + size_mb: 420, + description: "BERT Base - Robust text classification with ONNX/NPU".to_string(), + } + } + + /// CLIP for image understanding (ONNX) + pub fn clip_vit() -> ModelConfig { + ModelConfig { + name: "clip-vit-base-onnx".to_string(), + url: "https://huggingface.co/Xenova/clip-vit-base-patch32/resolve/main/onnx/model.onnx".to_string(), + filename: "clip-vit-base.onnx".to_string(), + size_mb: 350, + description: "CLIP ViT - Image and text embeddings for NPU".to_string(), } } } @@ -92,6 +103,11 @@ impl ModelDownloader { /// Download a model using HTTP pub fn download(&self, config: &ModelConfig) -> Result { + self.download_with_token(config, None) + } + + /// Download a model using HTTP with optional HuggingFace token + pub fn download_with_token(&self, config: &ModelConfig, hf_token: Option<&str>) -> Result { let model_path = self.models_dir.join(&config.filename); if self.is_downloaded(config) { @@ -106,9 +122,18 @@ impl ModelDownloader { // Download with progress println!("\nDownloading {} ({} MB)...", config.name, config.size_mb); - let response = ureq::get(&config.url) - .call() - .map_err(|e| AppError::Analysis(format!("Failed to download: {}", e)))?; + // Build request with optional authentication + let response = if let Some(token) = hf_token { + log::info!("Using HuggingFace authentication"); + ureq::get(&config.url) + .header("Authorization", &format!("Bearer {}", token)) + .call() + .map_err(|e| AppError::Analysis(format!("Failed to download: {}", e)))? + } else { + ureq::get(&config.url) + .call() + .map_err(|e| AppError::Analysis(format!("Failed to download: {}", e)))? + }; let mut file = fs::File::create(&model_path)?; let mut reader = response.into_body().into_reader(); @@ -173,12 +198,14 @@ mod tests { #[test] fn test_model_configs() { - let mistral = AvailableModels::mistral_7b_q4(); - assert_eq!(mistral.name, "mistral-7b-instruct-q4"); - assert!(mistral.size_mb > 0); + let phi3 = AvailableModels::phi3_mini_directml(); + assert_eq!(phi3.name, "phi3-mini-directml"); + assert!(phi3.size_mb > 0); + assert!(phi3.filename.ends_with(".onnx")); - let clip = AvailableModels::clip_vit(); - assert_eq!(clip.name, "clip-vit-base"); + let distilbert = AvailableModels::distilbert_base(); + assert_eq!(distilbert.name, "distilbert-base-onnx"); + assert!(distilbert.filename.ends_with(".onnx")); } #[test] diff --git a/src/main.rs b/src/main.rs index 2c5e569..037c1c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -108,8 +108,12 @@ enum ModelAction { /// Download a specific model Download { - /// Model name (mistral, clip, minilm) + /// Model name (mistral, clip, minilm, distilbert) model: String, + + /// HuggingFace API token for authentication (optional) + #[arg(long)] + token: Option, }, /// Show model info @@ -347,10 +351,11 @@ fn handle_models_command(action: ModelAction) -> Result<()> { println!("\n=== Available AI Models ===\n"); let models = vec![ + ai::AvailableModels::phi3_mini_directml(), ai::AvailableModels::distilbert_base(), + ai::AvailableModels::bert_base_onnx(), ai::AvailableModels::minilm(), ai::AvailableModels::clip_vit(), - ai::AvailableModels::mistral_7b_q4(), ]; for model in models { @@ -382,35 +387,41 @@ fn handle_models_command(action: ModelAction) -> Result<()> { } } - ModelAction::Download { model } => { + ModelAction::Download { model, token } => { let config = match model.as_str() { + "phi3" | "phi-3" => ai::AvailableModels::phi3_mini_directml(), "distilbert" => ai::AvailableModels::distilbert_base(), + "bert" => ai::AvailableModels::bert_base_onnx(), "minilm" => ai::AvailableModels::minilm(), "clip" => ai::AvailableModels::clip_vit(), - "mistral" => ai::AvailableModels::mistral_7b_q4(), _ => { println!("Unknown model: {}", model); - println!("Available: distilbert, minilm, clip, mistral"); + println!("Available: phi3, distilbert, bert, minilm, clip"); return Ok(()); } }; - match downloader.download(&config) { + match downloader.download_with_token(&config, token.as_deref()) { Ok(path) => { println!("Model ready: {}", path.display()); } Err(e) => { - println!("\n{}", e); + println!("\nDownload failed: {}", e); + if token.is_none() { + println!("Hint: Some models may require HuggingFace authentication."); + println!("Try: activity-tracker models download {} --token YOUR_HF_TOKEN", model); + } } } } ModelAction::Info { model } => { let config = match model.as_str() { + "phi3" | "phi-3" => ai::AvailableModels::phi3_mini_directml(), "distilbert" => ai::AvailableModels::distilbert_base(), + "bert" => ai::AvailableModels::bert_base_onnx(), "minilm" => ai::AvailableModels::minilm(), "clip" => ai::AvailableModels::clip_vit(), - "mistral" => ai::AvailableModels::mistral_7b_q4(), _ => { println!("Unknown model: {}", model); return Ok(()); diff --git a/tests/onnx_npu_test.rs b/tests/onnx_npu_test.rs new file mode 100644 index 0000000..009bd59 --- /dev/null +++ b/tests/onnx_npu_test.rs @@ -0,0 +1,87 @@ +/// Test ONNX model loading with NPU/DirectML +use activity_tracker::ai::NpuDevice; +use std::path::PathBuf; + +#[test] +fn test_onnx_model_exists() { + let model_path = PathBuf::from("models/distilbert-base.onnx"); + + if model_path.exists() { + println!("✅ Model file found: {}", model_path.display()); + let metadata = std::fs::metadata(&model_path).unwrap(); + println!(" Size: {} MB", metadata.len() / 1_000_000); + } else { + println!("⚠️ Model not found. Download it first:"); + println!(" cargo run --release -- models download distilbert"); + } +} + +#[test] +fn test_npu_session_with_onnx() { + let npu = NpuDevice::detect(); + let model_path = PathBuf::from("models/distilbert-base.onnx"); + + println!("\n=== ONNX Model Loading Test ==="); + println!("NPU Device: {}", npu.device_name()); + println!("NPU Available: {}", npu.is_available()); + + #[cfg(windows)] + { + assert!(npu.is_available(), "NPU should be available on Intel Core Ultra"); + + if model_path.exists() { + println!("\n📦 Model: {}", model_path.display()); + + match npu.create_session(model_path.to_str().unwrap()) { + Ok(session) => { + println!("✅ ONNX session created successfully with DirectML!"); + println!(" Inputs: {:?}", session.inputs.len()); + println!(" Outputs: {:?}", session.outputs.len()); + + // Print input details + for (i, input) in session.inputs.iter().enumerate() { + println!(" Input {}: {}", i, input.name); + } + + // Print output details + for (i, output) in session.outputs.iter().enumerate() { + println!(" Output {}: {}", i, output.name); + } + } + Err(e) => { + println!("❌ Failed to create session: {}", e); + panic!("Session creation failed"); + } + } + } else { + println!("⚠️ Skipping test - model not downloaded"); + println!(" Run: cargo run --release -- models download distilbert"); + } + } + + #[cfg(not(windows))] + { + println!("⚠️ NPU/DirectML only available on Windows"); + } +} + +#[test] +fn test_npu_performance_info() { + let npu = NpuDevice::detect(); + + println!("\n=== NPU Performance Information ==="); + println!("Device: {}", npu.device_name()); + println!("Status: {}", if npu.is_available() { "Ready" } else { "Not Available" }); + + #[cfg(windows)] + { + println!("\nDirectML Configuration:"); + println!(" • Execution Provider: DirectML"); + println!(" • Hardware: Intel AI Boost NPU"); + println!(" • API: Windows Machine Learning"); + println!(" • Quantization: INT4/FP16 support"); + println!(" • Expected speedup: 10-30x vs CPU"); + } + + println!("\n✅ NPU info test complete"); +}