Skip to content

Commit 8334fca

Browse files
MattStephansonstrega-nilStephanTLavavej
authored
<random>: Implement Lemire's fast integer generation (#3012)
Co-authored-by: Nicole Mazzuca <[email protected]> Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 2f03bdf commit 8334fca

File tree

7 files changed

+321
-10
lines changed

7 files changed

+321
-10
lines changed

benchmarks/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,5 @@ function(add_benchmark name)
7171
target_link_libraries(benchmark-${name} PRIVATE benchmark::benchmark)
7272
endfunction()
7373

74-
add_benchmark(std_copy
75-
src/std_copy.cpp
76-
CXX_STANDARD 23
77-
)
74+
add_benchmark(std_copy src/std_copy.cpp)
75+
add_benchmark(random_integer_generation src/random_integer_generation.cpp)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <benchmark/benchmark.h>
5+
#include <cstdint>
6+
#include <random>
7+
8+
/// Test URBGs alone
9+
10+
static void BM_mt19937(benchmark::State& state) {
11+
std::mt19937 gen;
12+
for (auto _ : state) {
13+
benchmark::DoNotOptimize(gen());
14+
}
15+
}
16+
BENCHMARK(BM_mt19937);
17+
18+
static void BM_mt19937_64(benchmark::State& state) {
19+
std::mt19937_64 gen;
20+
for (auto _ : state) {
21+
benchmark::DoNotOptimize(gen());
22+
}
23+
}
24+
BENCHMARK(BM_mt19937_64);
25+
26+
static void BM_lcg(benchmark::State& state) {
27+
std::minstd_rand gen;
28+
for (auto _ : state) {
29+
benchmark::DoNotOptimize(gen());
30+
}
31+
}
32+
BENCHMARK(BM_lcg);
33+
34+
std::uint32_t GetMax() {
35+
std::random_device gen;
36+
std::uniform_int_distribution<std::uint32_t> dist(10'000'000, 20'000'000);
37+
return dist(gen);
38+
}
39+
40+
static const std::uint32_t maximum = GetMax(); // random divisor to prevent strength reduction
41+
42+
/// Test mt19937
43+
44+
static void BM_raw_mt19937_old(benchmark::State& state) {
45+
std::mt19937 gen;
46+
std::_Rng_from_urng<std::uint32_t, decltype(gen)> rng(gen);
47+
for (auto _ : state) {
48+
benchmark::DoNotOptimize(rng(maximum));
49+
}
50+
}
51+
BENCHMARK(BM_raw_mt19937_old);
52+
53+
static void BM_raw_mt19937_new(benchmark::State& state) {
54+
std::mt19937 gen;
55+
std::_Rng_from_urng_v2<std::uint32_t, decltype(gen)> rng(gen);
56+
for (auto _ : state) {
57+
benchmark::DoNotOptimize(rng(maximum));
58+
}
59+
}
60+
BENCHMARK(BM_raw_mt19937_new);
61+
62+
/// Test mt19937_64
63+
64+
static void BM_raw_mt19937_64_old(benchmark::State& state) {
65+
std::mt19937_64 gen;
66+
std::_Rng_from_urng<std::uint64_t, decltype(gen)> rng(gen);
67+
for (auto _ : state) {
68+
benchmark::DoNotOptimize(rng(maximum));
69+
}
70+
}
71+
BENCHMARK(BM_raw_mt19937_64_old);
72+
73+
static void BM_raw_mt19937_64_new(benchmark::State& state) {
74+
std::mt19937_64 gen;
75+
std::_Rng_from_urng_v2<std::uint64_t, decltype(gen)> rng(gen);
76+
for (auto _ : state) {
77+
benchmark::DoNotOptimize(rng(maximum));
78+
}
79+
}
80+
BENCHMARK(BM_raw_mt19937_64_new);
81+
82+
/// Test minstd_rand
83+
84+
static void BM_raw_lcg_old(benchmark::State& state) {
85+
std::minstd_rand gen;
86+
std::_Rng_from_urng<std::uint32_t, decltype(gen)> rng(gen);
87+
for (auto _ : state) {
88+
benchmark::DoNotOptimize(rng(maximum));
89+
}
90+
}
91+
BENCHMARK(BM_raw_lcg_old);
92+
93+
static void BM_raw_lcg_new(benchmark::State& state) {
94+
std::minstd_rand gen;
95+
std::_Rng_from_urng_v2<std::uint32_t, decltype(gen)> rng(gen);
96+
for (auto _ : state) {
97+
benchmark::DoNotOptimize(rng(maximum));
98+
}
99+
}
100+
BENCHMARK(BM_raw_lcg_new);
101+
102+
BENCHMARK_MAIN();

stl/inc/random

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,7 +1844,9 @@ private:
18441844

18451845
template <class _Engine>
18461846
result_type _Eval(_Engine& _Eng, _Ty _Min, _Ty _Max) const { // compute next value in range [_Min, _Max]
1847-
_Rng_from_urng<_Uty, _Engine> _Generator(_Eng);
1847+
conditional_t<_Has_static_min_max<_Engine>::value, _Rng_from_urng_v2<_Uty, _Engine>,
1848+
_Rng_from_urng<_Uty, _Engine>>
1849+
_Generator(_Eng);
18481850

18491851
const _Uty _Umin = _Adjust(static_cast<_Uty>(_Min));
18501852
const _Uty _Umax = _Adjust(static_cast<_Uty>(_Max));
@@ -1862,7 +1864,7 @@ private:
18621864

18631865
static _Uty _Adjust(_Uty _Uval) { // convert signed ranges to unsigned ranges and vice versa
18641866
if constexpr (is_signed_v<_Ty>) {
1865-
const _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)
1867+
constexpr _Uty _Adjuster = (static_cast<_Uty>(-1) >> 1) + 1; // 2^(N-1)
18661868

18671869
if (_Uval < _Adjuster) {
18681870
return static_cast<_Uty>(_Uval + _Adjuster);

stl/inc/xutility

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <yvals.h>
1010
#if _STL_COMPILER_PREPROCESSOR
1111

12+
#include <__msvc_int128.hpp>
1213
#include <__msvc_iter_core.hpp>
1314
#include <climits>
1415
#include <cstdlib>
@@ -6029,7 +6030,8 @@ public:
60296030

60306031
using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
60316032

6032-
explicit _Rng_from_urng(_Urng& _Func) : _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(_Udiff(-1)) {
6033+
explicit _Rng_from_urng(_Urng& _Func)
6034+
: _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(static_cast<_Udiff>(-1)) {
60336035
for (; static_cast<_Udiff>((_Urng::max)() - (_Urng::min)()) < _Bmask; _Bmask >>= 1) {
60346036
--_Bits;
60356037
}
@@ -6040,7 +6042,7 @@ public:
60406042
_Udiff _Ret = 0; // random bits
60416043
_Udiff _Mask = 0; // 2^N - 1, _Ret is within [0, _Mask]
60426044

6043-
while (_Mask < _Udiff(_Index - 1)) { // need more random bits
6045+
while (_Mask < static_cast<_Udiff>(_Index - 1)) { // need more random bits
60446046
_Ret <<= _Bits - 1; // avoid full shift
60456047
_Ret <<= 1;
60466048
_Ret |= _Get_bits();
@@ -6050,7 +6052,7 @@ public:
60506052
}
60516053

60526054
// _Ret is [0, _Mask], _Index - 1 <= _Mask, return if unbiased
6053-
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == _Udiff(_Index - 1)) {
6055+
if (_Ret / _Index < _Mask / _Index || _Mask % _Index == static_cast<_Udiff>(_Index - 1)) {
60546056
return static_cast<_Diff>(_Ret % _Index);
60556057
}
60566058
}
@@ -6074,7 +6076,7 @@ public:
60746076
private:
60756077
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
60766078
for (;;) { // repeat until random value is in range
6077-
_Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
6079+
const _Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
60786080

60796081
if (_Val <= _Bmask) {
60806082
return _Val;
@@ -6087,6 +6089,128 @@ private:
60876089
_Udiff _Bmask; // 2^_Bits - 1
60886090
};
60896091

6092+
template <class _Diff, class _Urng>
6093+
class _Rng_from_urng_v2 { // wrap a URNG as an RNG
6094+
public:
6095+
using _Ty0 = make_unsigned_t<_Diff>;
6096+
using _Ty1 = _Invoke_result_t<_Urng&>;
6097+
6098+
using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
6099+
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
6100+
using _Uprod = conditional_t<_Udiff_bits <= 16, uint32_t, conditional_t<_Udiff_bits <= 32, uint64_t, _Unsigned128>>;
6101+
6102+
explicit _Rng_from_urng_v2(_Urng& _Func) : _Ref(_Func) {}
6103+
6104+
_Diff operator()(_Diff _Index) { // adapt _Urng closed range to [0, _Index)
6105+
// From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1),
6106+
// 2019.
6107+
//
6108+
// Algorithm 5 <-> This Code:
6109+
// m <-> _Product
6110+
// l <-> _Rem
6111+
// s <-> _Index
6112+
// t <-> _Threshold
6113+
// L <-> _Generated_bits
6114+
// 2^L - 1 <-> _Mask
6115+
6116+
_Udiff _Mask = _Bmask;
6117+
unsigned int _Niter = 1;
6118+
6119+
if constexpr (_Bits < _Udiff_bits) {
6120+
while (_Mask < static_cast<_Udiff>(_Index - 1)) {
6121+
_Mask <<= _Bits;
6122+
_Mask |= _Bmask;
6123+
++_Niter;
6124+
}
6125+
}
6126+
6127+
// x <- random integer in [0, 2^L)
6128+
// m <- x * s
6129+
auto _Product = _Get_random_product(_Index, _Niter);
6130+
// l <- m mod 2^L
6131+
auto _Rem = static_cast<_Udiff>(_Product) & _Mask;
6132+
6133+
if (_Rem < _Index) {
6134+
// t <- (2^L - s) mod s
6135+
const auto _Threshold = (_Mask - _Index + 1) % _Index;
6136+
while (_Rem < _Threshold) {
6137+
_Product = _Get_random_product(_Index, _Niter);
6138+
_Rem = static_cast<_Udiff>(_Product) & _Mask;
6139+
}
6140+
}
6141+
6142+
unsigned int _Generated_bits;
6143+
if constexpr (_Bits < _Udiff_bits) {
6144+
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
6145+
} else {
6146+
_Generated_bits = _Udiff_bits;
6147+
}
6148+
6149+
// m / 2^L
6150+
return static_cast<_Diff>(_Product >> _Generated_bits);
6151+
}
6152+
6153+
_Udiff _Get_all_bits() {
6154+
_Udiff _Ret = _Get_bits();
6155+
6156+
if constexpr (_Bits < _Udiff_bits) {
6157+
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
6158+
_Ret <<= _Bits;
6159+
_Ret |= _Get_bits();
6160+
}
6161+
}
6162+
6163+
return _Ret;
6164+
}
6165+
6166+
_Rng_from_urng_v2(const _Rng_from_urng_v2&) = delete;
6167+
_Rng_from_urng_v2& operator=(const _Rng_from_urng_v2&) = delete;
6168+
6169+
private:
6170+
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
6171+
static constexpr auto _Urng_min = (_Urng::min)();
6172+
for (;;) { // repeat until random value is in range
6173+
const _Udiff _Val = _Ref() - _Urng_min;
6174+
6175+
if (_Val <= _Bmask) {
6176+
return _Val;
6177+
}
6178+
}
6179+
}
6180+
6181+
static constexpr size_t _Calc_bits() {
6182+
auto _Bits_local = _Udiff_bits;
6183+
auto _Bmask_local = static_cast<_Udiff>(-1);
6184+
for (; (_Urng::max)() - (_Urng::min)() < _Bmask_local; _Bmask_local >>= 1) {
6185+
--_Bits_local;
6186+
}
6187+
6188+
return _Bits_local;
6189+
}
6190+
6191+
_Uprod _Get_random_product(const _Diff _Index, unsigned int _Niter) {
6192+
_Udiff _Ret = _Get_bits();
6193+
if constexpr (_Bits < _Udiff_bits) {
6194+
while (--_Niter > 0) {
6195+
_Ret <<= _Bits;
6196+
_Ret |= _Get_bits();
6197+
}
6198+
}
6199+
6200+
if constexpr (is_same_v<_Udiff, uint64_t>) {
6201+
uint64_t _High;
6202+
const auto _Low = _Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
6203+
return _Uprod{_Low, _High};
6204+
} else {
6205+
return _Uprod{_Ret} * _Uprod{_Index};
6206+
}
6207+
}
6208+
6209+
_Urng& _Ref; // reference to URNG
6210+
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
6211+
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
6212+
};
6213+
60906214
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xbad_alloc();
60916215
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xinvalid_argument(_In_z_ const char*);
60926216
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xlength_error(_In_z_ const char*);

tests/std/test.lst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ tests\Dev11_1150223_shared_mutex
156156
tests\Dev11_1158803_regex_thread_safety
157157
tests\Dev11_1180290_filesystem_error_code
158158
tests\GH_000177_forbidden_aliasing
159+
tests\GH_000178_uniform_int
159160
tests\GH_000342_filebuf_close
160161
tests\GH_000431_copy_move_family
161162
tests\GH_000431_equal_family
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
RUNALL_INCLUDE ..\usual_matrix.lst

0 commit comments

Comments
 (0)