Skip to content

Commit 0e48d4a

Browse files
authored
fix: retain not set len if predicate panics (#413)
* test: add retain method test from alloc::String https://github.com/rust-lang/rust/blob/73c278fd936c8eab4c8c6c6cb638da091b1e6740/library/alloc/tests/string.rs#L466-L499 * fix: retain not set len if predicate panics
1 parent 0e17e08 commit 0e48d4a

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

compact_str/src/lib.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,24 +1232,38 @@ impl CompactString {
12321232
pub fn retain(&mut self, mut predicate: impl FnMut(char) -> bool) {
12331233
// We iterate over the string, and copy character by character.
12341234

1235-
let s = self.as_mut_str();
1236-
let mut dest_idx = 0;
1237-
let mut src_idx = 0;
1238-
while let Some(ch) = s[src_idx..].chars().next() {
1235+
struct SetLenOnDrop<'a> {
1236+
self_: &'a mut CompactString,
1237+
src_idx: usize,
1238+
dst_idx: usize,
1239+
}
1240+
1241+
let mut g = SetLenOnDrop {
1242+
self_: self,
1243+
src_idx: 0,
1244+
dst_idx: 0,
1245+
};
1246+
let s = g.self_.as_mut_str();
1247+
while let Some(ch) = s[g.src_idx..].chars().next() {
12391248
let ch_len = ch.len_utf8();
12401249
if predicate(ch) {
12411250
// SAFETY: We know that both indices are valid, and that we don't split a char.
12421251
unsafe {
12431252
let p = s.as_mut_ptr();
1244-
core::ptr::copy(p.add(src_idx), p.add(dest_idx), ch_len);
1253+
core::ptr::copy(p.add(g.src_idx), p.add(g.dst_idx), ch_len);
12451254
}
1246-
dest_idx += ch_len;
1255+
g.dst_idx += ch_len;
12471256
}
1248-
src_idx += ch_len;
1257+
g.src_idx += ch_len;
12491258
}
12501259

1251-
// SAFETY: We know that the index is a valid position to break the string.
1252-
unsafe { self.set_len(dest_idx) };
1260+
impl Drop for SetLenOnDrop<'_> {
1261+
fn drop(&mut self) {
1262+
// SAFETY: We know that the index is a valid position to break the string.
1263+
unsafe { self.self_.set_len(self.dst_idx) };
1264+
}
1265+
}
1266+
drop(g);
12531267
}
12541268

12551269
/// Decode a bytes slice as UTF-8 string, replacing any illegal codepoints

compact_str/src/tests.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,41 @@ fn test_insert(to_compact: fn(&'static str) -> CompactString) {
13671367
);
13681368
}
13691369

1370+
#[test]
1371+
#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")]
1372+
fn test_retain() {
1373+
let mut s = CompactString::from("α_β_γ");
1374+
1375+
s.retain(|_| true);
1376+
assert_eq!(s, "α_β_γ");
1377+
1378+
s.retain(|c| c != '_');
1379+
assert_eq!(s, "αβγ");
1380+
1381+
s.retain(|c| c != 'β');
1382+
assert_eq!(s, "αγ");
1383+
1384+
s.retain(|c| c == 'α');
1385+
assert_eq!(s, "α");
1386+
1387+
s.retain(|_| false);
1388+
assert_eq!(s, "");
1389+
1390+
let mut s = CompactString::from("0è0");
1391+
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1392+
let mut count = 0;
1393+
s.retain(|_| {
1394+
count += 1;
1395+
match count {
1396+
1 => false,
1397+
2 => true,
1398+
_ => panic!(),
1399+
}
1400+
});
1401+
}));
1402+
assert!(std::str::from_utf8(s.as_bytes()).is_ok());
1403+
}
1404+
13701405
#[test]
13711406
fn test_remove() {
13721407
let mut control = String::from("🦄🦀hello🎶world🇺🇸");

0 commit comments

Comments
 (0)