/// Test ONNX inference with DistilBERT and NPU acceleration use activity_tracker::ai::OnnxClassifier; fn main() -> Result<(), Box> { // Initialize logger env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) .init(); println!("\n=== ONNX Inference Test with NPU ===\n"); // Model paths let model_path = "models/distilbert-base.onnx"; let vocab_path = "models/distilbert-vocab.txt"; // Check if files exist if !std::path::Path::new(model_path).exists() { eprintln!("❌ Model not found: {}", model_path); eprintln!("Run: cargo run --release -- models download distilbert"); return Ok(()); } if !std::path::Path::new(vocab_path).exists() { eprintln!("❌ Vocabulary not found: {}", vocab_path); return Ok(()); } println!("📦 Loading model and vocabulary..."); // Create classifier let classifier = match OnnxClassifier::new(model_path, vocab_path) { Ok(c) => c, Err(e) => { eprintln!("❌ Failed to create classifier: {}", e); return Err(e.into()); } }; println!("✅ Classifier created successfully!"); println!("🔧 NPU Device: {}", classifier.device_info()); println!("⚡ Using NPU: {}\n", classifier.is_using_npu()); // Test sentences let test_sentences = vec![ "This is a great movie, I really enjoyed it!", "The weather is nice today.", "I am working on a machine learning project.", "The food was terrible and the service was slow.", "Artificial intelligence is transforming the world.", ]; println!("🧪 Running inference on test sentences:\n"); for (i, sentence) in test_sentences.iter().enumerate() { println!("{}. \"{}\"", i + 1, sentence); // Get predictions match classifier.classify_with_probabilities(sentence) { Ok(probabilities) => { println!(" Probabilities:"); for (class_idx, prob) in probabilities.iter().enumerate().take(5) { println!(" Class {}: {:.4} ({:.1}%)", class_idx, prob, prob * 100.0); } // Get top prediction if let Some((top_class, top_prob)) = probabilities .iter() .enumerate() .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) { println!(" ✨ Top prediction: Class {} ({:.1}%)", top_class, top_prob * 100.0); } } Err(e) => { eprintln!(" ❌ Prediction failed: {}", e); } } println!(); } println!("✅ Inference test completed successfully!"); println!("\n=== Test Summary ==="); println!("• NPU Acceleration: {}", if classifier.is_using_npu() { "Enabled ⚡" } else { "Disabled (CPU fallback)" }); println!("• Model: DistilBERT (ONNX)"); println!("• Device: {}", classifier.device_info()); println!("• Sentences tested: {}", test_sentences.len()); Ok(()) }