/// Simple BERT tokenizer without external dependencies use crate::error::Result; use std::collections::HashMap; use std::fs::File; use std::io::{BufRead, BufReader}; pub struct SimpleTokenizer { vocab: HashMap, max_length: usize, cls_token_id: i64, sep_token_id: i64, pad_token_id: i64, } impl SimpleTokenizer { /// Load tokenizer from vocab file pub fn from_vocab_file(vocab_path: &str, max_length: usize) -> Result { let file = File::open(vocab_path)?; let reader = BufReader::new(file); let mut vocab = HashMap::new(); for (idx, line) in reader.lines().enumerate() { let token = line?; vocab.insert(token, idx as i64); } // Get special token IDs let cls_token_id = *vocab.get("[CLS]").unwrap_or(&101); let sep_token_id = *vocab.get("[SEP]").unwrap_or(&102); let pad_token_id = *vocab.get("[PAD]").unwrap_or(&0); Ok(Self { vocab, max_length, cls_token_id, sep_token_id, pad_token_id, }) } /// Tokenize text using simple whitespace and punctuation splitting pub fn encode(&self, text: &str) -> (Vec, Vec) { let mut input_ids = vec![self.cls_token_id]; let mut attention_mask = vec![1]; // Simple tokenization: lowercase and split let text_lower = text.to_lowercase(); // Split on whitespace and common punctuation let tokens: Vec<&str> = text_lower .split(|c: char| c.is_whitespace() || ".,!?;:()[]{}".contains(c)) .filter(|s| !s.is_empty()) .collect(); for token in tokens { // Try exact match first if let Some(&token_id) = self.vocab.get(token) { input_ids.push(token_id); attention_mask.push(1); } else { // Try subword tokenization (simple greedy approach) let mut remaining = token; while !remaining.is_empty() && input_ids.len() < self.max_length - 1 { let mut found = false; // Try longest match first for len in (1..=remaining.len()).rev() { let substr = &remaining[..len]; let lookup_key = if len < remaining.len() { format!("##{}", substr) // WordPiece continuation } else { substr.to_string() }; if let Some(&token_id) = self.vocab.get(&lookup_key) { input_ids.push(token_id); attention_mask.push(1); remaining = &remaining[len..]; found = true; break; } } if !found { // Unknown token - use [UNK] if let Some(&unk_id) = self.vocab.get("[UNK]") { input_ids.push(unk_id); attention_mask.push(1); } break; } } } if input_ids.len() >= self.max_length - 1 { break; } } // Add SEP token input_ids.push(self.sep_token_id); attention_mask.push(1); // Pad to max_length while input_ids.len() < self.max_length { input_ids.push(self.pad_token_id); attention_mask.push(0); } // Truncate if needed input_ids.truncate(self.max_length); attention_mask.truncate(self.max_length); (input_ids, attention_mask) } } #[cfg(test)] mod tests { use super::*; use std::io::Write; use tempfile::NamedTempFile; #[test] fn test_simple_tokenizer() { // Create a minimal vocab file let mut temp_file = NamedTempFile::new().unwrap(); writeln!(temp_file, "[PAD]").unwrap(); writeln!(temp_file, "[UNK]").unwrap(); writeln!(temp_file, "[CLS]").unwrap(); writeln!(temp_file, "[SEP]").unwrap(); writeln!(temp_file, "hello").unwrap(); writeln!(temp_file, "world").unwrap(); writeln!(temp_file, "test").unwrap(); let tokenizer = SimpleTokenizer::from_vocab_file( temp_file.path().to_str().unwrap(), 10, ).unwrap(); let (input_ids, attention_mask) = tokenizer.encode("hello world"); // Should have: [CLS] hello world [SEP] [PAD]... assert_eq!(input_ids.len(), 10); assert_eq!(attention_mask.len(), 10); assert_eq!(input_ids[0], tokenizer.cls_token_id); // [CLS] assert_eq!(attention_mask[0], 1); } }