diff --git a/src/data_structures/floyds_algorithm.rs b/src/data_structures/floyds_algorithm.rs index b475d07d963..bdedada24e9 100644 --- a/src/data_structures/floyds_algorithm.rs +++ b/src/data_structures/floyds_algorithm.rs @@ -91,5 +91,12 @@ mod tests { assert!(has_cycle(&linked_list)); assert_eq!(detect_cycle(&linked_list), Some(3)); + + // Break the cycle before the list is dropped + unsafe { + if let Some(mut tail) = linked_list.tail { + tail.as_mut().next = None; + } + } } } diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index fc26d3cb5ee..5dd623ec4d0 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -77,10 +77,50 @@ pub struct HuffmanDictionary { } impl HuffmanDictionary { - /// The list of alphabet symbols and their respective frequency should - /// be given as input - pub fn new(alphabet: &[(T, u64)]) -> Self { + /// Creates a new Huffman dictionary from alphabet symbols and their frequencies. + /// + /// Returns `None` if the alphabet is empty. + /// + /// # Arguments + /// * `alphabet` - A slice of tuples containing symbols and their frequencies + /// + /// # Example + /// ``` + /// # use the_algorithms_rust::general::HuffmanDictionary; + /// let freq = vec![('a', 5), ('b', 2), ('c', 1)]; + /// let dict = HuffmanDictionary::new(&freq).unwrap(); + /// + pub fn new(alphabet: &[(T, u64)]) -> Option { + if alphabet.is_empty() { + return None; + } + let mut alph: BTreeMap = BTreeMap::new(); + + // Special case: single symbol + if alphabet.len() == 1 { + let (symbol, _freq) = alphabet[0]; + alph.insert( + symbol, + HuffmanValue { + value: 0, + bits: 1, // Must use at least 1 bit per symbol + }, + ); + + let root = HuffmanNode { + left: None, + right: None, + symbol: Some(symbol), + frequency: alphabet[0].1, + }; + + return Some(HuffmanDictionary { + alphabet: alph, + root, + }); + } + let mut queue: BinaryHeap> = BinaryHeap::new(); for (symbol, freq) in alphabet.iter() { queue.push(HuffmanNode { @@ -101,11 +141,14 @@ impl HuffmanDictionary { frequency: sm_freq, }); } - let root = queue.pop().unwrap(); - HuffmanNode::get_alphabet(0, 0, &root, &mut alph); - HuffmanDictionary { - alphabet: alph, - root, + if let Some(root) = queue.pop() { + HuffmanNode::get_alphabet(0, 0, &root, &mut alph); + Some(HuffmanDictionary { + alphabet: alph, + root, + }) + } else { + None } } pub fn encode(&self, data: &[T]) -> HuffmanEncoding { @@ -143,27 +186,48 @@ impl HuffmanEncoding { } self.num_bits += data.bits as u64; } + + #[inline] fn get_bit(&self, pos: u64) -> bool { (self.data[(pos >> 6) as usize] & (1 << (pos & 63))) != 0 } + /// In case the encoding is invalid, `None` is returned pub fn decode(&self, dict: &HuffmanDictionary) -> Option> { + // Handle empty encoding + if self.num_bits == 0 { + return Some(vec![]); + } + + // Special case: single symbol in dictionary + if dict.alphabet.len() == 1 { + //all bits represent the same symbol + let symbol = dict.alphabet.keys().next()?; + let result = vec![*symbol; self.num_bits as usize]; + return Some(result); + } + + // Normal case: multiple symbols let mut state = &dict.root; let mut result: Vec = vec![]; + for i in 0..self.num_bits { - if state.symbol.is_some() { - result.push(state.symbol.unwrap()); + if let Some(symbol) = state.symbol { + result.push(symbol); state = &dict.root; } state = if self.get_bit(i) { - state.right.as_ref().unwrap() + state.right.as_ref()? } else { - state.left.as_ref().unwrap() + state.left.as_ref()? } } + + // Check if we ended on a symbol if self.num_bits > 0 { result.push(state.symbol?); } + Some(result) } } @@ -181,12 +245,34 @@ mod tests { .for_each(|(b, &cnt)| result.push((b as u8, cnt))); result } + + #[test] + fn empty_text() { + let text = ""; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq); + assert!(dict.is_none()); + } + + #[test] + fn one_symbol_text() { + let text = "aaaa"; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq).unwrap(); + let encoded = dict.encode(bytes); + assert_eq!(encoded.num_bits, 4); + let decoded = encoded.decode(&dict).unwrap(); + assert_eq!(decoded, bytes); + } + #[test] fn small_text() { let text = "Hello world"; let bytes = text.as_bytes(); let freq = get_frequency(bytes); - let dict = HuffmanDictionary::new(&freq); + let dict = HuffmanDictionary::new(&freq).unwrap(); let encoded = dict.encode(bytes); assert_eq!(encoded.num_bits, 32); let decoded = encoded.decode(&dict).unwrap(); @@ -208,7 +294,7 @@ mod tests { ); let bytes = text.as_bytes(); let freq = get_frequency(bytes); - let dict = HuffmanDictionary::new(&freq); + let dict = HuffmanDictionary::new(&freq).unwrap(); let encoded = dict.encode(bytes); assert_eq!(encoded.num_bits, 2372); let decoded = encoded.decode(&dict).unwrap(); diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs index f477607b20f..918bdc338e6 100644 --- a/src/machine_learning/loss_function/kl_divergence_loss.rs +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -16,7 +16,7 @@ pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { let loss: f64 = actual .iter() .zip(predicted.iter()) - .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) + .map(|(&a, &p)| (a + eps) * ((a + eps) / (p + eps)).ln()) .sum(); loss }