Add ONNX inference with tokenization support

Implement complete ONNX inference pipeline with NPU acceleration:
- Add OnnxClassifier for text classification via ONNX Runtime
- Integrate HuggingFace tokenizers for text preprocessing
- Support tokenization with padding/truncation
- Implement classification with probabilities (softmax)
- Add distilbert_tokenizer() model config for download

Features:
- Tokenize text input to input_ids and attention_mask
- Run NPU-accelerated inference via DirectML
- Extract logits and convert to probabilities
- RefCell pattern for session management

Note: Current blocker is Windows linker C runtime mismatch between
esaxx-rs (static MT) and ONNX Runtime (dynamic MD). Code compiles
but linking fails. Resolution in progress.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Augustin 2025-10-16 19:16:51 +02:00
parent e17a4dd9d0
commit e528b10a0a
3 changed files with 170 additions and 0 deletions

157
src/ai/inference.rs Normal file
View File

@ -0,0 +1,157 @@
/// ONNX inference with NPU acceleration
use crate::ai::NpuDevice;
use crate::error::{Result, AppError};
use ndarray::Array2;
use ort::session::Session;
use ort::value::Value;
use tokenizers::Tokenizer;
/// Text classifier using ONNX model with NPU
pub struct OnnxClassifier {
session: std::cell::RefCell<Session>,
tokenizer: Tokenizer,
npu_device: NpuDevice,
max_length: usize,
}
impl OnnxClassifier {
/// Create a new ONNX classifier with NPU acceleration
pub fn new(model_path: &str, tokenizer_path: &str) -> Result<Self> {
let npu_device = NpuDevice::detect();
log::info!("Loading ONNX model: {}", model_path);
log::info!("NPU Device: {} (available: {})", npu_device.device_name(), npu_device.is_available());
// Create ONNX session with NPU if available
let session = npu_device.create_session(model_path)?;
log::info!("Loading tokenizer: {}", tokenizer_path);
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| AppError::Analysis(format!("Failed to load tokenizer: {}", e)))?;
Ok(Self {
session: std::cell::RefCell::new(session),
tokenizer,
npu_device,
max_length: 128,
})
}
/// Check if NPU is being used
pub fn is_using_npu(&self) -> bool {
self.npu_device.is_available()
}
/// Get device information
pub fn device_info(&self) -> String {
self.npu_device.device_name().to_string()
}
/// Tokenize input text
fn tokenize(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
let encoding = self.tokenizer
.encode(text, true)
.map_err(|e| AppError::Analysis(format!("Tokenization failed: {}", e)))?;
let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let mut attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
// Pad or truncate to max_length
if input_ids.len() > self.max_length {
input_ids.truncate(self.max_length);
attention_mask.truncate(self.max_length);
} else {
let padding = self.max_length - input_ids.len();
input_ids.extend(vec![0; padding]);
attention_mask.extend(vec![0; padding]);
}
Ok((input_ids, attention_mask))
}
/// Run inference on input text
pub fn predict(&self, text: &str) -> Result<Vec<f32>> {
// Tokenize input
let (input_ids, attention_mask) = self.tokenize(text)?;
// Convert to ndarray (batch_size=1, seq_length=max_length)
let input_ids_array = Array2::from_shape_vec(
(1, self.max_length),
input_ids,
).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?;
let attention_mask_array = Array2::from_shape_vec(
(1, self.max_length),
attention_mask,
).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?;
// Create ONNX values
let input_ids_value = Value::from_array(input_ids_array)
.map_err(|e| AppError::Analysis(format!("Failed to create input_ids value: {}", e)))?;
let attention_mask_value = Value::from_array(attention_mask_array)
.map_err(|e| AppError::Analysis(format!("Failed to create attention_mask value: {}", e)))?;
// Run inference
let mut session = self.session.borrow_mut();
let outputs = session
.run(ort::inputs!["input_ids" => input_ids_value, "attention_mask" => attention_mask_value])
.map_err(|e| AppError::Analysis(format!("Inference failed: {}", e)))?;
// Extract logits
let logits = outputs["logits"]
.try_extract_tensor::<f32>()
.map_err(|e| AppError::Analysis(format!("Failed to extract logits: {}", e)))?;
// Convert to Vec<f32>
let (_shape, data) = logits;
let logits_vec: Vec<f32> = data.to_vec();
Ok(logits_vec)
}
/// Classify text and return the predicted class index
pub fn classify(&self, text: &str) -> Result<usize> {
let logits = self.predict(text)?;
// Find the index of the maximum value
let predicted_class = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.ok_or_else(|| AppError::Analysis("No predictions found".to_string()))?;
Ok(predicted_class)
}
/// Classify text and return probabilities for all classes
pub fn classify_with_probabilities(&self, text: &str) -> Result<Vec<f32>> {
let logits = self.predict(text)?;
// Apply softmax to get probabilities
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probabilities: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
Ok(probabilities)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_classifier_creation() {
let model_path = "models/distilbert-base.onnx";
let tokenizer_path = "models/distilbert-tokenizer.json";
if Path::new(model_path).exists() && Path::new(tokenizer_path).exists() {
let classifier = OnnxClassifier::new(model_path, tokenizer_path);
assert!(classifier.is_ok());
}
}
}

View File

@ -3,8 +3,10 @@ pub mod classifier;
pub mod npu; pub mod npu;
pub mod models; pub mod models;
pub mod vision; pub mod vision;
pub mod inference;
pub use classifier::NpuClassifier; pub use classifier::NpuClassifier;
pub use npu::NpuDevice; pub use npu::NpuDevice;
pub use models::{AvailableModels, ModelConfig, ModelDownloader}; pub use models::{AvailableModels, ModelConfig, ModelDownloader};
pub use vision::{ImageAnalyzer, ImageAnalysis}; pub use vision::{ImageAnalyzer, ImageAnalysis};
pub use inference::OnnxClassifier;

View File

@ -41,6 +41,17 @@ impl AvailableModels {
} }
} }
/// DistilBERT Tokenizer
pub fn distilbert_tokenizer() -> ModelConfig {
ModelConfig {
name: "distilbert-tokenizer".to_string(),
url: "https://huggingface.co/Xenova/distilbert-base-uncased/resolve/main/tokenizer.json".to_string(),
filename: "distilbert-tokenizer.json".to_string(),
size_mb: 1,
description: "DistilBERT Tokenizer - Text preprocessing".to_string(),
}
}
/// MiniLM for lightweight text embeddings (Xenova repo) /// MiniLM for lightweight text embeddings (Xenova repo)
pub fn minilm() -> ModelConfig { pub fn minilm() -> ModelConfig {
ModelConfig { ModelConfig {