Skip to content

Commit d15fd49

Browse files
Vectorize find_end (#4943)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent a2240ab commit d15fd49

File tree

4 files changed

+448
-37
lines changed

4 files changed

+448
-37
lines changed

benchmarks/src/search.cpp

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <vector>
1313
using namespace std::string_view_literals;
1414

15-
const char src_haystack[] =
15+
constexpr std::string_view common_src_data =
1616
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam mollis imperdiet massa, at dapibus elit interdum "
1717
"ac. In eget sollicitudin mi. Nam at tellus at sapien tincidunt sollicitudin vel non eros. Pellentesque nunc nunc, "
1818
"ullamcorper eu accumsan at, pulvinar non turpis. Quisque vel mauris pulvinar, pretium purus vel, ultricies erat. "
@@ -43,13 +43,42 @@ const char src_haystack[] =
4343
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquet malesuada est at dignissim. Pellentesque finibus "
4444
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";
4545

46-
constexpr std::array patterns = {
47-
"aliquet"sv,
48-
"aliquet malesuada"sv,
46+
template <size_t Size, bool Last_is_different>
47+
constexpr auto make_fill_pattern_array() {
48+
std::array<char, Size> result;
49+
result.fill('*');
50+
51+
if constexpr (Last_is_different) {
52+
result.back() = '!';
53+
}
54+
55+
return result;
56+
}
57+
58+
template <size_t Size, bool Last_is_different>
59+
constexpr std::array fill_pattern_array = make_fill_pattern_array<Size, Last_is_different>();
60+
61+
template <size_t Size, bool Last_is_different>
62+
constexpr std::string_view fill_pattern_view = fill_pattern_array<Size, Last_is_different>;
63+
64+
struct data_and_pattern {
65+
std::string_view data;
66+
std::string_view pattern;
67+
};
68+
69+
constexpr data_and_pattern patterns[] = {
70+
/* 0. Small, closer to end */ {common_src_data, "aliquet"sv},
71+
/* 1. Large, closer to end */ {common_src_data, "aliquet malesuada"sv},
72+
/* 2. Small, closer to begin */ {common_src_data, "pulvinar"sv},
73+
/* 3. Large, closer to begin */ {common_src_data, "dapibus elit interdum"sv},
74+
75+
/* 4. Small, evil */ {fill_pattern_view<3000, false>, fill_pattern_view<7, true>},
76+
/* 5. Large, evil */ {fill_pattern_view<3000, false>, fill_pattern_view<20, true>},
4977
};
5078

5179
void c_strstr(benchmark::State& state) {
52-
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
80+
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
81+
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;
5382

5483
const std::string haystack(std::begin(src_haystack), std::end(src_haystack));
5584
const std::string needle(std::begin(src_needle), std::end(src_needle));
@@ -64,7 +93,8 @@ void c_strstr(benchmark::State& state) {
6493

6594
template <class T>
6695
void classic_search(benchmark::State& state) {
67-
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
96+
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
97+
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;
6898

6999
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
70100
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
@@ -79,7 +109,8 @@ void classic_search(benchmark::State& state) {
79109

80110
template <class T>
81111
void ranges_search(benchmark::State& state) {
82-
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
112+
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
113+
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;
83114

84115
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
85116
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
@@ -94,7 +125,8 @@ void ranges_search(benchmark::State& state) {
94125

95126
template <class T>
96127
void search_default_searcher(benchmark::State& state) {
97-
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
128+
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
129+
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;
98130

99131
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
100132
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
@@ -107,26 +139,57 @@ void search_default_searcher(benchmark::State& state) {
107139
}
108140
}
109141

142+
template <class T>
143+
void classic_find_end(benchmark::State& state) {
144+
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
145+
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;
146+
147+
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
148+
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
149+
150+
for (auto _ : state) {
151+
benchmark::DoNotOptimize(haystack);
152+
benchmark::DoNotOptimize(needle);
153+
auto res = std::find_end(haystack.begin(), haystack.end(), needle.begin(), needle.end());
154+
benchmark::DoNotOptimize(res);
155+
}
156+
}
157+
158+
template <class T>
159+
void ranges_find_end(benchmark::State& state) {
160+
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
161+
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;
162+
163+
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
164+
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
165+
166+
for (auto _ : state) {
167+
benchmark::DoNotOptimize(haystack);
168+
benchmark::DoNotOptimize(needle);
169+
auto res = std::ranges::find_end(haystack, needle);
170+
benchmark::DoNotOptimize(res);
171+
}
172+
}
173+
110174
void common_args(auto bm) {
111-
bm->Range(0, patterns.size() - 1);
175+
bm->DenseRange(0, std::size(patterns) - 1, 1);
112176
}
113177

114178
BENCHMARK(c_strstr)->Apply(common_args);
115179

116180
BENCHMARK(classic_search<std::uint8_t>)->Apply(common_args);
117181
BENCHMARK(classic_search<std::uint16_t>)->Apply(common_args);
118-
BENCHMARK(classic_search<std::uint32_t>)->Apply(common_args);
119-
BENCHMARK(classic_search<std::uint64_t>)->Apply(common_args);
120182

121183
BENCHMARK(ranges_search<std::uint8_t>)->Apply(common_args);
122184
BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
123-
BENCHMARK(ranges_search<std::uint32_t>)->Apply(common_args);
124-
BENCHMARK(ranges_search<std::uint64_t>)->Apply(common_args);
125185

126186
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
127187
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);
128-
BENCHMARK(search_default_searcher<std::uint32_t>)->Apply(common_args);
129-
BENCHMARK(search_default_searcher<std::uint64_t>)->Apply(common_args);
130188

189+
BENCHMARK(classic_find_end<std::uint8_t>)->Apply(common_args);
190+
BENCHMARK(classic_find_end<std::uint16_t>)->Apply(common_args);
191+
192+
BENCHMARK(ranges_find_end<std::uint8_t>)->Apply(common_args);
193+
BENCHMARK(ranges_find_end<std::uint16_t>)->Apply(common_args);
131194

132195
BENCHMARK_MAIN();

stl/inc/algorithm

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ const void* __stdcall __std_find_last_trivial_2(const void* _First, const void*
5959
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
6060
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;
6161

62+
const void* __stdcall __std_find_end_1(
63+
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
64+
const void* __stdcall __std_find_end_2(
65+
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
66+
6267
__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
6368
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
6469
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
@@ -189,6 +194,19 @@ _Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val
189194
}
190195
}
191196

197+
template <class _Ty1, class _Ty2>
198+
_Ty1* _Find_end_vectorized(
199+
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept {
200+
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
201+
if constexpr (sizeof(_Ty1) == 1) {
202+
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_1(_First1, _Last1, _First2, _Count2)));
203+
} else if constexpr (sizeof(_Ty1) == 2) {
204+
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_2(_First1, _Last1, _First2, _Count2)));
205+
} else {
206+
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
207+
}
208+
}
209+
192210
template <class _Ty, class _TVal1, class _TVal2>
193211
__declspec(noalias) void _Replace_vectorized(
194212
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
@@ -3194,6 +3212,26 @@ _NODISCARD _CONSTEXPR20 _FwdIt1 find_end(
31943212
if constexpr (_Is_ranges_random_iter_v<_FwdIt1> && _Is_ranges_random_iter_v<_FwdIt2>) {
31953213
const _Iter_diff_t<_FwdIt2> _Count2 = _ULast2 - _UFirst2;
31963214
if (_Count2 > 0 && _Count2 <= _ULast1 - _UFirst1) {
3215+
#if _USE_STD_VECTOR_ALGORITHMS
3216+
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
3217+
if (!_STD _Is_constant_evaluated()) {
3218+
const auto _Ptr1 = _STD _To_address(_UFirst1);
3219+
3220+
const auto _Ptr_res1 = _STD _Find_end_vectorized(
3221+
_Ptr1, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), static_cast<size_t>(_Count2));
3222+
3223+
if constexpr (is_pointer_v<decltype(_UFirst1)>) {
3224+
_UFirst1 = _Ptr_res1;
3225+
} else {
3226+
_UFirst1 += _Ptr_res1 - _Ptr1;
3227+
}
3228+
3229+
_STD _Seek_wrapped(_First1, _UFirst1);
3230+
return _First1;
3231+
}
3232+
}
3233+
#endif // _USE_STD_VECTOR_ALGORITHMS
3234+
31973235
for (auto _UCandidate = _ULast1 - static_cast<_Iter_diff_t<_FwdIt1>>(_Count2);; --_UCandidate) {
31983236
if (_STD _Equal_rev_pred_unchecked(_UCandidate, _UFirst2, _ULast2, _STD _Pass_fn(_Pred))) {
31993237
_STD _Seek_wrapped(_First1, _UCandidate);
@@ -3297,6 +3335,34 @@ namespace ranges {
32973335

32983336
if (_Count2 > 0 && _Count2 <= _Count1) {
32993337
const auto _Count2_as1 = static_cast<iter_difference_t<_It1>>(_Count2);
3338+
#if _USE_STD_VECTOR_ALGORITHMS
3339+
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
3340+
&& is_same_v<_Pj2, identity>) {
3341+
if (!_STD is_constant_evaluated()) {
3342+
const auto _Ptr1 = _STD to_address(_First1);
3343+
const auto _Ptr2 = _STD to_address(_First2);
3344+
const auto _Ptr_last1 = _Ptr1 + _Count1;
3345+
3346+
const auto _Ptr_res1 =
3347+
_STD _Find_end_vectorized(_Ptr1, _Ptr_last1, _Ptr2, static_cast<size_t>(_Count2));
3348+
3349+
if constexpr (is_pointer_v<_It1>) {
3350+
if (_Ptr_res1 != _Ptr_last1) {
3351+
return {_Ptr_res1, _Ptr_res1 + _Count2};
3352+
} else {
3353+
return {_Ptr_res1, _Ptr_res1};
3354+
}
3355+
} else {
3356+
_First1 += _Ptr_res1 - _Ptr1;
3357+
if (_Ptr_res1 != _Ptr_last1) {
3358+
return {_First1, _First1 + _Count2_as1};
3359+
} else {
3360+
return {_First1, _First1};
3361+
}
3362+
}
3363+
}
3364+
}
3365+
#endif // _USE_STD_VECTOR_ALGORITHMS
33003366

33013367
for (auto _Candidate = _First1 + (_Count1 - _Count2_as1);; --_Candidate) {
33023368
auto _Match_and_mid1 =

0 commit comments

Comments
 (0)