/// ONNX inference with NPU acceleration use crate::ai::NpuDevice; use crate::error::{Result, AppError}; use ndarray::Array2; use ort::session::Session; use ort::value::Value; use tokenizers::Tokenizer; /// Text classifier using ONNX model with NPU pub struct OnnxClassifier { session: std::cell::RefCell, tokenizer: Tokenizer, npu_device: NpuDevice, max_length: usize, } impl OnnxClassifier { /// Create a new ONNX classifier with NPU acceleration pub fn new(model_path: &str, tokenizer_path: &str) -> Result { let npu_device = NpuDevice::detect(); log::info!("Loading ONNX model: {}", model_path); log::info!("NPU Device: {} (available: {})", npu_device.device_name(), npu_device.is_available()); // Create ONNX session with NPU if available let session = npu_device.create_session(model_path)?; log::info!("Loading tokenizer: {}", tokenizer_path); let tokenizer = Tokenizer::from_file(tokenizer_path) .map_err(|e| AppError::Analysis(format!("Failed to load tokenizer: {}", e)))?; Ok(Self { session: std::cell::RefCell::new(session), tokenizer, npu_device, max_length: 128, }) } /// Check if NPU is being used pub fn is_using_npu(&self) -> bool { self.npu_device.is_available() } /// Get device information pub fn device_info(&self) -> String { self.npu_device.device_name().to_string() } /// Tokenize input text fn tokenize(&self, text: &str) -> Result<(Vec, Vec)> { let encoding = self.tokenizer .encode(text, true) .map_err(|e| AppError::Analysis(format!("Tokenization failed: {}", e)))?; let mut input_ids: Vec = encoding.get_ids().iter().map(|&x| x as i64).collect(); let mut attention_mask: Vec = encoding.get_attention_mask().iter().map(|&x| x as i64).collect(); // Pad or truncate to max_length if input_ids.len() > self.max_length { input_ids.truncate(self.max_length); attention_mask.truncate(self.max_length); } else { let padding = self.max_length - input_ids.len(); input_ids.extend(vec![0; padding]); attention_mask.extend(vec![0; padding]); } Ok((input_ids, attention_mask)) } /// Run inference on input text pub fn predict(&self, text: &str) -> Result> { // Tokenize input let (input_ids, attention_mask) = self.tokenize(text)?; // Convert to ndarray (batch_size=1, seq_length=max_length) let input_ids_array = Array2::from_shape_vec( (1, self.max_length), input_ids, ).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?; let attention_mask_array = Array2::from_shape_vec( (1, self.max_length), attention_mask, ).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?; // Create ONNX values let input_ids_value = Value::from_array(input_ids_array) .map_err(|e| AppError::Analysis(format!("Failed to create input_ids value: {}", e)))?; let attention_mask_value = Value::from_array(attention_mask_array) .map_err(|e| AppError::Analysis(format!("Failed to create attention_mask value: {}", e)))?; // Run inference let mut session = self.session.borrow_mut(); let outputs = session .run(ort::inputs!["input_ids" => input_ids_value, "attention_mask" => attention_mask_value]) .map_err(|e| AppError::Analysis(format!("Inference failed: {}", e)))?; // Extract logits let logits = outputs["logits"] .try_extract_tensor::() .map_err(|e| AppError::Analysis(format!("Failed to extract logits: {}", e)))?; // Convert to Vec let (_shape, data) = logits; let logits_vec: Vec = data.to_vec(); Ok(logits_vec) } /// Classify text and return the predicted class index pub fn classify(&self, text: &str) -> Result { let logits = self.predict(text)?; // Find the index of the maximum value let predicted_class = logits .iter() .enumerate() .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) .map(|(idx, _)| idx) .ok_or_else(|| AppError::Analysis("No predictions found".to_string()))?; Ok(predicted_class) } /// Classify text and return probabilities for all classes pub fn classify_with_probabilities(&self, text: &str) -> Result> { let logits = self.predict(text)?; // Apply softmax to get probabilities let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max); let exp_logits: Vec = logits.iter().map(|&x| (x - max_logit).exp()).collect(); let sum_exp: f32 = exp_logits.iter().sum(); let probabilities: Vec = exp_logits.iter().map(|&x| x / sum_exp).collect(); Ok(probabilities) } } #[cfg(test)] mod tests { use super::*; use std::path::Path; #[test] fn test_classifier_creation() { let model_path = "models/distilbert-base.onnx"; let tokenizer_path = "models/distilbert-tokenizer.json"; if Path::new(model_path).exists() && Path::new(tokenizer_path).exists() { let classifier = OnnxClassifier::new(model_path, tokenizer_path); assert!(classifier.is_ok()); } } }