9
9
#include <yvals.h>
10
10
#if _STL_COMPILER_PREPROCESSOR
11
11
12
+ #include <__msvc_int128.hpp>
12
13
#include <__msvc_iter_core.hpp>
13
14
#include <climits>
14
15
#include <cstdlib>
@@ -6029,7 +6030,8 @@ public:
6029
6030
6030
6031
using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
6031
6032
6032
- explicit _Rng_from_urng(_Urng& _Func) : _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(_Udiff(-1)) {
6033
+ explicit _Rng_from_urng(_Urng& _Func)
6034
+ : _Ref(_Func), _Bits(CHAR_BIT * sizeof(_Udiff)), _Bmask(static_cast<_Udiff>(-1)) {
6033
6035
for (; static_cast<_Udiff>((_Urng::max)() - (_Urng::min)()) < _Bmask; _Bmask >>= 1) {
6034
6036
--_Bits;
6035
6037
}
@@ -6040,7 +6042,7 @@ public:
6040
6042
_Udiff _Ret = 0; // random bits
6041
6043
_Udiff _Mask = 0; // 2^N - 1, _Ret is within [0, _Mask]
6042
6044
6043
- while (_Mask < _Udiff(_Index - 1)) { // need more random bits
6045
+ while (_Mask < static_cast< _Udiff> (_Index - 1)) { // need more random bits
6044
6046
_Ret <<= _Bits - 1; // avoid full shift
6045
6047
_Ret <<= 1;
6046
6048
_Ret |= _Get_bits();
@@ -6050,7 +6052,7 @@ public:
6050
6052
}
6051
6053
6052
6054
// _Ret is [0, _Mask], _Index - 1 <= _Mask, return if unbiased
6053
- if (_Ret / _Index < _Mask / _Index || _Mask % _Index == _Udiff(_Index - 1)) {
6055
+ if (_Ret / _Index < _Mask / _Index || _Mask % _Index == static_cast< _Udiff> (_Index - 1)) {
6054
6056
return static_cast<_Diff>(_Ret % _Index);
6055
6057
}
6056
6058
}
@@ -6074,7 +6076,7 @@ public:
6074
6076
private:
6075
6077
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
6076
6078
for (;;) { // repeat until random value is in range
6077
- _Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
6079
+ const _Udiff _Val = static_cast<_Udiff>(_Ref() - (_Urng::min)());
6078
6080
6079
6081
if (_Val <= _Bmask) {
6080
6082
return _Val;
@@ -6087,6 +6089,128 @@ private:
6087
6089
_Udiff _Bmask; // 2^_Bits - 1
6088
6090
};
6089
6091
6092
+ template <class _Diff, class _Urng>
6093
+ class _Rng_from_urng_v2 { // wrap a URNG as an RNG
6094
+ public:
6095
+ using _Ty0 = make_unsigned_t<_Diff>;
6096
+ using _Ty1 = _Invoke_result_t<_Urng&>;
6097
+
6098
+ using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
6099
+ static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
6100
+ using _Uprod = conditional_t<_Udiff_bits <= 16, uint32_t, conditional_t<_Udiff_bits <= 32, uint64_t, _Unsigned128>>;
6101
+
6102
+ explicit _Rng_from_urng_v2(_Urng& _Func) : _Ref(_Func) {}
6103
+
6104
+ _Diff operator()(_Diff _Index) { // adapt _Urng closed range to [0, _Index)
6105
+ // From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1),
6106
+ // 2019.
6107
+ //
6108
+ // Algorithm 5 <-> This Code:
6109
+ // m <-> _Product
6110
+ // l <-> _Rem
6111
+ // s <-> _Index
6112
+ // t <-> _Threshold
6113
+ // L <-> _Generated_bits
6114
+ // 2^L - 1 <-> _Mask
6115
+
6116
+ _Udiff _Mask = _Bmask;
6117
+ unsigned int _Niter = 1;
6118
+
6119
+ if constexpr (_Bits < _Udiff_bits) {
6120
+ while (_Mask < static_cast<_Udiff>(_Index - 1)) {
6121
+ _Mask <<= _Bits;
6122
+ _Mask |= _Bmask;
6123
+ ++_Niter;
6124
+ }
6125
+ }
6126
+
6127
+ // x <- random integer in [0, 2^L)
6128
+ // m <- x * s
6129
+ auto _Product = _Get_random_product(_Index, _Niter);
6130
+ // l <- m mod 2^L
6131
+ auto _Rem = static_cast<_Udiff>(_Product) & _Mask;
6132
+
6133
+ if (_Rem < _Index) {
6134
+ // t <- (2^L - s) mod s
6135
+ const auto _Threshold = (_Mask - _Index + 1) % _Index;
6136
+ while (_Rem < _Threshold) {
6137
+ _Product = _Get_random_product(_Index, _Niter);
6138
+ _Rem = static_cast<_Udiff>(_Product) & _Mask;
6139
+ }
6140
+ }
6141
+
6142
+ unsigned int _Generated_bits;
6143
+ if constexpr (_Bits < _Udiff_bits) {
6144
+ _Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
6145
+ } else {
6146
+ _Generated_bits = _Udiff_bits;
6147
+ }
6148
+
6149
+ // m / 2^L
6150
+ return static_cast<_Diff>(_Product >> _Generated_bits);
6151
+ }
6152
+
6153
+ _Udiff _Get_all_bits() {
6154
+ _Udiff _Ret = _Get_bits();
6155
+
6156
+ if constexpr (_Bits < _Udiff_bits) {
6157
+ for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
6158
+ _Ret <<= _Bits;
6159
+ _Ret |= _Get_bits();
6160
+ }
6161
+ }
6162
+
6163
+ return _Ret;
6164
+ }
6165
+
6166
+ _Rng_from_urng_v2(const _Rng_from_urng_v2&) = delete;
6167
+ _Rng_from_urng_v2& operator=(const _Rng_from_urng_v2&) = delete;
6168
+
6169
+ private:
6170
+ _Udiff _Get_bits() { // return a random value within [0, _Bmask]
6171
+ static constexpr auto _Urng_min = (_Urng::min)();
6172
+ for (;;) { // repeat until random value is in range
6173
+ const _Udiff _Val = _Ref() - _Urng_min;
6174
+
6175
+ if (_Val <= _Bmask) {
6176
+ return _Val;
6177
+ }
6178
+ }
6179
+ }
6180
+
6181
+ static constexpr size_t _Calc_bits() {
6182
+ auto _Bits_local = _Udiff_bits;
6183
+ auto _Bmask_local = static_cast<_Udiff>(-1);
6184
+ for (; (_Urng::max)() - (_Urng::min)() < _Bmask_local; _Bmask_local >>= 1) {
6185
+ --_Bits_local;
6186
+ }
6187
+
6188
+ return _Bits_local;
6189
+ }
6190
+
6191
+ _Uprod _Get_random_product(const _Diff _Index, unsigned int _Niter) {
6192
+ _Udiff _Ret = _Get_bits();
6193
+ if constexpr (_Bits < _Udiff_bits) {
6194
+ while (--_Niter > 0) {
6195
+ _Ret <<= _Bits;
6196
+ _Ret |= _Get_bits();
6197
+ }
6198
+ }
6199
+
6200
+ if constexpr (is_same_v<_Udiff, uint64_t>) {
6201
+ uint64_t _High;
6202
+ const auto _Low = _Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
6203
+ return _Uprod{_Low, _High};
6204
+ } else {
6205
+ return _Uprod{_Ret} * _Uprod{_Index};
6206
+ }
6207
+ }
6208
+
6209
+ _Urng& _Ref; // reference to URNG
6210
+ static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
6211
+ static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
6212
+ };
6213
+
6090
6214
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xbad_alloc();
6091
6215
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xinvalid_argument(_In_z_ const char*);
6092
6216
extern "C++" [[noreturn]] _CRTIMP2_PURE void __CLRCALL_PURE_OR_CDECL _Xlength_error(_In_z_ const char*);
0 commit comments