Aleph-w 3.0
A C++ Library for Data Structures and Algorithms
Loading...
Searching...
No Matches
modular_arithmetic.H
Go to the documentation of this file.
1/*
2 Aleph_w
3
4 Data structures & Algorithms
5 version 2.0.0b
6 https://github.com/lrleon/Aleph-w
7
8 This file is part of Aleph-w library
9
10 Copyright (c) 2002-2026 Leandro Rabindranath Leon
11
12 Permission is hereby granted, free of charge, to any person obtaining a copy
13 of this software and associated documentation files (the "Software"), to deal
14 in the Software without restriction, including without limitation the rights
15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 copies of the Software, and to permit persons to whom the Software is
17 furnished to do so, subject to the following conditions:
18
19 The above copyright notice and this permission notice shall be included in all
20 copies or substantial portions of the Software.
21
22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
29*/
30
41# ifndef MODULAR_ARITHMETIC_H
42# define MODULAR_ARITHMETIC_H
43
44# include <cstdint>
45# include <type_traits>
46
47# include <ah-errors.H>
48# include <tpl_array.H>
49
50namespace Aleph
51{
64 {
65 ah_invalid_argument_if(m == 0) << "mod_mul: modulus must be > 0";
66
67# if defined(__SIZEOF_INT128__)
68 return static_cast<uint64_t>((static_cast<__uint128_t>(a) * b) % m);
69# else
70 uint64_t res = 0;
71 a %= m;
72 while (b)
73 {
74 if (b & 1)
75 {
76 if (m - res <= a) res = a - (m - res);
77 else res += a;
78 }
79 if (m - a <= a) a = a - (m - a);
80 else a <<= 1;
81 b >>= 1;
82 }
83 return res;
84# endif
85 }
86
98 {
99 ah_invalid_argument_if(m == 0) << "mod_exp: modulus must be > 0";
100 if (m == 1) return 0;
101 uint64_t res = 1;
102 base %= m;
103 while (exp > 0)
104 {
105 if (exp & 1)
106 res = mod_mul(res, base, m);
107 base = mod_mul(base, base, m);
108 exp >>= 1;
109 }
110 return res;
111 }
112
123 template <typename T>
124 requires (std::is_integral_v<T> and std::is_signed_v<T>)
125 [[nodiscard]] T ext_gcd(T a, T b, T & x, T & y) noexcept
126 {
127 if (b == 0)
128 {
129 x = 1;
130 y = 0;
131 return a;
132 }
133 T x1, y1;
134 T d = ext_gcd(b, a % b, x1, y1);
135 x = y1;
136 y = x1 - static_cast<T>(a / b) * y1;
137 return d;
138 }
139
156 [[nodiscard]] inline uint64_t mod_inv(const uint64_t a, const uint64_t m)
157 {
158 ah_domain_error_if(m == 0) << "mod_inv: modulus cannot be 0";
159
160 if (m == 1)
161 return 0;
162
163 const uint64_t a_mod = a % m;
165 << "Modular inverse does not exist (" << a << " is 0 mod " << m << ")";
166
167 // Iterative extended Euclidean algorithm using unsigned arithmetic.
168 // We track t0 such that a * t0 ≡ r0 (mod m) at each step,
169 // keeping t0 in [0, m) to avoid signed overflow.
170 uint64_t r0 = a_mod, r1 = m;
171 uint64_t t0 = 1, t1 = 0;
172
173 while (r1 != 0)
174 {
175 const uint64_t q = r0 / r1;
176
177 const uint64_t tmp_r = r0 - q * r1;
178 r0 = r1;
179 r1 = tmp_r;
180
181 // t_new = t0 - q * t1 (mod m), computed in unsigned
182 const uint64_t qt1 = mod_mul(q, t1, m);
183 const uint64_t tmp_t = (t0 >= qt1) ? (t0 - qt1) : (m - (qt1 - t0));
184 t0 = t1;
185 t1 = tmp_t;
186 }
187
189 << "Modular inverse does not exist (numbers " << a
190 << " and " << m << " are not coprime)";
191
192 return t0;
193 }
194
195# if defined(__SIZEOF_INT128__)
197 struct MontgomeryCtx;
198
199 namespace detail
200 {
202 [[nodiscard]] constexpr MontgomeryCtx
203 montgomery_ctx_unchecked(const uint64_t mod) noexcept;
204 }
205
221 struct MontgomeryCtx
222 {
223 public:
224 MontgomeryCtx() = delete;
225
229 [[nodiscard]] constexpr uint64_t mod() const noexcept { return mod_; }
230
234 [[nodiscard]] constexpr uint64_t mod2() const noexcept { return mod2_; }
235
239 [[nodiscard]] constexpr uint64_t r() const noexcept { return r_; }
240
244 [[nodiscard]] constexpr uint64_t r2() const noexcept { return r2_; }
245
250 {
251 return mod_inv_neg_;
252 }
253
254 private:
255 friend constexpr MontgomeryCtx
256 detail::montgomery_ctx_unchecked(const uint64_t mod) noexcept;
257
259 constexpr MontgomeryCtx(const uint64_t mod,
260 const uint64_t mod2,
261 const uint64_t r,
262 const uint64_t r2,
264 : mod_(mod), mod2_(mod2), r_(r), r2_(r2), mod_inv_neg_(mod_inv_neg)
265 {
266 /* empty */
267 }
268
269 uint64_t mod_;
271 uint64_t r_;
272 uint64_t r2_;
274 };
275
276 namespace detail
277 {
278 [[nodiscard]] constexpr uint64_t
279 montgomery_neg_inverse(const uint64_t mod) noexcept
280 {
281 uint64_t inv = 1;
282 for (size_t i = 0; i < 6; ++i)
283 inv *= 2 - mod * inv;
284 return ~inv + 1;
285 }
286
287 [[nodiscard]] constexpr MontgomeryCtx
288 montgomery_ctx_unchecked(const uint64_t mod) noexcept
289 {
290 const __uint128_t r128 = static_cast<__uint128_t>(1) << 64;
291 const auto r = static_cast<uint64_t>(r128 % mod);
292 const auto r2 = static_cast<uint64_t>((static_cast<__uint128_t>(r) * r) % mod);
293
294 return MontgomeryCtx(mod,
295 mod <= UINT64_MAX - mod ? mod + mod : 0,
296 r,
297 r2,
299 }
300 }
301
308 [[nodiscard]] inline MontgomeryCtx
309 montgomery_ctx(const uint64_t mod)
310 {
311 ah_invalid_argument_if(mod <= 1)
312 << "montgomery_ctx: modulus must be > 1";
313 ah_invalid_argument_if((mod & 1ULL) == 0)
314 << "montgomery_ctx: modulus " << mod << " must be odd";
315 return detail::montgomery_ctx_unchecked(mod);
316 }
317
323 template <uint64_t Mod>
324 [[nodiscard]] consteval MontgomeryCtx
326 {
327 static_assert(Mod > 1, "montgomery_ctx_for_mod: modulus must be > 1");
328 static_assert((Mod & 1ULL) == 1ULL,
329 "montgomery_ctx_for_mod: modulus must be odd");
330 return detail::montgomery_ctx_unchecked(Mod);
331 }
332
356 [[nodiscard]] constexpr uint64_t
357 mont_reduce(const __uint128_t x,
358 const MontgomeryCtx & ctx) noexcept
359 {
360 const uint64_t q = static_cast<uint64_t>(x) * ctx.mod_inv_neg();
361 const auto x_lo = static_cast<uint64_t>(x);
362 const auto x_hi = static_cast<uint64_t>(x >> 64);
363 const __uint128_t qmod = static_cast<__uint128_t>(q) * ctx.mod();
364 const auto qmod_lo = static_cast<uint64_t>(qmod);
365 const auto qmod_hi = static_cast<uint64_t>(qmod >> 64);
366 const uint64_t carry = qmod_lo > UINT64_MAX - x_lo ? 1ULL : 0ULL;
367
368 // t = (x + q·p) / R; REDC guarantees 0 ≤ t < 2p.
369 const __uint128_t t = static_cast<__uint128_t>(x_hi) + qmod_hi + carry;
370
371 // Fast path: 2p < 2^64 so t fits in uint64_t; one conditional subtraction
372 // reduces to [0, p) using only additions and comparisons.
373 if (ctx.mod2() != 0)
374 {
375 const uint64_t t64 = static_cast<uint64_t>(t);
376 return t64 >= ctx.mod() ? t64 - ctx.mod() : t64;
377 }
378
379 // Fallback for large primes (p ≥ 2^63): t may exceed UINT64_MAX, so keep
380 // it in __uint128_t and use a single 128-bit division to finish.
381 return static_cast<uint64_t>(t % ctx.mod());
382 }
383
393 [[nodiscard]] constexpr uint64_t
394 mont_mul(const uint64_t a,
395 const uint64_t b,
396 const MontgomeryCtx & ctx) noexcept
397 {
398 return mont_reduce(static_cast<__uint128_t>(a) * b, ctx);
399 }
400
407 [[nodiscard]] constexpr uint64_t
408 to_mont(const uint64_t a,
409 const MontgomeryCtx & ctx) noexcept
410 {
411 return mont_mul(a % ctx.mod(), ctx.r2(), ctx);
412 }
413
420 [[nodiscard]] constexpr uint64_t
421 from_mont(const uint64_t a,
422 const MontgomeryCtx & ctx) noexcept
423 {
424 return mont_reduce(a, ctx);
425 }
426
436 [[nodiscard]] constexpr uint64_t
437 mont_exp(uint64_t base,
439 const MontgomeryCtx & ctx) noexcept
440 {
441 uint64_t result = to_mont(1, ctx);
442 while (exp > 0)
443 {
444 if (exp & 1ULL)
445 result = mont_mul(result, base, ctx);
446 base = mont_mul(base, base, ctx);
447 exp >>= 1;
448 }
449 return result;
450 }
451# endif
452
466 const Array<uint64_t> & mod)
467 {
468 ah_invalid_argument_if(rem.size() != mod.size())
469 << "crt: arrays must have the same size (got " << rem.size()
470 << " vs " << mod.size() << ")";
471
472 const size_t n = rem.size();
473 if (n == 0)
474 return 0;
475
476 // Compute product of all moduli with overflow detection
477 uint64_t prod = 1;
478 for (size_t i = 0; i < n; ++i)
479 {
480 ah_invalid_argument_if(mod[i] <= 1)
481 << "crt: all moduli must be > 1 (got " << mod[i] << " at index " << i << ")";
482
484 << "crt: product of moduli overflows uint64_t at index " << i;
485 prod *= mod[i];
486 }
487
488 uint64_t result = 0;
489 for (size_t i = 0; i < n; ++i)
490 {
491 const uint64_t p = prod / mod[i];
492 const uint64_t inv = mod_inv(p, mod[i]);
493 uint64_t term = mod_mul(rem[i], p, prod);
494 term = mod_mul(term, inv, prod);
495 if (result >= prod - term)
496 result -= (prod - term);
497 else
498 result += term;
499 }
500
501 return result;
502 }
503} // namespace Aleph
504
505# endif // MODULAR_ARITHMETIC_H
Exception handling system with formatted messages for Aleph-w.
#define ah_overflow_error_if(C)
Throws std::overflow_error if condition holds.
Definition ah-errors.H:463
#define ah_domain_error_if(C)
Throws std::domain_error if condition holds.
Definition ah-errors.H:522
#define ah_invalid_argument_if(C)
Throws std::invalid_argument if condition holds.
Definition ah-errors.H:639
Simple dynamic array with automatic resizing and functional operations.
Definition tpl_array.H:139
__gmp_expr< T, __gmp_unary_expr< __gmp_expr< T, U >, __gmp_y1_function > > y1(const __gmp_expr< T, U > &expr)
Definition gmpfrxx.h:4103
__gmp_expr< T, __gmp_unary_expr< __gmp_expr< T, U >, __gmp_exp_function > > exp(const __gmp_expr< T, U > &expr)
Definition gmpfrxx.h:4066
static mpfr_t y
Definition mpfr_mul_d.c:3
Main namespace for Aleph-w library functions.
Definition ah-arena.H:89
and
Check uniqueness with explicit hash + equality functors.
uint64_t mod_inv(const uint64_t a, const uint64_t m)
Modular Inverse.
T ext_gcd(T a, T b, T &x, T &y) noexcept
Extended Euclidean Algorithm.
Divide_Conquer_DP_Result< Cost > divide_and_conquer_partition_dp(const size_t groups, const size_t n, Transition_Cost_Fn transition_cost, const Cost inf=dp_optimization_detail::default_inf< Cost >())
Optimize partition DP using divide-and-conquer optimization.
std::decay_t< typename HeadC::Item_Type > T
Definition ah-zip.H:105
uint64_t mod_exp(uint64_t base, uint64_t exp, const uint64_t m)
Modular exponentiation.
uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t m)
Safe 64-bit modular multiplication.
uint64_t crt(const Array< uint64_t > &rem, const Array< uint64_t > &mod)
Chinese Remainder Theorem (CRT).
double mod(double a, double b)
FooMap m(5, fst_unit_pair_hash, snd_unit_pair_hash)
gsl_rng * r
Dynamic array container with automatic resizing.