diff --git a/library/alloc/src/vec/extract_if.rs b/library/alloc/src/vec/extract_if.rs index a456d3d9e602d..cb9e14f554d41 100644 --- a/library/alloc/src/vec/extract_if.rs +++ b/library/alloc/src/vec/extract_if.rs @@ -64,27 +64,37 @@ where type Item = T; fn next(&mut self) -> Option { - unsafe { - while self.idx < self.end { - let i = self.idx; - let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len); - let drained = (self.pred)(&mut v[i]); - // Update the index *after* the predicate is called. If the index - // is updated prior and the predicate panics, the element at this - // index would be leaked. - self.idx += 1; - if drained { - self.del += 1; - return Some(ptr::read(&v[i])); - } else if self.del > 0 { - let del = self.del; - let src: *const T = &v[i]; - let dst: *mut T = &mut v[i - del]; - ptr::copy_nonoverlapping(src, dst, 1); + while self.idx < self.end { + let i = self.idx; + // SAFETY: + // We know that `i < self.end` from the if guard and that `self.end <= self.old_len` from + // the validity of `Self`. Therefore `i` points to an element within `vec`. + // + // Additionally, the i-th element is valid because each element is visited at most once + // and it is the first time we access vec[i]. + // + // Note: we can't use `vec.get_unchecked_mut(i)` here since the precondition for that + // function is that i < vec.len(), but we've set vec's length to zero. + let cur = unsafe { &mut *self.vec.as_mut_ptr().add(i) }; + let drained = (self.pred)(cur); + // Update the index *after* the predicate is called. If the index + // is updated prior and the predicate panics, the element at this + // index would be leaked. + self.idx += 1; + if drained { + self.del += 1; + // SAFETY: We never touch this element again after returning it. + return Some(unsafe { ptr::read(cur) }); + } else if self.del > 0 { + // SAFETY: `self.del` > 0, so the hole slot must not overlap with current element. + // We use copy for move, and never touch this element again. + unsafe { + let hole_slot = self.vec.as_mut_ptr().add(i - self.del); + ptr::copy_nonoverlapping(cur, hole_slot, 1); } } - None } + None } fn size_hint(&self) -> (usize, Option) { @@ -95,14 +105,18 @@ where #[stable(feature = "extract_if", since = "1.87.0")] impl Drop for ExtractIf<'_, T, F, A> { fn drop(&mut self) { - unsafe { - if self.idx < self.old_len && self.del > 0 { - let ptr = self.vec.as_mut_ptr(); - let src = ptr.add(self.idx); - let dst = src.sub(self.del); - let tail_len = self.old_len - self.idx; - src.copy_to(dst, tail_len); + if self.del > 0 { + // SAFETY: Trailing unchecked items must be valid since we never touch them. + unsafe { + ptr::copy( + self.vec.as_ptr().add(self.idx), + self.vec.as_mut_ptr().add(self.idx - self.del), + self.old_len - self.idx, + ); } + } + // SAFETY: After filling holes, all items are in contiguous memory. + unsafe { self.vec.set_len(self.old_len - self.del); } } diff --git a/src/tools/miri/tests/pass/vec.rs b/src/tools/miri/tests/pass/vec.rs index 3e526813bb457..8b1b1e143b16d 100644 --- a/src/tools/miri/tests/pass/vec.rs +++ b/src/tools/miri/tests/pass/vec.rs @@ -169,6 +169,15 @@ fn miri_issue_2759() { input.replace_range(0..0, "0"); } +/// This was skirting the edge of UB, let's make sure it remains on the sound side. +/// Context: . +fn extract_if() { + let mut v = vec![Box::new(0u64), Box::new(1u64)]; + for item in v.extract_if(.., |x| **x == 0) { + drop(item); + } +} + fn main() { assert_eq!(vec_reallocate().len(), 5); @@ -199,4 +208,5 @@ fn main() { swap_remove(); reverse(); miri_issue_2759(); + extract_if(); }