From 8dcbf137e86e0b647590d6d8e5051cd420af1e23 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:47:43 -0600 Subject: [PATCH 01/11] fix(huffman): Handle edge cases and improve error handling - Change HuffmanDictionary::new() to return Option for safer API - Add proper handling for empty alphabet (returns None) - Add special case handling for single-symbol alphabets - Replace unwrap() calls with ? operator in decode() for better error handling - Add #[inline(always)] optimization for frequently called get_bit() - Add comprehensive tests for edge cases - Improve documentation with usage examples BREAKING CHANGE: HuffmanDictionary::new() now returns Option --- src/general/huffman_encoding.rs | 114 ++++++++++++++++++++++++++++---- 1 file changed, 100 insertions(+), 14 deletions(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index fc26d3cb5ee..8a6e5b6ee3f 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -77,10 +77,50 @@ pub struct HuffmanDictionary { } impl HuffmanDictionary { - /// The list of alphabet symbols and their respective frequency should - /// be given as input - pub fn new(alphabet: &[(T, u64)]) -> Self { + /// Creates a new Huffman dictionary from alphabet symbols and their frequencies. + /// + /// Returns `None` if the alphabet is empty. + /// + /// # Arguments + /// * `alphabet` - A slice of tuples containing symbols and their frequencies + /// + /// # Example + /// ``` + /// # use the_algorithms_rust::general::HuffmanDictionary; + /// let freq = vec![('a', 5), ('b', 2), ('c', 1)]; + /// let dict = HuffmanDictionary::new(&freq).unwrap(); + /// + pub fn new(alphabet: &[(T, u64)]) -> Option { + if alphabet.is_empty() { + return None; + } + let mut alph: BTreeMap = BTreeMap::new(); + + // Special case: single symbol + if alphabet.len() == 1 { + let (symbol, _freq) = alphabet[0]; + alph.insert( + symbol, + HuffmanValue { + value: 0, + bits: 1, // Must use at least 1 bit per symbol + }, + ); + + let root = HuffmanNode { + left: None, + right: None, + symbol: Some(symbol), + frequency: alphabet[0].1, + }; + + return Some(HuffmanDictionary { + alphabet: alph, + root, + }); + } + let mut queue: BinaryHeap> = BinaryHeap::new(); for (symbol, freq) in alphabet.iter() { queue.push(HuffmanNode { @@ -101,11 +141,14 @@ impl HuffmanDictionary { frequency: sm_freq, }); } - let root = queue.pop().unwrap(); - HuffmanNode::get_alphabet(0, 0, &root, &mut alph); - HuffmanDictionary { - alphabet: alph, - root, + if let Some(root) = queue.pop() { + HuffmanNode::get_alphabet(0, 0, &root, &mut alph); + Some(HuffmanDictionary { + alphabet: alph, + root, + }) + } else { + None } } pub fn encode(&self, data: &[T]) -> HuffmanEncoding { @@ -143,27 +186,48 @@ impl HuffmanEncoding { } self.num_bits += data.bits as u64; } + + #[inline(always)] fn get_bit(&self, pos: u64) -> bool { (self.data[(pos >> 6) as usize] & (1 << (pos & 63))) != 0 } + /// In case the encoding is invalid, `None` is returned pub fn decode(&self, dict: &HuffmanDictionary) -> Option> { + // Handle empty encoding + if self.num_bits == 0 { + return Some(vec![]); + } + + // Special case: single symbol in dictionary + if dict.alphabet.len() == 1 { + //all bits represent the same symbol + let symbol = dict.alphabet.keys().next()?; + let result = vec![*symbol; self.num_bits as usize]; + return Some(result); + } + + // Normal case: multiple symbols let mut state = &dict.root; let mut result: Vec = vec![]; + for i in 0..self.num_bits { - if state.symbol.is_some() { - result.push(state.symbol.unwrap()); + if let Some(symbol) = state.symbol { + result.push(symbol); state = &dict.root; } state = if self.get_bit(i) { - state.right.as_ref().unwrap() + state.right.as_ref()? } else { - state.left.as_ref().unwrap() + state.left.as_ref()? } } + + // Check if we ended on a symbol if self.num_bits > 0 { result.push(state.symbol?); } + Some(result) } } @@ -181,12 +245,34 @@ mod tests { .for_each(|(b, &cnt)| result.push((b as u8, cnt))); result } + + #[test] + fn empty_text() { + let text = ""; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq); + assert!(dict.is_none()); + } + + #[test] + fn one_symbol_text() { + let text = "aaaa"; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq).unwrap(); + let encoded = dict.encode(bytes); + assert_eq!(encoded.num_bits, 4); + let decoded = encoded.decode(&dict).unwrap(); + assert_eq!(decoded, bytes); + } + #[test] fn small_text() { let text = "Hello world"; let bytes = text.as_bytes(); let freq = get_frequency(bytes); - let dict = HuffmanDictionary::new(&freq); + let dict = HuffmanDictionary::new(&freq).unwrap(); let encoded = dict.encode(bytes); assert_eq!(encoded.num_bits, 32); let decoded = encoded.decode(&dict).unwrap(); @@ -208,7 +294,7 @@ mod tests { ); let bytes = text.as_bytes(); let freq = get_frequency(bytes); - let dict = HuffmanDictionary::new(&freq); + let dict = HuffmanDictionary::new(&freq).unwrap(); let encoded = dict.encode(bytes); assert_eq!(encoded.num_bits, 2372); let decoded = encoded.decode(&dict).unwrap(); From 77ec98a4434deab7c72dcbc2b4a3d737468cec21 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:14:12 -0600 Subject: [PATCH 02/11] clippy --- src/general/huffman_encoding.rs | 6 +++--- src/machine_learning/loss_function/kl_divergence_loss.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 8a6e5b6ee3f..41f9886cc0f 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -78,12 +78,12 @@ pub struct HuffmanDictionary { impl HuffmanDictionary { /// Creates a new Huffman dictionary from alphabet symbols and their frequencies. - /// + /// /// Returns `None` if the alphabet is empty. - /// + /// /// # Arguments /// * `alphabet` - A slice of tuples containing symbols and their frequencies - /// + /// /// # Example /// ``` /// # use the_algorithms_rust::general::HuffmanDictionary; diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs index f477607b20f..918bdc338e6 100644 --- a/src/machine_learning/loss_function/kl_divergence_loss.rs +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -16,7 +16,7 @@ pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { let loss: f64 = actual .iter() .zip(predicted.iter()) - .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) + .map(|(&a, &p)| (a + eps) * ((a + eps) / (p + eps)).ln()) .sum(); loss } From 44d02d70f8c6038737aa010202d3ac208788007f Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:23:18 -0600 Subject: [PATCH 03/11] inline(always) -> inline --- src/general/huffman_encoding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 41f9886cc0f..5dd623ec4d0 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -187,7 +187,7 @@ impl HuffmanEncoding { self.num_bits += data.bits as u64; } - #[inline(always)] + #[inline] fn get_bit(&self, pos: u64) -> bool { (self.data[(pos >> 6) as usize] & (1 << (pos & 63))) != 0 } From 9f0df7e596931c1df4d2dd8f80607aef7aef1a0a Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:21:55 -0600 Subject: [PATCH 04/11] Test: Increase coverage for huffman_encoding.rs decode method Adds two new test cases to ensure 100% patch coverage for HuffmanEncoding::decode: 1. test_decode_empty_encoding_struct: Covers the edge case where num_bits == 0. 2. minimal_decode_end_check: Ensures the final 'if self.num_bits > 0' check in the multi-symbol decode path is fully covered. --- src/general/huffman_encoding.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 5dd623ec4d0..80dd1a0b77c 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -267,6 +267,39 @@ mod tests { assert_eq!(decoded, bytes); } + #[test] + fn test_decode_empty_encoding_struct() { + // Create a minimal but VALID HuffmanDictionary. + // This is required because decode() expects a dictionary, even though + // the content of the dictionary doesn't matter when num_bits == 0. + let freq = vec![('a' as u8, 1)]; + let dict = HuffmanDictionary::new(&freq).unwrap(); + + // Manually create the target state: an encoding with 0 bits. + let empty_encoding = HuffmanEncoding { + data: vec![], + num_bits: 0, + }; + + let result = empty_encoding.decode(&dict); + + assert_eq!(result, Some(vec![])); + } + + #[test] + fn minimal_decode_end_check() { + let freq = vec![('a' as u8, 1), ('b' as u8, 1)]; + let bytes = b"ab"; + + let dict = HuffmanDictionary::new(&freq).unwrap(); + let encoded = dict.encode(bytes); + + // This decode will go through the main loop and hit the final 'if self.num_bits > 0' check. + let decoded = encoded.decode(&dict).unwrap(); + + assert_eq!(decoded, bytes); + } + #[test] fn small_text() { let text = "Hello world"; From 87ebdc5db7112fab7ecb3ebdf3a55425efe12bf0 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:28:52 -0600 Subject: [PATCH 05/11] ->byte literal --- src/general/huffman_encoding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 80dd1a0b77c..50e5a8f5045 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -288,7 +288,7 @@ mod tests { #[test] fn minimal_decode_end_check() { - let freq = vec![('a' as u8, 1), ('b' as u8, 1)]; + let freq = vec![(b'a' as u8, 1), (b'b' as u8, 1)]; let bytes = b"ab"; let dict = HuffmanDictionary::new(&freq).unwrap(); From 7e22cb47a0db677ca8b5e28f76409d5918b8dbaf Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:30:23 -0600 Subject: [PATCH 06/11] -> byte literal --- src/general/huffman_encoding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 50e5a8f5045..ad6c69f7b9a 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -272,7 +272,7 @@ mod tests { // Create a minimal but VALID HuffmanDictionary. // This is required because decode() expects a dictionary, even though // the content of the dictionary doesn't matter when num_bits == 0. - let freq = vec![('a' as u8, 1)]; + let freq = vec![(b'a' as u8, 1)]; let dict = HuffmanDictionary::new(&freq).unwrap(); // Manually create the target state: an encoding with 0 bits. From 7c63c02e540c91819b09fd63fd7144380b9b86d6 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:42:06 -0600 Subject: [PATCH 07/11] Refactor: Fix clippy lints in new huffman tests Corrects 'char-lit-as-u8' and 'unnecessary-cast' lints in the newly added coverage tests to satisfy GitHub Actions. --- src/general/huffman_encoding.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index ad6c69f7b9a..62644b857ed 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -272,7 +272,7 @@ mod tests { // Create a minimal but VALID HuffmanDictionary. // This is required because decode() expects a dictionary, even though // the content of the dictionary doesn't matter when num_bits == 0. - let freq = vec![(b'a' as u8, 1)]; + let freq = vec![(b'a', 1)]; let dict = HuffmanDictionary::new(&freq).unwrap(); // Manually create the target state: an encoding with 0 bits. @@ -288,7 +288,7 @@ mod tests { #[test] fn minimal_decode_end_check() { - let freq = vec![(b'a' as u8, 1), (b'b' as u8, 1)]; + let freq = vec![(b'a', 1), (b'b', 1)]; let bytes = b"ab"; let dict = HuffmanDictionary::new(&freq).unwrap(); From 5c5d022dedba0888ea0f1d012215933ad297de51 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:45:26 -0600 Subject: [PATCH 08/11] fmt --- src/general/huffman_encoding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 62644b857ed..a7cab1540d7 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -288,7 +288,7 @@ mod tests { #[test] fn minimal_decode_end_check() { - let freq = vec![(b'a', 1), (b'b', 1)]; + let freq = vec![(b'a', 1), (b'b', 1)]; let bytes = b"ab"; let dict = HuffmanDictionary::new(&freq).unwrap(); From edc2003feacdd89604923c3f5cfe87fb09998549 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Thu, 25 Sep 2025 10:05:31 -0600 Subject: [PATCH 09/11] add test for code coverage: fn test_decode_corrupted_stream_dead_end --- src/general/huffman_encoding.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index a7cab1540d7..4582d419f8f 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -300,6 +300,36 @@ mod tests { assert_eq!(decoded, bytes); } + #[test] + fn test_decode_corrupted_stream_dead_end() { + // Create a dictionary with three symbols to ensure a deeper tree. + // This makes hitting a dead-end (None pointer) easier. + let freq = vec![(b'a', 1), (b'b', 1), (b'c', 1)]; + let bytes = b"ab"; + let dict = HuffmanDictionary::new(&freq).unwrap(); + + let encoded = dict.encode(bytes); + + // Manually corrupt the stream to stop mid-symbol. + // We will truncate num_bits by a small amount (e.g., 1 bit). + // This forces the loop to stop on an *intermediate* node. + let corrupted_encoding = HuffmanEncoding { + data: encoded.data, + // Shorten the bit count by one. The total length of the 'ab' stream + // is likely 4 or 5 bits. This forces the loop to end one bit early, + // leaving the state on an internal node. + num_bits: encoded + .num_bits + .checked_sub(1) + .expect("Encoding should be > 0 bits"), + }; + + // Assert that the decode fails gracefully. + // The loop finishes, the final 'if self.num_bits > 0' executes, + // and result.push(state.symbol?) fails because state.symbol is None. + assert_eq!(corrupted_encoding.decode(&dict), None); + } + #[test] fn small_text() { let text = "Hello world"; From d64d56d5bc8ea35ba2910787849576dfb0d2b5e4 Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Fri, 26 Sep 2025 13:03:57 -0600 Subject: [PATCH 10/11] git commit -m "Revert unrelated style fix in KLD loss as requested by reviewer" --- src/machine_learning/loss_function/kl_divergence_loss.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs index 918bdc338e6..f477607b20f 100644 --- a/src/machine_learning/loss_function/kl_divergence_loss.rs +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -16,7 +16,7 @@ pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { let loss: f64 = actual .iter() .zip(predicted.iter()) - .map(|(&a, &p)| (a + eps) * ((a + eps) / (p + eps)).ln()) + .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) .sum(); loss } From cb6fe85f7c19da1a0052c8d73a85bbc15e4bbc0e Mon Sep 17 00:00:00 2001 From: Jesper Olsen <43079279+jesper-olsen@users.noreply.github.com> Date: Fri, 26 Sep 2025 13:09:09 -0600 Subject: [PATCH 11/11] fix warning - Unnecessary parentheses around closure body --- src/machine_learning/loss_function/kl_divergence_loss.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs index f477607b20f..918bdc338e6 100644 --- a/src/machine_learning/loss_function/kl_divergence_loss.rs +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -16,7 +16,7 @@ pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { let loss: f64 = actual .iter() .zip(predicted.iter()) - .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) + .map(|(&a, &p)| (a + eps) * ((a + eps) / (p + eps)).ln()) .sum(); loss }