diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index fc26d3cb5ee..4582d419f8f 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -77,10 +77,50 @@ pub struct HuffmanDictionary { } impl HuffmanDictionary { - /// The list of alphabet symbols and their respective frequency should - /// be given as input - pub fn new(alphabet: &[(T, u64)]) -> Self { + /// Creates a new Huffman dictionary from alphabet symbols and their frequencies. + /// + /// Returns `None` if the alphabet is empty. + /// + /// # Arguments + /// * `alphabet` - A slice of tuples containing symbols and their frequencies + /// + /// # Example + /// ``` + /// # use the_algorithms_rust::general::HuffmanDictionary; + /// let freq = vec![('a', 5), ('b', 2), ('c', 1)]; + /// let dict = HuffmanDictionary::new(&freq).unwrap(); + /// + pub fn new(alphabet: &[(T, u64)]) -> Option { + if alphabet.is_empty() { + return None; + } + let mut alph: BTreeMap = BTreeMap::new(); + + // Special case: single symbol + if alphabet.len() == 1 { + let (symbol, _freq) = alphabet[0]; + alph.insert( + symbol, + HuffmanValue { + value: 0, + bits: 1, // Must use at least 1 bit per symbol + }, + ); + + let root = HuffmanNode { + left: None, + right: None, + symbol: Some(symbol), + frequency: alphabet[0].1, + }; + + return Some(HuffmanDictionary { + alphabet: alph, + root, + }); + } + let mut queue: BinaryHeap> = BinaryHeap::new(); for (symbol, freq) in alphabet.iter() { queue.push(HuffmanNode { @@ -101,11 +141,14 @@ impl HuffmanDictionary { frequency: sm_freq, }); } - let root = queue.pop().unwrap(); - HuffmanNode::get_alphabet(0, 0, &root, &mut alph); - HuffmanDictionary { - alphabet: alph, - root, + if let Some(root) = queue.pop() { + HuffmanNode::get_alphabet(0, 0, &root, &mut alph); + Some(HuffmanDictionary { + alphabet: alph, + root, + }) + } else { + None } } pub fn encode(&self, data: &[T]) -> HuffmanEncoding { @@ -143,27 +186,48 @@ impl HuffmanEncoding { } self.num_bits += data.bits as u64; } + + #[inline] fn get_bit(&self, pos: u64) -> bool { (self.data[(pos >> 6) as usize] & (1 << (pos & 63))) != 0 } + /// In case the encoding is invalid, `None` is returned pub fn decode(&self, dict: &HuffmanDictionary) -> Option> { + // Handle empty encoding + if self.num_bits == 0 { + return Some(vec![]); + } + + // Special case: single symbol in dictionary + if dict.alphabet.len() == 1 { + //all bits represent the same symbol + let symbol = dict.alphabet.keys().next()?; + let result = vec![*symbol; self.num_bits as usize]; + return Some(result); + } + + // Normal case: multiple symbols let mut state = &dict.root; let mut result: Vec = vec![]; + for i in 0..self.num_bits { - if state.symbol.is_some() { - result.push(state.symbol.unwrap()); + if let Some(symbol) = state.symbol { + result.push(symbol); state = &dict.root; } state = if self.get_bit(i) { - state.right.as_ref().unwrap() + state.right.as_ref()? } else { - state.left.as_ref().unwrap() + state.left.as_ref()? } } + + // Check if we ended on a symbol if self.num_bits > 0 { result.push(state.symbol?); } + Some(result) } } @@ -181,12 +245,97 @@ mod tests { .for_each(|(b, &cnt)| result.push((b as u8, cnt))); result } + + #[test] + fn empty_text() { + let text = ""; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq); + assert!(dict.is_none()); + } + + #[test] + fn one_symbol_text() { + let text = "aaaa"; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq).unwrap(); + let encoded = dict.encode(bytes); + assert_eq!(encoded.num_bits, 4); + let decoded = encoded.decode(&dict).unwrap(); + assert_eq!(decoded, bytes); + } + + #[test] + fn test_decode_empty_encoding_struct() { + // Create a minimal but VALID HuffmanDictionary. + // This is required because decode() expects a dictionary, even though + // the content of the dictionary doesn't matter when num_bits == 0. + let freq = vec![(b'a', 1)]; + let dict = HuffmanDictionary::new(&freq).unwrap(); + + // Manually create the target state: an encoding with 0 bits. + let empty_encoding = HuffmanEncoding { + data: vec![], + num_bits: 0, + }; + + let result = empty_encoding.decode(&dict); + + assert_eq!(result, Some(vec![])); + } + + #[test] + fn minimal_decode_end_check() { + let freq = vec![(b'a', 1), (b'b', 1)]; + let bytes = b"ab"; + + let dict = HuffmanDictionary::new(&freq).unwrap(); + let encoded = dict.encode(bytes); + + // This decode will go through the main loop and hit the final 'if self.num_bits > 0' check. + let decoded = encoded.decode(&dict).unwrap(); + + assert_eq!(decoded, bytes); + } + + #[test] + fn test_decode_corrupted_stream_dead_end() { + // Create a dictionary with three symbols to ensure a deeper tree. + // This makes hitting a dead-end (None pointer) easier. + let freq = vec![(b'a', 1), (b'b', 1), (b'c', 1)]; + let bytes = b"ab"; + let dict = HuffmanDictionary::new(&freq).unwrap(); + + let encoded = dict.encode(bytes); + + // Manually corrupt the stream to stop mid-symbol. + // We will truncate num_bits by a small amount (e.g., 1 bit). + // This forces the loop to stop on an *intermediate* node. + let corrupted_encoding = HuffmanEncoding { + data: encoded.data, + // Shorten the bit count by one. The total length of the 'ab' stream + // is likely 4 or 5 bits. This forces the loop to end one bit early, + // leaving the state on an internal node. + num_bits: encoded + .num_bits + .checked_sub(1) + .expect("Encoding should be > 0 bits"), + }; + + // Assert that the decode fails gracefully. + // The loop finishes, the final 'if self.num_bits > 0' executes, + // and result.push(state.symbol?) fails because state.symbol is None. + assert_eq!(corrupted_encoding.decode(&dict), None); + } + #[test] fn small_text() { let text = "Hello world"; let bytes = text.as_bytes(); let freq = get_frequency(bytes); - let dict = HuffmanDictionary::new(&freq); + let dict = HuffmanDictionary::new(&freq).unwrap(); let encoded = dict.encode(bytes); assert_eq!(encoded.num_bits, 32); let decoded = encoded.decode(&dict).unwrap(); @@ -208,7 +357,7 @@ mod tests { ); let bytes = text.as_bytes(); let freq = get_frequency(bytes); - let dict = HuffmanDictionary::new(&freq); + let dict = HuffmanDictionary::new(&freq).unwrap(); let encoded = dict.encode(bytes); assert_eq!(encoded.num_bits, 2372); let decoded = encoded.decode(&dict).unwrap(); diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs index f477607b20f..918bdc338e6 100644 --- a/src/machine_learning/loss_function/kl_divergence_loss.rs +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -16,7 +16,7 @@ pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { let loss: f64 = actual .iter() .zip(predicted.iter()) - .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) + .map(|(&a, &p)| (a + eps) * ((a + eps) / (p + eps)).ln()) .sum(); loss }