From e528b10a0a12c4691aa8c21e7d244568ec7d0181 Mon Sep 17 00:00:00 2001 From: Augustin Date: Thu, 16 Oct 2025 19:16:51 +0200 Subject: [PATCH] Add ONNX inference with tokenization support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement complete ONNX inference pipeline with NPU acceleration: - Add OnnxClassifier for text classification via ONNX Runtime - Integrate HuggingFace tokenizers for text preprocessing - Support tokenization with padding/truncation - Implement classification with probabilities (softmax) - Add distilbert_tokenizer() model config for download Features: - Tokenize text input to input_ids and attention_mask - Run NPU-accelerated inference via DirectML - Extract logits and convert to probabilities - RefCell pattern for session management Note: Current blocker is Windows linker C runtime mismatch between esaxx-rs (static MT) and ONNX Runtime (dynamic MD). Code compiles but linking fails. Resolution in progress. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/ai/inference.rs | 157 ++++++++++++++++++++++++++++++++++++++++++++ src/ai/mod.rs | 2 + src/ai/models.rs | 11 ++++ 3 files changed, 170 insertions(+) create mode 100644 src/ai/inference.rs diff --git a/src/ai/inference.rs b/src/ai/inference.rs new file mode 100644 index 0000000..cb3fb36 --- /dev/null +++ b/src/ai/inference.rs @@ -0,0 +1,157 @@ +/// ONNX inference with NPU acceleration +use crate::ai::NpuDevice; +use crate::error::{Result, AppError}; +use ndarray::Array2; +use ort::session::Session; +use ort::value::Value; +use tokenizers::Tokenizer; + +/// Text classifier using ONNX model with NPU +pub struct OnnxClassifier { + session: std::cell::RefCell, + tokenizer: Tokenizer, + npu_device: NpuDevice, + max_length: usize, +} + +impl OnnxClassifier { + /// Create a new ONNX classifier with NPU acceleration + pub fn new(model_path: &str, tokenizer_path: &str) -> Result { + let npu_device = NpuDevice::detect(); + + log::info!("Loading ONNX model: {}", model_path); + log::info!("NPU Device: {} (available: {})", npu_device.device_name(), npu_device.is_available()); + + // Create ONNX session with NPU if available + let session = npu_device.create_session(model_path)?; + + log::info!("Loading tokenizer: {}", tokenizer_path); + let tokenizer = Tokenizer::from_file(tokenizer_path) + .map_err(|e| AppError::Analysis(format!("Failed to load tokenizer: {}", e)))?; + + Ok(Self { + session: std::cell::RefCell::new(session), + tokenizer, + npu_device, + max_length: 128, + }) + } + + /// Check if NPU is being used + pub fn is_using_npu(&self) -> bool { + self.npu_device.is_available() + } + + /// Get device information + pub fn device_info(&self) -> String { + self.npu_device.device_name().to_string() + } + + /// Tokenize input text + fn tokenize(&self, text: &str) -> Result<(Vec, Vec)> { + let encoding = self.tokenizer + .encode(text, true) + .map_err(|e| AppError::Analysis(format!("Tokenization failed: {}", e)))?; + + let mut input_ids: Vec = encoding.get_ids().iter().map(|&x| x as i64).collect(); + let mut attention_mask: Vec = encoding.get_attention_mask().iter().map(|&x| x as i64).collect(); + + // Pad or truncate to max_length + if input_ids.len() > self.max_length { + input_ids.truncate(self.max_length); + attention_mask.truncate(self.max_length); + } else { + let padding = self.max_length - input_ids.len(); + input_ids.extend(vec![0; padding]); + attention_mask.extend(vec![0; padding]); + } + + Ok((input_ids, attention_mask)) + } + + /// Run inference on input text + pub fn predict(&self, text: &str) -> Result> { + // Tokenize input + let (input_ids, attention_mask) = self.tokenize(text)?; + + // Convert to ndarray (batch_size=1, seq_length=max_length) + let input_ids_array = Array2::from_shape_vec( + (1, self.max_length), + input_ids, + ).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?; + + let attention_mask_array = Array2::from_shape_vec( + (1, self.max_length), + attention_mask, + ).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?; + + // Create ONNX values + let input_ids_value = Value::from_array(input_ids_array) + .map_err(|e| AppError::Analysis(format!("Failed to create input_ids value: {}", e)))?; + + let attention_mask_value = Value::from_array(attention_mask_array) + .map_err(|e| AppError::Analysis(format!("Failed to create attention_mask value: {}", e)))?; + + // Run inference + let mut session = self.session.borrow_mut(); + let outputs = session + .run(ort::inputs!["input_ids" => input_ids_value, "attention_mask" => attention_mask_value]) + .map_err(|e| AppError::Analysis(format!("Inference failed: {}", e)))?; + + // Extract logits + let logits = outputs["logits"] + .try_extract_tensor::() + .map_err(|e| AppError::Analysis(format!("Failed to extract logits: {}", e)))?; + + // Convert to Vec + let (_shape, data) = logits; + let logits_vec: Vec = data.to_vec(); + + Ok(logits_vec) + } + + /// Classify text and return the predicted class index + pub fn classify(&self, text: &str) -> Result { + let logits = self.predict(text)?; + + // Find the index of the maximum value + let predicted_class = logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(idx, _)| idx) + .ok_or_else(|| AppError::Analysis("No predictions found".to_string()))?; + + Ok(predicted_class) + } + + /// Classify text and return probabilities for all classes + pub fn classify_with_probabilities(&self, text: &str) -> Result> { + let logits = self.predict(text)?; + + // Apply softmax to get probabilities + let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let exp_logits: Vec = logits.iter().map(|&x| (x - max_logit).exp()).collect(); + let sum_exp: f32 = exp_logits.iter().sum(); + let probabilities: Vec = exp_logits.iter().map(|&x| x / sum_exp).collect(); + + Ok(probabilities) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::Path; + + #[test] + fn test_classifier_creation() { + let model_path = "models/distilbert-base.onnx"; + let tokenizer_path = "models/distilbert-tokenizer.json"; + + if Path::new(model_path).exists() && Path::new(tokenizer_path).exists() { + let classifier = OnnxClassifier::new(model_path, tokenizer_path); + assert!(classifier.is_ok()); + } + } +} diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 163ec11..76b60a6 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -3,8 +3,10 @@ pub mod classifier; pub mod npu; pub mod models; pub mod vision; +pub mod inference; pub use classifier::NpuClassifier; pub use npu::NpuDevice; pub use models::{AvailableModels, ModelConfig, ModelDownloader}; pub use vision::{ImageAnalyzer, ImageAnalysis}; +pub use inference::OnnxClassifier; diff --git a/src/ai/models.rs b/src/ai/models.rs index 9ee8d58..917d8f1 100644 --- a/src/ai/models.rs +++ b/src/ai/models.rs @@ -41,6 +41,17 @@ impl AvailableModels { } } + /// DistilBERT Tokenizer + pub fn distilbert_tokenizer() -> ModelConfig { + ModelConfig { + name: "distilbert-tokenizer".to_string(), + url: "https://huggingface.co/Xenova/distilbert-base-uncased/resolve/main/tokenizer.json".to_string(), + filename: "distilbert-tokenizer.json".to_string(), + size_mb: 1, + description: "DistilBERT Tokenizer - Text preprocessing".to_string(), + } + } + /// MiniLM for lightweight text embeddings (Xenova repo) pub fn minilm() -> ModelConfig { ModelConfig {