activity-tracker/src/ai/inference.rs
Augustin b61c8e31a8 Fix: Replace tokenizers crate with custom SimpleTokenizer
Resolve Windows linker C runtime mismatch by implementing a custom
tokenizer that doesn't depend on esaxx-rs (which uses static runtime).

Changes:
- Remove tokenizers crate dependency (caused MT/MD conflict)
- Add custom SimpleTokenizer in src/ai/tokenizer.rs
  - Loads vocab.txt files directly
  - Implements WordPiece-style subword tokenization
  - Pure Rust, no C++ dependencies
  - Handles [CLS], [SEP], [PAD], [UNK] special tokens

- Update OnnxClassifier to use SimpleTokenizer
- Update ModelConfig to use vocab.txt instead of tokenizer.json
- Rename distilbert_tokenizer() to distilbert_vocab()

Build status:
 Compiles successfully
 Links without C runtime conflicts
 Executable works correctly
 All previous functionality preserved

This resolves the LNK2038 error completely while maintaining full
ONNX inference capability with NPU acceleration.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-16 19:38:44 +02:00

143 lines
4.8 KiB
Rust

/// ONNX inference with NPU acceleration
use crate::ai::{NpuDevice, SimpleTokenizer};
use crate::error::{Result, AppError};
use ndarray::Array2;
use ort::session::Session;
use ort::value::Value;
/// Text classifier using ONNX model with NPU
pub struct OnnxClassifier {
session: std::cell::RefCell<Session>,
tokenizer: SimpleTokenizer,
npu_device: NpuDevice,
}
impl OnnxClassifier {
/// Create a new ONNX classifier with NPU acceleration
pub fn new(model_path: &str, vocab_path: &str) -> Result<Self> {
let npu_device = NpuDevice::detect();
log::info!("Loading ONNX model: {}", model_path);
log::info!("NPU Device: {} (available: {})", npu_device.device_name(), npu_device.is_available());
// Create ONNX session with NPU if available
let session = npu_device.create_session(model_path)?;
log::info!("Loading vocabulary: {}", vocab_path);
// Load our custom tokenizer
let tokenizer = SimpleTokenizer::from_vocab_file(vocab_path, 128)?;
Ok(Self {
session: std::cell::RefCell::new(session),
tokenizer,
npu_device,
})
}
/// Check if NPU is being used
pub fn is_using_npu(&self) -> bool {
self.npu_device.is_available()
}
/// Get device information
pub fn device_info(&self) -> String {
self.npu_device.device_name().to_string()
}
/// Tokenize input text
fn tokenize(&self, text: &str) -> Result<(Vec<i64>, Vec<i64>)> {
// Use our custom tokenizer
let (input_ids, attention_mask) = self.tokenizer.encode(text);
Ok((input_ids, attention_mask))
}
/// Run inference on input text
pub fn predict(&self, text: &str) -> Result<Vec<f32>> {
// Tokenize input
let (input_ids, attention_mask) = self.tokenize(text)?;
let seq_length = input_ids.len();
// Convert to ndarray (batch_size=1, seq_length)
let input_ids_array = Array2::from_shape_vec(
(1, seq_length),
input_ids,
).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?;
let attention_mask_array = Array2::from_shape_vec(
(1, seq_length),
attention_mask,
).map_err(|e| AppError::Analysis(format!("Array creation failed: {}", e)))?;
// Create ONNX values
let input_ids_value = Value::from_array(input_ids_array)
.map_err(|e| AppError::Analysis(format!("Failed to create input_ids value: {}", e)))?;
let attention_mask_value = Value::from_array(attention_mask_array)
.map_err(|e| AppError::Analysis(format!("Failed to create attention_mask value: {}", e)))?;
// Run inference
let mut session = self.session.borrow_mut();
let outputs = session
.run(ort::inputs!["input_ids" => input_ids_value, "attention_mask" => attention_mask_value])
.map_err(|e| AppError::Analysis(format!("Inference failed: {}", e)))?;
// Extract logits
let logits = outputs["logits"]
.try_extract_tensor::<f32>()
.map_err(|e| AppError::Analysis(format!("Failed to extract logits: {}", e)))?;
// Convert to Vec<f32>
let (_shape, data) = logits;
let logits_vec: Vec<f32> = data.to_vec();
Ok(logits_vec)
}
/// Classify text and return the predicted class index
pub fn classify(&self, text: &str) -> Result<usize> {
let logits = self.predict(text)?;
// Find the index of the maximum value
let predicted_class = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.ok_or_else(|| AppError::Analysis("No predictions found".to_string()))?;
Ok(predicted_class)
}
/// Classify text and return probabilities for all classes
pub fn classify_with_probabilities(&self, text: &str) -> Result<Vec<f32>> {
let logits = self.predict(text)?;
// Apply softmax to get probabilities
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probabilities: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
Ok(probabilities)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_classifier_creation() {
let model_path = "models/distilbert-base.onnx";
let tokenizer_path = "models/distilbert-tokenizer.json";
if Path::new(model_path).exists() && Path::new(tokenizer_path).exists() {
let classifier = OnnxClassifier::new(model_path, tokenizer_path);
assert!(classifier.is_ok());
}
}
}