Aleph-w 3.0
A C++ Library for Data Structures and Algorithms
Loading...
Searching...
No Matches
ntt.H
Go to the documentation of this file.
1/*
2 Aleph_w
3
4 Data structures & Algorithms
5 version 2.0.0b
6 https://github.com/lrleon/Aleph-w
7
8 This file is part of Aleph-w library
9
10 Copyright (c) 2002-2026 Leandro Rabindranath Leon
11
12 Permission is hereby granted, free of charge, to any person obtaining a copy
13 of this software and associated documentation files (the "Software"), to deal
14 in the Software without restriction, including without limitation the rights
15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 copies of the Software, and to permit persons to whom the Software is
17 furnished to do so, subject to the following conditions:
18
19 The above copyright notice and this permission notice shall be included in all
20 copies or substantial portions of the Software.
21
22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
29*/
30
43# ifndef NTT_H
44# define NTT_H
45
46# if !defined(__SIZEOF_INT128__)
47# error "ntt.H requires compiler support for __uint128_t"
48# endif
49
50# include <algorithm>
51# include <cstdlib>
52# include <cstdint>
53# include <future>
54# include <limits>
55# include <memory>
56# include <string>
57# include <string_view>
58# include <type_traits>
59# include <utility>
60
61# if (defined(__GNUC__) or defined(__clang__)) \
62 and (defined(__x86_64__) or defined(__i386__) \
63 or defined(_M_X64) or defined(_M_IX86))
64# include <immintrin.h>
65# define ALEPH_NTT_HAS_X86_AVX2_DISPATCH 1
66# define ALEPH_NTT_AVX2_TARGET __attribute__((target("avx2")))
67# else
68# define ALEPH_NTT_HAS_X86_AVX2_DISPATCH 0
69# endif
70
71# if (defined(__GNUC__) or defined(__clang__)) \
72 and (defined(__aarch64__) or defined(_M_ARM64))
73# include <arm_neon.h>
74# define ALEPH_NTT_HAS_ARM_NEON_DISPATCH 1
75# else
76# define ALEPH_NTT_HAS_ARM_NEON_DISPATCH 0
77# endif
78
79# if ALEPH_NTT_HAS_ARM_NEON_DISPATCH and defined(__linux__)
80# include <sys/auxv.h>
81# include <asm/hwcap.h>
82# endif
83
84# include <ah-errors.H>
85# include <modular_arithmetic.H>
86# include <thread_pool.H>
87# include <tpl_array.H>
88
89namespace Aleph
90{
113 template <uint64_t MOD = 998244353ULL, uint64_t ROOT = 3ULL>
114 class NTT
115 {
116 static_assert(MOD > 1, "NTT requires MOD > 1");
117 static_assert((MOD & 1ULL) == 1ULL, "NTT requires an odd modulus");
118 static_assert(ROOT > 0 and ROOT < MOD, "NTT root must lie in (0, MOD)");
119
120 public:
122 enum class NTTSimdBackend
123 {
124 scalar,
125 avx2,
126 neon
127 };
128
129 private:
130 enum class SimdPreference
131 {
132 automatic,
134 avx2_only,
136 };
137
138 enum class Representation
139 {
140 standard,
142 };
143
145
146 [[nodiscard]] static constexpr bool
147 is_power_of_two(const size_t n) noexcept
148 {
149 return n != 0 and (n & (n - 1)) == 0;
150 }
151
152 [[nodiscard]] static constexpr uint64_t
154 const uint64_t rhs) noexcept
155 {
156 const __uint128_t sum = static_cast<__uint128_t>(lhs) + rhs;
157 return static_cast<uint64_t>(sum >= MOD ? sum - MOD : sum);
158 }
159
160 [[nodiscard]] static constexpr uint64_t
162 const uint64_t rhs) noexcept
163 {
164 return lhs >= rhs ? lhs - rhs : MOD - (rhs - lhs);
165 }
166
167 [[nodiscard]] static constexpr bool
172
173 [[nodiscard]] static constexpr uint64_t
175 uint64_t exp) noexcept
176 {
177 if (MOD == 1)
178 return 0;
179
180 uint64_t result = 1;
181 base %= MOD;
182 while (exp > 0)
183 {
184 if (exp & 1ULL)
185 result = static_cast<uint64_t>(
186 (static_cast<__uint128_t>(result) * base) % MOD);
187 base = static_cast<uint64_t>(
188 (static_cast<__uint128_t>(base) * base) % MOD);
189 exp >>= 1;
190 }
191
192 return result;
193 }
194
195 [[nodiscard]] static constexpr uint64_t
197 {
198 uint64_t value = MOD - 1;
199 uint64_t size = 1;
200 while ((value & 1ULL) == 0)
201 {
202 size <<= 1;
203 value >>= 1;
204 }
205 return size;
206 }
207
208 [[nodiscard]] static constexpr bool
209 supports_power_of_two_size(const size_t n) noexcept
210 {
211 return is_power_of_two(n)
213 }
214
215 [[nodiscard]] static constexpr bool
216 supports_root_order(const uint64_t order) noexcept
217 {
218 return order != 0 and (MOD - 1) % order == 0;
219 }
220
221 [[nodiscard]] static constexpr bool
222 supports_bluestein_size(const size_t n) noexcept
223 {
224 if (n <= 1 or is_power_of_two(n))
225 return false;
226
227 if (n > static_cast<size_t>((MOD - 1) / 2))
228 return false;
229
230 const uint64_t order = static_cast<uint64_t>(n) * 2ULL;
231 if (not supports_root_order(order))
232 return false;
233
234 if (n > std::numeric_limits<size_t>::max() / 2)
235 return false;
236
237 const size_t required = n * 2 - 1;
238 size_t conv_size = 1;
239 while (conv_size < required)
240 {
241 if (conv_size > std::numeric_limits<size_t>::max() / 2)
242 return false;
243 conv_size <<= 1;
244 }
245
247 }
248
249 static void
251 const char * const ctx)
252 {
253 ah_invalid_argument_if(order == 0)
254 << ctx << ": root order must be positive";
256 << ctx << ": order " << order
257 << " does not divide MOD - 1 (" << (MOD - 1) << ")";
258 }
259
260 [[nodiscard]] static constexpr uint64_t
262 {
263 return pow_mod_constexpr(ROOT, (MOD - 1) / order);
264 }
265
266 static void
268 const char * const ctx)
269 {
271 << ctx << ": size must be positive";
273 << ctx << ": size " << n
274 << " is not supported by MOD " << MOD
275 << " (power-of-two sizes require n | 2^k <= "
277 << "; Bluestein sizes require 2*n | (MOD - 1) and an internal "
278 << "power-of-two convolution size <= " << max_transform_size() << ")";
279 }
280
281 [[nodiscard]] static size_t
283 const char * const ctx)
284 {
285 if (n <= 1)
286 return 1;
287
288 size_t value = 1;
289 while (value < n)
290 {
291 ah_overflow_error_if(value > std::numeric_limits<size_t>::max() / 2)
292 << ctx << ": next power of two overflows size_t for requested size "
293 << n;
294 value <<= 1;
295 }
296
297 return value;
298 }
299
300 [[nodiscard]] static Array<uint64_t>
302 const size_t n)
303 {
305 for (size_t i = 0; i < n; ++i)
306 output(i) = 0;
307
308 for (size_t i = 0; i < input.size(); ++i)
309 output(i) = input[i] % MOD;
310
311 return output;
312 }
313
314 [[nodiscard]] static Array<uint64_t>
316 const size_t length)
317 {
318 ah_invalid_argument_if(length > input.size())
319 << "NTT::prefix_copy: length " << length
320 << " exceeds input size " << input.size();
321
323 output.reserve(length);
324 for (size_t i = 0; i < length; ++i)
325 output.append(input[i]);
326 return output;
327 }
328
329 public:
330 static constexpr uint64_t mod = MOD;
331 static constexpr uint64_t root = ROOT;
332
333 [[nodiscard]] static constexpr const char *
335 {
336 switch (backend)
337 {
339 return "avx2";
341 return "neon";
343 default:
344 return "scalar";
345 }
346 }
347
348 private:
349 [[nodiscard]] static constexpr const char *
351 {
352 switch (preference)
353 {
355 return "scalar";
357 return "avx2";
359 return "neon";
361 default:
362 return "auto";
363 }
364 }
365
366 [[nodiscard]] static SimdPreference
368 {
369 if (const char *mode = std::getenv("ALEPH_NTT_SIMD");
370 mode != nullptr and mode[0] != '\0')
371 {
372 const std::string_view value(mode);
373 if (value == "scalar")
375 if (value == "avx2")
377 if (value == "neon")
379 }
380
382 }
383
384 [[nodiscard]] static NTTSimdBackend
393
394 public:
396 [[nodiscard]] static bool
398 {
399# if ALEPH_NTT_HAS_X86_AVX2_DISPATCH
400 static const bool available = []() noexcept
401 {
403 return static_cast<bool>(__builtin_cpu_supports("avx2"));
404 }();
405 return available;
406# else
407 return false;
408# endif
409 }
410
412 [[nodiscard]] static bool
414 {
415# if ALEPH_NTT_HAS_ARM_NEON_DISPATCH
416# if defined(__linux__) and defined(HWCAP_ASIMD)
417 static const bool available = []() noexcept
418 {
419 return (getauxval(AT_HWCAP) & HWCAP_ASIMD) != 0;
420 }();
421 return available;
422# else
423 return true;
424# endif
425# else
426 return false;
427# endif
428 }
429
431 [[nodiscard]] static NTTSimdBackend
452
454 [[nodiscard]] static const char *
459
465 class Plan
466 {
467 enum class Strategy
468 {
471 };
472
474 size_t n_ = 0;
475 size_t log_n_ = 0;
481 size_t bluestein_size_ = 0;
486 std::shared_ptr<const Plan> bluestein_plan_;
487
488 template <typename F>
489 static void
491 const size_t count,
492 F && fn,
493 const size_t chunk_size)
494 {
495 if (count == 0)
496 return;
497
498 if (pool != nullptr and pool->num_threads() > 1 and count > 1)
499 {
500 parallel_for_index(*pool, 0, count, std::forward<F>(fn),
501 chunk_size);
502 return;
503 }
504
505 for (size_t i = 0; i < count; ++i)
506 fn(i);
507 }
508
509 void
512 const size_t base,
513 const size_t half,
514 const size_t offset,
515 const size_t begin,
516 const size_t end) const
517 {
518 for (size_t j = begin; j < end; ++j)
519 {
520 const uint64_t u = a[base + j];
521 const uint64_t v =
522 mont_mul(a[base + j + half], twiddles[offset + j], mctx_);
523 a(base + j) = add_mod(u, v);
524 a(base + j + half) = sub_mod(u, v);
525 }
526 }
527
528 [[nodiscard]] bool
529 should_use_avx2(ThreadPool * const pool) const noexcept
530 {
531# if !ALEPH_NTT_HAS_X86_AVX2_DISPATCH
532 (void) pool;
533 return false;
534# else
536 return false;
537
538 if (pool != nullptr and pool->num_threads() > 1)
539 return false;
540
542# endif
543 }
544
545 [[nodiscard]] bool
546 should_use_neon(ThreadPool * const pool) const noexcept
547 {
548# if !ALEPH_NTT_HAS_ARM_NEON_DISPATCH
549 (void) pool;
550 return false;
551# else
553 return false;
554
555 if (pool != nullptr and pool->num_threads() > 1)
556 return false;
557
559# endif
560 }
561
562# if ALEPH_NTT_HAS_X86_AVX2_DISPATCH
563 [[nodiscard]] static __m256i
565 const __m256i rhs) noexcept ALEPH_NTT_AVX2_TARGET
566 {
567 alignas(32) static const uint64_t sign_bits[4] = {
568 0x8000000000000000ULL,
569 0x8000000000000000ULL,
570 0x8000000000000000ULL,
571 0x8000000000000000ULL
572 };
573 const __m256i sign =
574 _mm256_load_si256(reinterpret_cast<const __m256i *>(sign_bits));
579 return _mm256_or_si256(gt, eq);
580 }
581
582 static void
584 uint64_t * const high,
586 {
587 alignas(32) uint64_t products[4];
588 for (size_t lane = 0; lane < 4; ++lane)
590
591 alignas(32) static const uint64_t mod_lanes[4] = {MOD, MOD, MOD, MOD};
592 const __m256i modv =
593 _mm256_load_si256(reinterpret_cast<const __m256i *>(mod_lanes));
594 const __m256i u =
595 _mm256_loadu_si256(reinterpret_cast<const __m256i *>(low));
596 const __m256i v =
597 _mm256_load_si256(reinterpret_cast<const __m256i *>(products));
598
599 const __m256i sum = _mm256_add_epi64(u, v);
601 const __m256i sum_mask = avx2_cmp_ge_u64(sum, modv);
603
604 const __m256i diff = _mm256_sub_epi64(u, v);
605 const __m256i vu = _mm256_sub_epi64(v, u);
607 const __m256i diff_mask = avx2_cmp_ge_u64(u, v);
608 const __m256i diff_result =
610
611 _mm256_storeu_si256(reinterpret_cast<__m256i *>(low), sum_result);
612 _mm256_storeu_si256(reinterpret_cast<__m256i *>(high), diff_result);
613 }
614
615 void
618 {
619 for (size_t stage = 0; stage < log_n_; ++stage)
620 {
621 const size_t half = static_cast<size_t>(1) << stage;
622 const size_t len = half << 1;
623 const size_t blocks = n_ / len;
624 const size_t offset = half - 1;
625
626 for (size_t block = 0; block < blocks; ++block)
627 {
628 const size_t base = block * len;
629 size_t j = 0;
630 for (; j + 4 <= half; j += 4)
631 avx2_apply_chunk(&a(base + j),
632 &a(base + j + half),
633 &twiddles[offset + j]);
635 j, half);
636 }
637 }
638 }
639# endif
640
641# if ALEPH_NTT_HAS_ARM_NEON_DISPATCH
642 static void
644 uint64_t * const high,
645 const uint64_t * const twiddle_ptr)
646 {
647 alignas(16) uint64_t products[2];
648 for (size_t lane = 0; lane < 2; ++lane)
650
652 const uint64x2_t u = vld1q_u64(low);
653 const uint64x2_t v = vld1q_u64(products);
654 const uint64x2_t sum = vaddq_u64(u, v);
656 const uint64x2_t sum_mask = vcgeq_u64(sum, modv);
658
659 const uint64x2_t diff = vsubq_u64(u, v);
660 const uint64x2_t vu = vsubq_u64(v, u);
662 const uint64x2_t diff_mask = vcgeq_u64(u, v);
664
666 vst1q_u64(high, diff_result);
667 }
668
669 void
671 const Array<uint64_t> & twiddles) const
672 {
673 for (size_t stage = 0; stage < log_n_; ++stage)
674 {
675 const size_t half = static_cast<size_t>(1) << stage;
676 const size_t len = half << 1;
677 const size_t blocks = n_ / len;
678 const size_t offset = half - 1;
679
680 for (size_t block = 0; block < blocks; ++block)
681 {
682 const size_t base = block * len;
683 size_t j = 0;
684 for (; j + 2 <= half; j += 2)
685 neon_apply_chunk(&a(base + j),
686 &a(base + j + half),
687 &twiddles[offset + j]);
689 j, half);
690 }
691 }
692 }
693# endif
694
695 void
697 {
699 bit_rev_(0) = 0;
700 for (size_t i = 1, j = 0; i < n_; ++i)
701 {
702 size_t bit = n_ >> 1;
703 for (; j & bit; bit >>= 1)
704 j ^= bit;
705 j ^= bit;
706 bit_rev_(i) = j;
707 }
708 }
709
710 void
712 {
713 if (n_ <= 1)
714 return;
715
718
719 const uint64_t mont_one = to_mont(1, mctx_);
720 for (size_t stage = 0; stage < log_n_; ++stage)
721 {
722 const size_t half = static_cast<size_t>(1) << stage;
723 const size_t len = half << 1;
724 const size_t offset = half - 1;
725
728
733
734 for (size_t j = 0; j < half; ++j)
735 {
740 }
741 }
742 }
743
744 void
746 {
748
749 for (size_t value = n_; value > 1; value >>= 1)
750 ++log_n_;
751
754 inv_n_ = to_mont(mod_inv(static_cast<uint64_t>(n_), MOD), mctx_);
755 }
756
757 void
759 {
761
762 ah_invalid_argument_if(n_ > std::numeric_limits<size_t>::max() / 2)
763 << "NTT::Plan: size " << n_
764 << " is too large for Bluestein convolution sizing";
765
766 const size_t required = n_ * 2 - 1;
769 << "NTT::Plan: Bluestein internal size " << bluestein_size_
770 << " exceeds the power-of-two capacity of MOD " << MOD;
771
772 const uint64_t order = static_cast<uint64_t>(n_) * 2ULL;
773 NTT::validate_root_order(order, "NTT::Plan");
774
775 bluestein_plan_ = std::make_shared<Plan>(bluestein_size_);
778
779 const uint64_t z = NTT::primitive_root_of_order(order);
780 const uint64_t z_inv = mod_inv(z, MOD);
781 for (size_t i = 0; i < n_; ++i)
782 {
783 const uint64_t exponent =
784 static_cast<uint64_t>((static_cast<__uint128_t>(i) * i) % order);
787 }
788
791 for (size_t i = 0; i < bluestein_size_; ++i)
792 {
795 }
796
799 for (size_t i = 1; i < n_; ++i)
800 {
804
808 }
809
812 }
813
814 void
816 {
817 for (size_t i = 0; i < n_; ++i)
818 if (i < bit_rev_(i))
819 std::swap(a(i), a(bit_rev_(i)));
820 }
821
822 void
824 ThreadPool * const pool,
825 const size_t chunk_size) const
826 {
827 auto lift_one = [&a](const size_t i)
828 {
829 a(i) = to_mont(a[i] % MOD, mctx_);
830 };
831 for_each_index(pool, n_, lift_one, chunk_size);
832 }
833
834 void
836 ThreadPool * const pool,
837 const size_t chunk_size) const
838 {
839 auto scale_one = [this, &a](const size_t i)
840 {
841 a(i) = mont_mul(a[i], inv_n_, mctx_);
842 };
843 for_each_index(pool, n_, scale_one, chunk_size);
844 }
845
846 void
848 ThreadPool * const pool,
849 const size_t chunk_size) const
850 {
851 auto lower_one = [&a](const size_t i)
852 {
853 a(i) = from_mont(a[i], mctx_);
854 };
855 for_each_index(pool, n_, lower_one, chunk_size);
856 }
857
858 void
861 ThreadPool * const pool,
862 const size_t chunk_size) const
863 {
864 for (size_t stage = 0; stage < log_n_; ++stage)
865 {
866 const size_t half = static_cast<size_t>(1) << stage;
867 const size_t len = half << 1;
868 const size_t blocks = n_ / len;
869 const size_t offset = half - 1;
870
871 auto butterfly_block = [this, &a, &twiddles, half, len, offset]
872 (const size_t block)
873 {
874 const size_t base = block * len;
876 0, half);
877 };
878
879 for_each_index(pool, blocks, butterfly_block, chunk_size);
880 }
881 }
882
883 void
885 const bool invert,
886 ThreadPool * const pool,
887 const size_t chunk_size) const
888 {
889 const Array<uint64_t> & twiddles =
891
892# if ALEPH_NTT_HAS_X86_AVX2_DISPATCH
893 if (should_use_avx2(pool))
894 {
896 return;
897 }
898# endif
899
900# if ALEPH_NTT_HAS_ARM_NEON_DISPATCH
901 if (should_use_neon(pool))
902 {
904 return;
905 }
906# endif
907
908 apply_butterflies_scalar(a, twiddles, pool, chunk_size);
909 }
910
911 void
913 const bool invert,
914 ThreadPool * const pool,
915 const size_t chunk_size) const
916 {
918 << "NTT::Plan::apply_bluestein_transform: missing internal plan";
919
921 for (size_t i = 0; i < bluestein_size_; ++i)
922 work(i) = 0;
923
928 const Array<uint64_t> & kernel =
930
931 auto initialize = [&a, &work, &input_chirp](const size_t i)
932 {
933 if (i < a.size())
934 work(i) = mod_mul(a[i] % MOD, input_chirp[i], MOD);
935 };
936 for_each_index(pool, n_, initialize, chunk_size);
937
938 if (pool != nullptr and pool->num_threads() > 1)
939 bluestein_plan_->ptransform(*pool, work, false, chunk_size);
940 else
941 bluestein_plan_->transform(work, false);
942
943 auto pointwise = [&work, &kernel](const size_t i)
944 {
945 work(i) = mod_mul(work[i], kernel[i], MOD);
946 };
947 for_each_index(pool, bluestein_size_, pointwise, chunk_size);
948
949 if (pool != nullptr and pool->num_threads() > 1)
950 bluestein_plan_->ptransform(*pool, work, true, chunk_size);
951 else
952 bluestein_plan_->transform(work, true);
953
954 auto finalize = [this, &a, &work, &output_chirp, invert](const size_t i)
955 {
956 uint64_t value = mod_mul(work[i], output_chirp[i], MOD);
957 if (invert)
958 value = mod_mul(value, inv_n_std_, MOD);
959 a(i) = value;
960 };
961 for_each_index(pool, n_, finalize, chunk_size);
962 }
963
964 void
966 const bool invert,
969 ThreadPool * const pool,
970 const size_t chunk_size) const
971 {
973 << "NTT::Plan::transform: input size " << a.size()
974 << " does not match plan size " << n_;
975
976 switch (strategy_)
977 {
980 lift_input(a, pool, chunk_size);
981
982 if (n_ > 1)
983 {
985 apply_butterflies(a, invert, pool, chunk_size);
986 }
987
988 if (invert)
989 scale_inverse(a, pool, chunk_size);
990
992 lower_output(a, pool, chunk_size);
993 break;
994
997 << "NTT::Plan::apply_transform: Bluestein path expects standard "
998 << "input representation";
1000 << "NTT::Plan::apply_transform: Bluestein path returns standard "
1001 << "output representation";
1002 apply_bluestein_transform(a, invert, pool, chunk_size);
1003 break;
1004 }
1005 }
1006
1009 const Array<uint64_t> & b,
1010 ThreadPool * const pool,
1011 const size_t chunk_size) const
1012 {
1013 if (a.is_empty() or b.is_empty())
1014 return {};
1015
1017 std::numeric_limits<size_t>::max()
1018 - b.size() + 1)
1019 << "NTT::Plan::multiply: product size exceeds size_t capacity";
1020
1021 const size_t required = a.size() + b.size() - 1;
1023 << "NTT::Plan::multiply: product size " << required
1024 << " exceeds plan size " << n_;
1025
1028
1030 {
1031 apply_transform(fa, false,
1034 pool, chunk_size);
1035 apply_transform(fb, false,
1038 pool, chunk_size);
1039
1040 auto pointwise_product = [&fa, &fb](const size_t i)
1041 {
1042 fa(i) = mont_mul(fa[i], fb[i], mctx_);
1043 };
1044 for_each_index(pool, n_, pointwise_product, chunk_size);
1045
1046 apply_transform(fa, true,
1049 pool, chunk_size);
1050 }
1051 else
1052 {
1053 apply_transform(fa, false,
1056 pool, chunk_size);
1057 apply_transform(fb, false,
1060 pool, chunk_size);
1061
1062 auto pointwise_product = [&fa, &fb](const size_t i)
1063 {
1064 fa(i) = mod_mul(fa[i], fb[i], MOD);
1065 };
1066 for_each_index(pool, n_, pointwise_product, chunk_size);
1067
1068 apply_transform(fa, true,
1071 pool, chunk_size);
1072 }
1073
1074 return NTT::prefix_copy(fa, required);
1075 }
1076
1077 public:
1085 explicit Plan(const size_t n) : n_(n)
1086 {
1087 NTT::validate_supported_size(n_, "NTT::Plan");
1088 inv_n_std_ = mod_inv(static_cast<uint64_t>(n_), MOD);
1089
1092 else
1094 }
1095
1098 {
1099 return n_;
1100 }
1101
1107 void
1109 const bool invert) const
1110 {
1114 nullptr, 0);
1115 }
1116
1125 const bool invert = false) const
1126 {
1129 return output;
1130 }
1131
1142 const Array<uint64_t> & b) const
1143 {
1144 return multiply_impl(a, b, nullptr, 0);
1145 }
1146
1155 void
1157 Array<uint64_t> & a,
1158 const bool invert,
1159 const size_t chunk_size = 0) const
1160 {
1164 &pool, chunk_size);
1165 }
1166
1178 const Array<uint64_t> & input,
1179 const bool invert = false,
1180 const size_t chunk_size = 0) const
1181 {
1183 ptransform(pool, output, invert, chunk_size);
1184 return output;
1185 }
1186
1199 const Array<uint64_t> & a,
1200 const Array<uint64_t> & b,
1201 const size_t chunk_size = 0) const
1202 {
1203 return multiply_impl(a, b, &pool, chunk_size);
1204 }
1205
1211 void
1213 const bool invert) const
1214 {
1215 for (size_t i = 0; i < batch.size(); ++i)
1216 {
1218 << "NTT::Plan::transform_batch: batch item " << i
1219 << " has size " << batch[i].size()
1220 << " but plan size is " << n_;
1224 nullptr, 0);
1225 }
1226 }
1227
1239 void
1242 const bool invert,
1243 const size_t chunk_size = 0) const
1244 {
1245 for (size_t i = 0; i < batch.size(); ++i)
1247 << "NTT::Plan::ptransform_batch: batch item " << i
1248 << " has size " << batch[i].size()
1249 << " but plan size is " << n_;
1250
1251 if (batch.is_empty())
1252 return;
1253
1254 if (batch.size() == 1 or pool.num_threads() <= 1)
1255 {
1256 for (size_t i = 0; i < batch.size(); ++i)
1260 &pool, chunk_size);
1261 return;
1262 }
1263
1264 auto transform_one = [this, &batch, invert](const size_t i)
1265 {
1269 nullptr, 0);
1270 };
1271 parallel_for_index(pool, 0, batch.size(), transform_one, chunk_size);
1272 }
1273
1282 const bool invert = false) const
1283 {
1286 return output;
1287 }
1288
1300 const Array<Array<uint64_t>> & input,
1301 const bool invert = false,
1302 const size_t chunk_size = 0) const
1303 {
1305 ptransform_batch(pool, output, invert, chunk_size);
1306 return output;
1307 }
1308 };
1309
1310 private:
1311 static void
1313 {
1314 while (not poly.is_empty() and poly.get_last() % MOD == 0)
1315 static_cast<void>(poly.remove_last());
1316 }
1317
1318 [[nodiscard]] static Array<uint64_t>
1320 {
1322 output.reserve(input.size());
1323 for (size_t i = 0; i < input.size(); ++i)
1324 output.append(input[i] % MOD);
1326 return output;
1327 }
1328
1329 [[nodiscard]] static Array<uint64_t>
1330 zero_series(const size_t n)
1331 {
1333 for (size_t i = 0; i < n; ++i)
1334 output(i) = 0;
1335 return output;
1336 }
1337
1338 [[nodiscard]] static Array<uint64_t>
1340 const size_t n)
1341 {
1343 const size_t limit = std::min(input.size(), n);
1344 for (size_t i = 0; i < limit; ++i)
1345 output(i) = input[i] % MOD;
1346 return output;
1347 }
1348
1349 [[nodiscard]] static Array<uint64_t>
1351 const size_t n)
1352 {
1354 output.reserve(std::min(input.size(), n));
1355 for (size_t i = 0; i < input.size() and i < n; ++i)
1356 output.append(input[i] % MOD);
1357 return output;
1358 }
1359
1360 [[nodiscard]] static Array<uint64_t>
1362 {
1364 output.reserve(input.size());
1365 for (size_t i = input.size(); i > 0; --i)
1366 output.append(input[i - 1] % MOD);
1367 return output;
1368 }
1369
1370 [[nodiscard]] static Array<uint64_t>
1372 const Array<uint64_t> & rhs,
1373 const size_t n)
1374 {
1376 for (size_t i = 0; i < n; ++i)
1377 {
1378 const uint64_t a = i < lhs.size() ? lhs[i] % MOD : 0;
1379 const uint64_t b = i < rhs.size() ? rhs[i] % MOD : 0;
1380 output(i) = add_mod(a, b);
1381 }
1382 return output;
1383 }
1384
1385 [[nodiscard]] static Array<uint64_t>
1387 const Array<uint64_t> & rhs,
1388 const size_t n)
1389 {
1391 for (size_t i = 0; i < n; ++i)
1392 {
1393 const uint64_t a = i < lhs.size() ? lhs[i] % MOD : 0;
1394 const uint64_t b = i < rhs.size() ? rhs[i] % MOD : 0;
1395 output(i) = sub_mod(a, b);
1396 }
1397 return output;
1398 }
1399
1400 [[nodiscard]] static Array<uint64_t>
1402 const Array<uint64_t> & rhs)
1403 {
1404 const size_t n = std::max(lhs.size(), rhs.size());
1407 return output;
1408 }
1409
1410 [[nodiscard]] static Array<uint64_t>
1412 const Array<uint64_t> & rhs)
1413 {
1414 const size_t n = std::max(lhs.size(), rhs.size());
1417 return output;
1418 }
1419
1420 [[nodiscard]] static Array<uint64_t>
1422 const uint64_t scalar,
1423 const size_t n)
1424 {
1426 const uint64_t factor = scalar % MOD;
1427 for (size_t i = 0; i < input.size() and i < n; ++i)
1428 output(i) = mod_mul(input[i] % MOD, factor, MOD);
1429 return output;
1430 }
1431
1432 [[nodiscard]] static Array<uint64_t>
1434 const Array<uint64_t> & rhs,
1435 const size_t n)
1436 {
1437 if (n == 0)
1438 return {};
1439
1440 const Array<uint64_t> left = truncate_poly(lhs, n);
1441 const Array<uint64_t> right = truncate_poly(rhs, n);
1442 if (left.is_empty() or right.is_empty())
1443 return zero_series(n);
1444
1445 return series_prefix(multiply(left, right), n);
1446 }
1447
1448 [[nodiscard]] static Array<uint64_t>
1450 {
1451 if (coeffs.size() <= 1)
1452 return {};
1453
1455 for (size_t i = 1; i < coeffs.size(); ++i)
1456 output(i - 1) = mod_mul(coeffs[i] % MOD,
1457 static_cast<uint64_t>(i) % MOD,
1458 MOD);
1459 return output;
1460 }
1461
1462 [[nodiscard]] static Array<uint64_t>
1464 {
1466 output(0) = 0;
1467 for (size_t i = 0; i < coeffs.size(); ++i)
1468 {
1469 const uint64_t inv = mod_inv(static_cast<uint64_t>(i + 1), MOD);
1470 output(i + 1) = mod_mul(coeffs[i] % MOD, inv, MOD);
1471 }
1472 return output;
1473 }
1474
1475 [[nodiscard]] static uint64_t
1477 const char * const ctx)
1478 {
1479 const uint64_t a = value % MOD;
1480 if (a == 0)
1481 return 0;
1482
1483 const uint64_t legendre = mod_exp(a, (MOD - 1) / 2, MOD);
1485 << ctx << ": constant term " << a
1486 << " is not a quadratic residue modulo " << MOD;
1487
1488 if (MOD % 4 == 3)
1489 {
1490 const uint64_t root = mod_exp(a, (MOD + 1) / 4, MOD);
1491 return std::min(root, root == 0 ? 0 : MOD - root);
1492 }
1493
1494 uint64_t q = MOD - 1;
1495 size_t s = 0;
1496 while ((q & 1ULL) == 0)
1497 {
1498 q >>= 1;
1499 ++s;
1500 }
1501
1502 uint64_t z = 2;
1503 while (mod_exp(z, (MOD - 1) / 2, MOD) != MOD - 1)
1504 ++z;
1505
1506 uint64_t c = mod_exp(z, q, MOD);
1507 uint64_t r = mod_exp(a, (q + 1) / 2, MOD);
1508 uint64_t t = mod_exp(a, q, MOD);
1509 size_t m = s;
1510
1511 while (t != 1)
1512 {
1513 size_t i = 1;
1514 uint64_t t2i = mod_mul(t, t, MOD);
1515 while (i < m and t2i != 1)
1516 {
1517 t2i = mod_mul(t2i, t2i, MOD);
1518 ++i;
1519 }
1520
1522 << ctx << ": Tonelli-Shanks failed to converge";
1523
1524 uint64_t b = c;
1525 for (size_t j = 0; j + i + 1 < m; ++j)
1526 b = mod_mul(b, b, MOD);
1527
1528 r = mod_mul(r, b, MOD);
1529 const uint64_t bb = mod_mul(b, b, MOD);
1530 t = mod_mul(t, bb, MOD);
1531 c = bb;
1532 m = i;
1533 }
1534
1535 return std::min(r, r == 0 ? 0 : MOD - r);
1536 }
1537
1540 {
1542 const size_t capacity = count == 0 ? 0 : count * 4 + 4;
1543 tree.reserve(capacity);
1544 for (size_t i = 0; i < capacity; ++i)
1545 tree.append(Array<uint64_t>());
1546 return tree;
1547 }
1548
1549 static void
1551 const Array<uint64_t> & points,
1552 const size_t node,
1553 const size_t left,
1554 const size_t right)
1555 {
1556 if (left + 1 == right)
1557 {
1558 tree(node) = {
1559 sub_mod(0, points[left] % MOD),
1560 1
1561 };
1562 return;
1563 }
1564
1565 const size_t mid = left + (right - left) / 2;
1566 build_product_tree(tree, points, node << 1, left, mid);
1567 build_product_tree(tree, points, (node << 1) | 1, mid, right);
1568 tree(node) = multiply(tree[node << 1], tree[(node << 1) | 1]);
1569 }
1570
1571 static void
1573 const char * const ctx)
1574 {
1575 for (size_t i = 0; i < points.size(); ++i)
1576 for (size_t j = i + 1; j < points.size(); ++j)
1577 ah_invalid_argument_if(points[i] % MOD == points[j] % MOD)
1578 << ctx << ": points[" << i << "] and points[" << j
1579 << "] collide modulo " << MOD;
1580 }
1581
1582 [[nodiscard]] static Array<uint64_t>
1584 const Array<uint64_t> & divisor)
1585 {
1586 if (dividend.is_empty())
1587 return {};
1588
1589 if (divisor.is_empty())
1590 return dividend;
1591
1592 if (dividend.size() < divisor.size())
1593 return dividend;
1594
1595 return poly_divmod(dividend, divisor).second;
1596 }
1597
1598 static void
1600 const Array<uint64_t> & poly,
1602 const size_t node,
1603 const size_t left,
1604 const size_t right)
1605 {
1606 if (left + 1 == right)
1607 {
1608 output(left) = poly.is_empty() ? 0 : poly[0] % MOD;
1609 return;
1610 }
1611
1612 const size_t mid = left + (right - left) / 2;
1614 poly.size() < tree[node << 1].size() ?
1615 poly :
1616 poly_mod(poly, tree[node << 1]);
1618 poly.size() < tree[(node << 1) | 1].size() ?
1619 poly :
1620 poly_mod(poly, tree[(node << 1) | 1]);
1621
1623 node << 1, left, mid);
1625 (node << 1) | 1, mid, right);
1626 }
1627
1628 [[nodiscard]] static Array<uint64_t>
1631 const size_t node,
1632 const size_t left,
1633 const size_t right)
1634 {
1635 if (left + 1 == right)
1636 return Array<uint64_t>({scaled_values[left] % MOD});
1637
1638 const size_t mid = left + (right - left) / 2;
1640 interpolate_recursive(tree, scaled_values, node << 1, left, mid);
1642 interpolate_recursive(tree, scaled_values, (node << 1) | 1, mid, right);
1643
1644 return poly_add_normalized(
1645 multiply(left_poly, tree[(node << 1) | 1]),
1646 multiply(right_poly, tree[node << 1]));
1647 }
1648
1649 public:
1650
1655 [[nodiscard]] static constexpr uint64_t
1660
1667 [[nodiscard]] static constexpr bool
1668 supports_size(const size_t n) noexcept
1669 {
1670 return n > 0
1673 }
1674
1681 [[nodiscard]] static constexpr uint64_t
1683 {
1684 if (not std::is_constant_evaluated())
1685 validate_supported_size(n, "NTT::primitive_root_of_unity");
1686
1687 if (n <= 1)
1688 return 1;
1689
1690 return primitive_root_of_order(static_cast<uint64_t>(n));
1691 }
1692
1698 static void
1700 const bool invert)
1701 {
1702 Plan(a.size()).transform(a, invert);
1703 }
1704
1711 [[nodiscard]] static Array<uint64_t>
1713 const bool invert = false)
1714 {
1717 return output;
1718 }
1719
1728 [[nodiscard]] static Array<uint64_t>
1730 const Array<uint64_t> & b)
1731 {
1732 if (a.is_empty() or b.is_empty())
1733 return {};
1734
1736 std::numeric_limits<size_t>::max()
1737 - b.size() + 1)
1738 << "NTT::multiply: product size exceeds size_t capacity";
1739
1740 const size_t required = a.size() + b.size() - 1;
1741 const size_t n =
1743 required :
1744 next_power_of_two(required, "NTT::multiply");
1745 validate_supported_size(n, "NTT::multiply");
1746
1747 return Plan(n).multiply(a, b);
1748 }
1749
1755 static void
1757 const Array<uint64_t> & b)
1758 {
1759 a = multiply(a, b);
1760 }
1761
1776 [[nodiscard]] static Array<uint64_t>
1778 const Array<uint64_t> & b)
1779 {
1781 << "NTT::negacyclic_multiply: inputs must have positive size";
1782 ah_invalid_argument_if(a.size() != b.size())
1783 << "NTT::negacyclic_multiply: lhs size " << a.size()
1784 << " does not match rhs size " << b.size();
1785
1786 const size_t n = a.size();
1788 << "NTT::negacyclic_multiply: size " << n
1789 << " is not a power of two";
1790 ah_invalid_argument_if(n > static_cast<size_t>(max_transform_size() / 2))
1791 << "NTT::negacyclic_multiply: size " << n
1792 << " requires a primitive root of order " << (n << 1)
1793 << ", but the largest supported power-of-two order is "
1794 << max_transform_size();
1795
1796 const uint64_t twist = primitive_root_of_unity(n << 1);
1798 Plan plan(n);
1799
1802 uint64_t power = 1;
1803 for (size_t i = 0; i < n; ++i)
1804 {
1805 lhs(i) = mod_mul(a[i] % MOD, power, MOD);
1806 rhs(i) = mod_mul(b[i] % MOD, power, MOD);
1807 power = mod_mul(power, twist, MOD);
1808 }
1809
1810 plan.transform(lhs, false);
1811 plan.transform(rhs, false);
1812 for (size_t i = 0; i < n; ++i)
1813 lhs(i) = mod_mul(lhs[i], rhs[i], MOD);
1814 plan.transform(lhs, true);
1815
1816 power = 1;
1817 for (size_t i = 0; i < n; ++i)
1818 {
1819 lhs(i) = mod_mul(lhs[i], power, MOD);
1820 power = mod_mul(power, inv_twist, MOD);
1821 }
1822
1823 return lhs;
1824 }
1825
1832 static void
1834 const bool invert)
1835 {
1836 if (batch.is_empty())
1837 return;
1838
1840 }
1841
1851 const bool invert = false)
1852 {
1853 if (batch.is_empty())
1854 return {};
1855
1857 }
1858
1867 static void
1869 Array<uint64_t> & a,
1870 const bool invert,
1871 const size_t chunk_size = 0)
1872 {
1873 Plan(a.size()).ptransform(pool, a, invert, chunk_size);
1874 }
1875
1885 [[nodiscard]] static Array<uint64_t>
1887 const Array<uint64_t> & input,
1888 const bool invert = false,
1889 const size_t chunk_size = 0)
1890 {
1892 ptransform(pool, output, invert, chunk_size);
1893 return output;
1894 }
1895
1906 [[nodiscard]] static Array<uint64_t>
1908 const Array<uint64_t> & a,
1909 const Array<uint64_t> & b,
1910 const size_t chunk_size = 0)
1911 {
1912 if (a.is_empty() or b.is_empty())
1913 return {};
1914
1916 std::numeric_limits<size_t>::max()
1917 - b.size() + 1)
1918 << "NTT::pmultiply: product size exceeds size_t capacity";
1919
1920 const size_t required = a.size() + b.size() - 1;
1921 const size_t n =
1923 required :
1924 next_power_of_two(required, "NTT::pmultiply");
1925 validate_supported_size(n, "NTT::pmultiply");
1926
1927 return Plan(n).pmultiply(pool, a, b, chunk_size);
1928 }
1929
1939 static void
1942 const bool invert,
1943 const size_t chunk_size = 0)
1944 {
1945 if (batch.is_empty())
1946 return;
1947
1948 Plan(batch[0].size()).ptransform_batch(pool, batch, invert, chunk_size);
1949 }
1950
1957 [[nodiscard]] static uint64_t
1959 const uint64_t x)
1960 {
1961 uint64_t value = 0;
1962 const uint64_t x_mod = x % MOD;
1963 for (size_t i = coeffs.size(); i > 0; --i)
1964 {
1965 value = mod_mul(value, x_mod, MOD);
1966 value = add_mod(value, coeffs[i - 1] % MOD);
1967 }
1968 return value;
1969 }
1970
1978 [[nodiscard]] static Array<uint64_t>
1980 const size_t n)
1981 {
1982 if (n == 0)
1983 return {};
1984
1986 << "NTT::poly_inverse: input polynomial must be non-empty";
1987
1988 const uint64_t c0 = coeffs[0] % MOD;
1990 << "NTT::poly_inverse: constant term must be invertible modulo " << MOD;
1991
1993 inverse(0) = mod_inv(c0, MOD);
1994
1995 size_t m = 1;
1996 while (m < n)
1997 {
1998 const size_t m2 = std::min(n, m << 1);
1999 const Array<uint64_t> fg =
2000 poly_mul_trunc(truncate_poly(coeffs, m2), inverse, m2);
2002 correction(0) = sub_mod(2 % MOD, fg[0]);
2003 for (size_t i = 1; i < m2; ++i)
2004 correction(i) = fg[i] == 0 ? 0 : MOD - fg[i];
2005 inverse = poly_mul_trunc(inverse, correction, m2);
2006 m = m2;
2007 }
2008
2009 return series_prefix(inverse, n);
2010 }
2011
2019 [[nodiscard]] static std::pair<Array<uint64_t>, Array<uint64_t>>
2021 const Array<uint64_t> & divisor)
2022 {
2025
2027 << "NTT::poly_divmod: divisor cannot be the zero polynomial";
2028
2029 if (a.is_empty() or a.size() < b.size())
2030 return {{}, a};
2031
2032 const size_t quotient_size = a.size() - b.size() + 1;
2034 reverse_poly(a),
2039
2042 if (remainder.size() >= b.size())
2045 return {quotient, remainder};
2046 }
2047
2055 [[nodiscard]] static Array<uint64_t>
2057 const size_t n)
2058 {
2059 if (n == 0)
2060 return {};
2061
2062 const uint64_t c0 = coeffs.is_empty() ? 0 : coeffs[0] % MOD;
2064 << "NTT::poly_log: constant term must be 1 modulo " << MOD;
2065
2066 if (n == 1)
2067 return zero_series(1);
2068
2069 const Array<uint64_t> derivative =
2070 poly_derivative(truncate_poly(coeffs, n));
2071 const Array<uint64_t> inverse = poly_inverse(coeffs, n - 1);
2073 poly_mul_trunc(derivative, inverse, n - 1)), n);
2074 }
2075
2083 [[nodiscard]] static Array<uint64_t>
2085 const size_t n)
2086 {
2087 if (n == 0)
2088 return {};
2089
2090 const uint64_t c0 = coeffs.is_empty() ? 0 : coeffs[0] % MOD;
2092 << "NTT::poly_exp: constant term must be 0 modulo " << MOD;
2093
2095 result(0) = 1;
2096
2097 size_t m = 1;
2098 while (m < n)
2099 {
2100 const size_t m2 = std::min(n, m << 1);
2101 Array<uint64_t> delta =
2103 poly_log(result, m2), m2);
2104 delta(0) = add_mod(delta[0], 1);
2105 result = poly_mul_trunc(result, delta, m2);
2106 m = m2;
2107 }
2108
2109 return series_prefix(result, n);
2110 }
2111
2120 [[nodiscard]] static Array<uint64_t>
2122 const size_t n)
2123 {
2124 if (n == 0)
2125 return {};
2126
2127 const Array<uint64_t> input = truncate_poly(coeffs, n);
2128 size_t lead = 0;
2129 while (lead < input.size() and input[lead] % MOD == 0)
2130 ++lead;
2131
2132 if (lead == input.size())
2133 return zero_series(n);
2134
2135 ah_invalid_argument_if((lead & 1U) != 0)
2136 << "NTT::poly_sqrt: first non-zero term appears at odd degree "
2137 << lead;
2138
2139 if (lead > 0)
2140 {
2141 const size_t shift = lead / 2;
2142 Array<uint64_t> tail;
2143 tail.reserve(input.size() - lead);
2144 for (size_t i = lead; i < input.size(); ++i)
2145 tail.append(input[i] % MOD);
2146
2147 const Array<uint64_t> rooted = poly_sqrt(tail, n - shift);
2149 for (size_t i = 0; i < rooted.size() and i + shift < n; ++i)
2150 output(i + shift) = rooted[i];
2151 return output;
2152 }
2153
2155 result(0) = tonelli_shanks(input[0], "NTT::poly_sqrt");
2156
2157 const uint64_t inv_two = mod_inv(2, MOD);
2158 size_t m = 1;
2159 while (m < n)
2160 {
2161 const size_t m2 = std::min(n, m << 1);
2164 poly_inverse(result, m2), m2);
2165 result = poly_scalar_mul_series(
2166 poly_add_series(result, quotient, m2), inv_two, m2);
2167 m = m2;
2168 }
2169
2170 return series_prefix(result, n);
2171 }
2172
2180 [[nodiscard]] static Array<uint64_t>
2182 const uint64_t k,
2183 const size_t n)
2184 {
2185 if (n == 0)
2186 return {};
2187
2188 if (k == 0)
2189 {
2191 output(0) = 1;
2192 return output;
2193 }
2194
2195 const Array<uint64_t> input = truncate_poly(coeffs, n);
2196 size_t lead = 0;
2197 while (lead < input.size() and input[lead] % MOD == 0)
2198 ++lead;
2199
2200 if (lead == input.size())
2201 return zero_series(n);
2202
2203 const __uint128_t shift128 =
2204 static_cast<__uint128_t>(lead) * static_cast<__uint128_t>(k);
2205 if (shift128 >= n)
2206 return zero_series(n);
2207
2208 const size_t shift = static_cast<size_t>(shift128);
2209 const size_t target = n - shift;
2210 const uint64_t lead_coeff = input[lead] % MOD;
2212
2214 normalized.reserve(input.size() - lead);
2215 for (size_t i = lead; i < input.size(); ++i)
2216 normalized.append(mod_mul(input[i] % MOD, inv_lead, MOD));
2217
2219 for (size_t i = 0; i < scaled_log.size(); ++i)
2221 static_cast<uint64_t>(k % MOD), MOD);
2222
2225 target);
2226
2228 for (size_t i = 0; i < powered.size() and i + shift < n; ++i)
2229 output(i + shift) = powered[i];
2230 return output;
2231 }
2232
2239 [[nodiscard]] static Array<uint64_t>
2241 const Array<uint64_t> & points)
2242 {
2243 if (points.is_empty())
2244 return {};
2245
2247 for (size_t i = 0; i < points.size(); ++i)
2248 reduced_points(i) = points[i] % MOD;
2249
2252
2254 for (size_t i = 0; i < output.size(); ++i)
2255 output(i) = 0;
2256
2258 1, 0, reduced_points.size());
2259 return output;
2260 }
2261
2271 [[nodiscard]] static Array<uint64_t>
2273 const Array<uint64_t> & values)
2274 {
2275 ah_invalid_argument_if(points.size() != values.size())
2276 << "NTT::interpolate: points size " << points.size()
2277 << " does not match values size " << values.size();
2278
2279 if (points.is_empty())
2280 return {};
2281
2284 for (size_t i = 0; i < points.size(); ++i)
2285 {
2286 reduced_points(i) = points[i] % MOD;
2287 reduced_values(i) = values[i] % MOD;
2288 }
2289
2290 validate_distinct_points(reduced_points, "NTT::interpolate");
2291
2294
2296 const Array<uint64_t> weights =
2298
2300 for (size_t i = 0; i < values.size(); ++i)
2301 {
2303 << "NTT::interpolate: derivative vanished at point index " << i;
2305 mod_inv(weights[i], MOD), MOD);
2306 }
2307
2309 1, 0, reduced_points.size()));
2310 }
2311
2333 template <uint64_t Base = (1ULL << 15)>
2334 [[nodiscard]] static Array<uint64_t>
2335 bigint_multiply(const Array<uint64_t> & a,
2336 const Array<uint64_t> & b);
2337
2354 template <uint64_t Base = (1ULL << 15)>
2355 [[nodiscard]] static Array<uint64_t>
2356 pbigint_multiply(ThreadPool & pool,
2357 const Array<uint64_t> & a,
2358 const Array<uint64_t> & b,
2359 const size_t chunk_size = 0);
2360 };
2361
2367 struct NTTPrime
2368 {
2369 uint64_t mod = 0;
2370 uint64_t root = 0;
2371 uint64_t max_power_of_two = 0;
2372 };
2373
2389 class NTTExact
2390 {
2391 public:
2392 using coeff_type = __uint128_t;
2393
2394 private:
2395 using Prime0NTT = NTT<998244353ULL, 3ULL>;
2396 using Prime1NTT = NTT<469762049ULL, 3ULL>;
2397 using Prime2NTT = NTT<1004535809ULL, 3ULL>;
2398
2399 struct CoefficientStats
2400 {
2401 uint64_t max_value = 0;
2402 size_t non_zero = 0;
2403 coeff_type sum = 0;
2404 };
2405
2406 static constexpr NTTPrime primes_[] = {
2407 {998244353ULL, 3ULL, 23ULL},
2408 {469762049ULL, 3ULL, 26ULL},
2409 {1004535809ULL, 3ULL, 21ULL}
2410 };
2411
2412 [[nodiscard]] static constexpr coeff_type
2413 exact_modulus_product_impl() noexcept
2414 {
2415 return static_cast<coeff_type>(primes_[0].mod)
2416 * static_cast<coeff_type>(primes_[1].mod)
2417 * static_cast<coeff_type>(primes_[2].mod);
2418 }
2419
2420 [[nodiscard]] static constexpr uint64_t
2421 sub_mod(const uint64_t lhs,
2422 const uint64_t rhs,
2423 const uint64_t mod) noexcept
2424 {
2425 return lhs >= rhs ? lhs - rhs : mod - (rhs - lhs);
2426 }
2427
2428 [[nodiscard]] static constexpr coeff_type
2429 add_capped(const coeff_type lhs,
2430 const coeff_type rhs,
2431 const coeff_type cap) noexcept
2432 {
2433 return lhs >= cap - rhs ? cap : lhs + rhs;
2434 }
2435
2436 [[nodiscard]] static constexpr coeff_type
2438 const coeff_type rhs,
2439 const coeff_type cap) noexcept
2440 {
2441 if (lhs == 0 or rhs == 0)
2442 return 0;
2443 return lhs > cap / rhs ? cap : lhs * rhs;
2444 }
2445
2446 [[nodiscard]] static constexpr size_t
2447 next_power_of_two(const size_t n) noexcept
2448 {
2449 if (n <= 1)
2450 return 1;
2451
2452 size_t value = 1;
2453 while (value < n)
2454 {
2455 if (value > std::numeric_limits<size_t>::max() / 2)
2456 return 0;
2457 value <<= 1;
2458 }
2459
2460 return value;
2461 }
2462
2463 template <typename PrimeNTT>
2464 [[nodiscard]] static constexpr bool
2466 {
2467 if (required == 0)
2468 return false;
2469
2470 if (PrimeNTT::supports_size(required))
2471 return true;
2472
2473 const size_t n = next_power_of_two(required);
2474 return n != 0 and PrimeNTT::supports_size(n);
2475 }
2476
2477 [[nodiscard]] static std::string
2479 {
2480 if (value == 0)
2481 return "0";
2482
2483 std::string digits;
2484 while (value > 0)
2485 {
2486 const auto digit = static_cast<unsigned>(value % 10);
2487 digits.push_back(static_cast<char>('0' + digit));
2488 value /= 10;
2489 }
2490
2491 std::reverse(digits.begin(), digits.end());
2492 return digits;
2493 }
2494
2495 [[nodiscard]] static CoefficientStats
2497 {
2498 CoefficientStats stats;
2499 const coeff_type cap = exact_modulus_product();
2500 for (size_t i = 0; i < input.size(); ++i)
2501 {
2502 const uint64_t value = input[i];
2503 if (value == 0)
2504 continue;
2505
2506 ++stats.non_zero;
2507 if (value > stats.max_value)
2508 stats.max_value = value;
2509 stats.sum = add_capped(stats.sum, static_cast<coeff_type>(value), cap);
2510 }
2511
2512 return stats;
2513 }
2514
2515 [[nodiscard]] static coeff_type
2517 const Array<uint64_t> & b)
2518 {
2519 if (a.is_empty() or b.is_empty())
2520 return 0;
2521
2522 const coeff_type cap = exact_modulus_product();
2523 const CoefficientStats lhs = analyze_coefficients(a);
2524 const CoefficientStats rhs = analyze_coefficients(b);
2525
2526 if (lhs.non_zero == 0 or rhs.non_zero == 0)
2527 return 0;
2528
2529 coeff_type bound = cap;
2530
2531 const coeff_type max_product =
2532 mul_capped(static_cast<coeff_type>(lhs.max_value),
2533 static_cast<coeff_type>(rhs.max_value), cap);
2534
2536 mul_capped(static_cast<coeff_type>(std::min(lhs.non_zero, rhs.non_zero)),
2537 max_product, cap);
2538 if (overlap_bound < bound)
2539 bound = overlap_bound;
2540
2542 mul_capped(lhs.sum, static_cast<coeff_type>(rhs.max_value), cap);
2543 if (lhs_sum_bound < bound)
2544 bound = lhs_sum_bound;
2545
2547 mul_capped(rhs.sum, static_cast<coeff_type>(lhs.max_value), cap);
2548 if (rhs_sum_bound < bound)
2549 bound = rhs_sum_bound;
2550
2551 return bound;
2552 }
2553
2554 static void
2556 const Array<uint64_t> & b,
2557 const char * const ctx)
2558 {
2560 std::numeric_limits<size_t>::max()
2561 - b.size() + 1)
2562 << ctx << ": product size exceeds size_t capacity";
2563
2564 const size_t required = a.size() + b.size() - 1;
2565 ah_invalid_argument_if(not supports_product_size(required))
2566 << ctx << ": required product size " << required
2567 << " is not supported by the three-prime CRT pack";
2568
2569 const coeff_type bound = conservative_bound(a, b);
2570 ah_invalid_argument_if(bound >= exact_modulus_product())
2571 << ctx << ": cannot guarantee exact reconstruction inside CRT range "
2572 << coeff_to_string(exact_modulus_product())
2573 << " with conservative coefficient bound "
2574 << coeff_to_string(bound);
2575 }
2576
2577 [[nodiscard]] static coeff_type
2579 const uint64_t r1,
2580 const uint64_t r2)
2581 {
2582 static const uint64_t m0 = primes_[0].mod;
2583 static const uint64_t m1 = primes_[1].mod;
2584 static const uint64_t m2 = primes_[2].mod;
2585 static const coeff_type m0m1 =
2586 static_cast<coeff_type>(m0) * static_cast<coeff_type>(m1);
2587 static const uint64_t m0_inv_mod_m1 = mod_inv(m0 % m1, m1);
2588 static const uint64_t m0m1_inv_mod_m2 =
2589 mod_inv(static_cast<uint64_t>(m0m1 % m2), m2);
2590
2591 const uint64_t t1 =
2592 mod_mul(sub_mod(r1, r0 % m1, m1), m0_inv_mod_m1, m1);
2593 const coeff_type x01 =
2594 static_cast<coeff_type>(r0)
2595 + static_cast<coeff_type>(t1) * static_cast<coeff_type>(m0);
2596
2597 const uint64_t x01_mod_m2 = static_cast<uint64_t>(x01 % m2);
2598 const uint64_t t2 =
2599 mod_mul(sub_mod(r2, x01_mod_m2, m2), m0m1_inv_mod_m2, m2);
2600
2601 return x01 + static_cast<coeff_type>(t2) * m0m1;
2602 }
2603
2604 [[nodiscard]] static Array<coeff_type>
2606 const Array<uint64_t> & c1,
2607 const Array<uint64_t> & c2,
2608 ThreadPool * const pool,
2609 const size_t chunk_size)
2610 {
2611 ah_runtime_error_unless(c0.size() == c1.size() and c1.size() == c2.size())
2612 << "NTTExact::reconstruct_product: inconsistent residue sizes";
2613
2615
2616 auto reconstruct_one = [&output, &c0, &c1, &c2](const size_t i)
2617 {
2618 output(i) = reconstruct_coefficient(c0[i], c1[i], c2[i]);
2619 };
2620
2621 if (pool != nullptr and pool->num_threads() > 1 and c0.size() > 1)
2622 parallel_for_index(*pool, 0, c0.size(), reconstruct_one, chunk_size);
2623 else
2624 for (size_t i = 0; i < c0.size(); ++i)
2625 reconstruct_one(i);
2626
2627 return output;
2628 }
2629
2630 public:
2632 [[nodiscard]] static constexpr size_t
2634 {
2635 return sizeof(primes_) / sizeof(primes_[0]);
2636 }
2637
2643 [[nodiscard]] static constexpr coeff_type
2645 {
2646 return exact_modulus_product_impl();
2647 }
2648
2655 [[nodiscard]] static constexpr bool
2662
2672 [[nodiscard]] static Array<coeff_type>
2674 const Array<uint64_t> & b)
2675 {
2676 if (a.is_empty() or b.is_empty())
2677 return {};
2678
2679 validate_inputs(a, b, "NTTExact::multiply");
2680
2681 const Array<uint64_t> c0 = Prime0NTT::multiply(a, b);
2682 const Array<uint64_t> c1 = Prime1NTT::multiply(a, b);
2683 const Array<uint64_t> c2 = Prime2NTT::multiply(a, b);
2684 return reconstruct_product(c0, c1, c2, nullptr, 0);
2685 }
2686
2703 [[nodiscard]] static Array<coeff_type>
2705 const Array<uint64_t> & a,
2706 const Array<uint64_t> & b,
2707 const size_t chunk_size = 0)
2708 {
2709 if (a.is_empty() or b.is_empty())
2710 return {};
2711
2712 validate_inputs(a, b, "NTTExact::pmultiply");
2713
2714 if (pool.num_threads() <= 1)
2715 return multiply(a, b);
2716
2717 auto f0 = pool.enqueue([&a, &b]()
2718 {
2719 return Prime0NTT::multiply(a, b);
2720 });
2721 auto f1 = pool.enqueue([&a, &b]()
2722 {
2723 return Prime1NTT::multiply(a, b);
2724 });
2725 auto f2 = pool.enqueue([&a, &b]()
2726 {
2727 return Prime2NTT::multiply(a, b);
2728 });
2729
2730 const Array<uint64_t> c0 = f0.get();
2731 const Array<uint64_t> c1 = f1.get();
2732 const Array<uint64_t> c2 = f2.get();
2733 return reconstruct_product(c0, c1, c2, &pool, chunk_size);
2734 }
2735 };
2736
2737 template <uint64_t MOD, uint64_t ROOT>
2738 template <uint64_t Base>
2741 const Array<uint64_t> & b)
2742 {
2743 static_assert(Base > 1, "NTT::bigint_multiply requires Base >= 2");
2745
2746 auto zero_digits = []()
2747 {
2749 output(0) = 0;
2750 return output;
2751 };
2752
2753 auto validate_digits = [](const Array<uint64_t> & digits,
2754 const char * const name,
2755 const char * const ctx)
2756 {
2757 for (size_t i = 0; i < digits.size(); ++i)
2758 ah_invalid_argument_if(digits[i] >= Base)
2759 << ctx << ": " << name << "[" << i << "] = " << digits[i]
2760 << " is not in [0, " << Base << ")";
2761 };
2762
2763 auto normalize_digits = [](const Array<uint64_t> & input)
2764 {
2766 output.reserve(input.size());
2767 for (size_t i = 0; i < input.size(); ++i)
2768 output.append(input[i]);
2769
2770 while (not output.is_empty() and output.get_last() == 0)
2771 static_cast<void>(output.remove_last());
2772 return output;
2773 };
2774
2775 auto propagate_carries = [&zero_digits](const Array<ExactCoeff> & coeffs)
2776 {
2777 if (coeffs.is_empty())
2778 return zero_digits();
2779
2781 output.reserve(coeffs.size() + 2);
2782
2783 ExactCoeff carry = 0;
2784 for (size_t i = 0; i < coeffs.size(); ++i)
2785 {
2786 const ExactCoeff total = coeffs[i] + carry;
2787 output.append(static_cast<uint64_t>(total % Base));
2788 carry = total / Base;
2789 }
2790
2791 while (carry > 0)
2792 {
2793 output.append(static_cast<uint64_t>(carry % Base));
2794 carry /= Base;
2795 }
2796
2797 while (output.size() > 1 and output.get_last() == 0)
2798 static_cast<void>(output.remove_last());
2799
2800 return output.is_empty() ? zero_digits() : output;
2801 };
2802
2803 validate_digits(a, "a", "NTT::bigint_multiply");
2804 validate_digits(b, "b", "NTT::bigint_multiply");
2805
2808 if (lhs.is_empty() or rhs.is_empty())
2809 return zero_digits();
2810
2812 std::numeric_limits<size_t>::max()
2813 - rhs.size() + 1)
2814 << "NTT::bigint_multiply: product size exceeds size_t capacity";
2815
2816 const size_t required = lhs.size() + rhs.size() - 1;
2817 const size_t internal_size =
2818 supports_size(required) ?
2819 required :
2820 next_power_of_two(required, "NTT::bigint_multiply");
2821
2822 const ExactCoeff digit_max = static_cast<ExactCoeff>(Base - 1);
2824 static_cast<ExactCoeff>(std::min(lhs.size(), rhs.size()))
2825 * digit_max * digit_max;
2826
2827 if (single_prime_bound < static_cast<ExactCoeff>(MOD)
2828 and supports_size(internal_size))
2829 {
2830 const Array<uint64_t> coeffs = multiply(lhs, rhs);
2832 for (size_t i = 0; i < coeffs.size(); ++i)
2833 exact(i) = static_cast<ExactCoeff>(coeffs[i]);
2834 return propagate_carries(exact);
2835 }
2836
2838 }
2839
2840 template <uint64_t MOD, uint64_t ROOT>
2841 template <uint64_t Base>
2844 const Array<uint64_t> & a,
2845 const Array<uint64_t> & b,
2846 const size_t chunk_size)
2847 {
2848 static_assert(Base > 1, "NTT::pbigint_multiply requires Base >= 2");
2850
2851 auto zero_digits = []()
2852 {
2854 output(0) = 0;
2855 return output;
2856 };
2857
2858 auto validate_digits = [](const Array<uint64_t> & digits,
2859 const char * const name,
2860 const char * const ctx)
2861 {
2862 for (size_t i = 0; i < digits.size(); ++i)
2863 ah_invalid_argument_if(digits[i] >= Base)
2864 << ctx << ": " << name << "[" << i << "] = " << digits[i]
2865 << " is not in [0, " << Base << ")";
2866 };
2867
2868 auto normalize_digits = [](const Array<uint64_t> & input)
2869 {
2871 output.reserve(input.size());
2872 for (size_t i = 0; i < input.size(); ++i)
2873 output.append(input[i]);
2874
2875 while (not output.is_empty() and output.get_last() == 0)
2876 static_cast<void>(output.remove_last());
2877 return output;
2878 };
2879
2880 auto propagate_carries = [&zero_digits](const Array<ExactCoeff> & coeffs)
2881 {
2882 if (coeffs.is_empty())
2883 return zero_digits();
2884
2886 output.reserve(coeffs.size() + 2);
2887
2888 ExactCoeff carry = 0;
2889 for (size_t i = 0; i < coeffs.size(); ++i)
2890 {
2891 const ExactCoeff total = coeffs[i] + carry;
2892 output.append(static_cast<uint64_t>(total % Base));
2893 carry = total / Base;
2894 }
2895
2896 while (carry > 0)
2897 {
2898 output.append(static_cast<uint64_t>(carry % Base));
2899 carry /= Base;
2900 }
2901
2902 while (output.size() > 1 and output.get_last() == 0)
2903 static_cast<void>(output.remove_last());
2904
2905 return output.is_empty() ? zero_digits() : output;
2906 };
2907
2908 validate_digits(a, "a", "NTT::pbigint_multiply");
2909 validate_digits(b, "b", "NTT::pbigint_multiply");
2910
2913 if (lhs.is_empty() or rhs.is_empty())
2914 return zero_digits();
2915
2917 std::numeric_limits<size_t>::max()
2918 - rhs.size() + 1)
2919 << "NTT::pbigint_multiply: product size exceeds size_t capacity";
2920
2921 const size_t required = lhs.size() + rhs.size() - 1;
2922 const size_t internal_size =
2923 supports_size(required) ?
2924 required :
2925 next_power_of_two(required, "NTT::pbigint_multiply");
2926
2927 const ExactCoeff digit_max = static_cast<ExactCoeff>(Base - 1);
2929 static_cast<ExactCoeff>(std::min(lhs.size(), rhs.size()))
2930 * digit_max * digit_max;
2931
2932 if (single_prime_bound < static_cast<ExactCoeff>(MOD)
2933 and supports_size(internal_size))
2934 {
2935 const Array<uint64_t> coeffs = pmultiply(pool, lhs, rhs, chunk_size);
2937 for (size_t i = 0; i < coeffs.size(); ++i)
2938 exact(i) = static_cast<ExactCoeff>(coeffs[i]);
2939 return propagate_carries(exact);
2940 }
2941
2942 return propagate_carries(NTTExact::pmultiply(pool, lhs, rhs, chunk_size));
2943 }
2944} // namespace Aleph
2945
2946# endif // NTT_H
Exception handling system with formatted messages for Aleph-w.
#define ah_runtime_error_unless(C)
Throws std::runtime_error if condition does NOT hold.
Definition ah-errors.H:250
#define ah_overflow_error_if(C)
Throws std::overflow_error if condition holds.
Definition ah-errors.H:463
#define ah_runtime_error_if(C)
Throws std::runtime_error if condition holds.
Definition ah-errors.H:266
#define ah_invalid_argument_if(C)
Throws std::invalid_argument if condition holds.
Definition ah-errors.H:639
Simple dynamic array with automatic resizing and functional operations.
Definition tpl_array.H:139
static Array create(size_t n)
Create an array with n logical elements.
Definition tpl_array.H:194
constexpr size_t size() const noexcept
Return the number of elements stored in the stack.
Definition tpl_array.H:351
constexpr bool is_empty() const noexcept
Checks if the container is empty.
Definition tpl_array.H:348
T & append(const T &data)
Append a copy of data
Definition tpl_array.H:245
T & get_last() noexcept
return a modifiable reference to the last element.
Definition tpl_array.H:366
void reserve(size_t cap)
Reserves cap cells into the array.
Definition tpl_array.H:315
static coeff_type conservative_bound(const Array< uint64_t > &a, const Array< uint64_t > &b)
Definition ntt.H:2516
static constexpr coeff_type mul_capped(const coeff_type lhs, const coeff_type rhs, const coeff_type cap) noexcept
Definition ntt.H:2437
static Array< coeff_type > reconstruct_product(const Array< uint64_t > &c0, const Array< uint64_t > &c1, const Array< uint64_t > &c2, ThreadPool *const pool, const size_t chunk_size)
Definition ntt.H:2605
static constexpr coeff_type exact_modulus_product() noexcept
Product of the three CRT moduli.
Definition ntt.H:2644
static CoefficientStats analyze_coefficients(const Array< uint64_t > &input)
Definition ntt.H:2496
static coeff_type reconstruct_coefficient(const uint64_t r0, const uint64_t r1, const uint64_t r2)
Definition ntt.H:2578
__uint128_t coeff_type
Definition ntt.H:2392
static Array< coeff_type > pmultiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0)
Exact parallel polynomial multiplication.
Definition ntt.H:2704
static constexpr bool prime_supports_product_size(const size_t required) noexcept
Definition ntt.H:2465
static constexpr bool supports_product_size(const size_t required) noexcept
Check whether a target product length is supported.
Definition ntt.H:2656
static constexpr size_t next_power_of_two(const size_t n) noexcept
Definition ntt.H:2447
static std::string coeff_to_string(coeff_type value)
Definition ntt.H:2478
static void validate_inputs(const Array< uint64_t > &a, const Array< uint64_t > &b, const char *const ctx)
Definition ntt.H:2555
static Array< coeff_type > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Exact sequential polynomial multiplication.
Definition ntt.H:2673
static constexpr size_t prime_count() noexcept
Number of CRT primes in the exact multiplier.
Definition ntt.H:2633
Precomputed plans for NTT transforms.
Definition ntt.H:466
Array< Array< uint64_t > > transformed_batch(const Array< Array< uint64_t > > &input, const bool invert=false) const
Return a transformed copy of an entire batch.
Definition ntt.H:1281
Array< uint64_t > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b) const
Multiply two polynomials using this plan size.
Definition ntt.H:1141
void ptransform(ThreadPool &pool, Array< uint64_t > &a, const bool invert, const size_t chunk_size=0) const
Parallel in-place transform using a ThreadPool.
Definition ntt.H:1156
Array< uint64_t > multiply_impl(const Array< uint64_t > &a, const Array< uint64_t > &b, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:1008
size_t size() const noexcept
Return the transform size bound to the plan.
Definition ntt.H:1097
void initialize_twiddles()
Definition ntt.H:711
Array< uint64_t > bluestein_kernel_forward_
Definition ntt.H:484
void apply_bit_reversal(Array< uint64_t > &a) const noexcept
Definition ntt.H:815
Array< uint64_t > transformed(const Array< uint64_t > &input, const bool invert=false) const
Transform an input array and return the result.
Definition ntt.H:1124
std::shared_ptr< const Plan > bluestein_plan_
Definition ntt.H:486
uint64_t inv_n_std_
Definition ntt.H:480
Array< uint64_t > pmultiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0) const
Parallel polynomial multiplication using this plan size.
Definition ntt.H:1198
Strategy strategy_
Definition ntt.H:473
Plan(const size_t n)
Construct a reusable plan for a fixed transform size.
Definition ntt.H:1085
Array< uint64_t > bluestein_chirp_inv_
Definition ntt.H:483
void transform(Array< uint64_t > &a, const bool invert) const
In-place forward or inverse transform.
Definition ntt.H:1108
void ptransform_batch(ThreadPool &pool, Array< Array< uint64_t > > &batch, const bool invert, const size_t chunk_size=0) const
Parallel batch transform for equal-sized inputs.
Definition ntt.H:1240
void lift_input(Array< uint64_t > &a, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:823
Array< size_t > bit_rev_
Definition ntt.H:476
void initialize_power_of_two_plan()
Definition ntt.H:745
Array< uint64_t > twiddles_inv_
Definition ntt.H:478
static void for_each_index(ThreadPool *const pool, const size_t count, F &&fn, const size_t chunk_size)
Definition ntt.H:490
Array< Array< uint64_t > > ptransformed_batch(ThreadPool &pool, const Array< Array< uint64_t > > &input, const bool invert=false, const size_t chunk_size=0) const
Return a parallel-transformed copy of an entire batch.
Definition ntt.H:1299
Array< uint64_t > bluestein_chirp_fwd_
Definition ntt.H:482
size_t n_
Definition ntt.H:474
void apply_transform(Array< uint64_t > &a, const bool invert, const Representation input_repr, const Representation output_repr, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:965
bool should_use_avx2(ThreadPool *const pool) const noexcept
Definition ntt.H:529
void apply_butterflies_scalar(Array< uint64_t > &a, const Array< uint64_t > &twiddles, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:859
void apply_butterflies(Array< uint64_t > &a, const bool invert, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:884
void apply_bluestein_transform(Array< uint64_t > &a, const bool invert, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:912
bool should_use_neon(ThreadPool *const pool) const noexcept
Definition ntt.H:546
void initialize_bit_reversal()
Definition ntt.H:696
size_t log_n_
Definition ntt.H:475
void scale_inverse(Array< uint64_t > &a, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:835
void apply_scalar_butterfly_range(Array< uint64_t > &a, const Array< uint64_t > &twiddles, const size_t base, const size_t half, const size_t offset, const size_t begin, const size_t end) const
Definition ntt.H:510
void transform_batch(Array< Array< uint64_t > > &batch, const bool invert) const
Sequential batch transform for equal-sized inputs.
Definition ntt.H:1212
Array< uint64_t > twiddles_fwd_
Definition ntt.H:477
void lower_output(Array< uint64_t > &a, ThreadPool *const pool, const size_t chunk_size) const
Definition ntt.H:847
size_t bluestein_size_
Definition ntt.H:481
void initialize_bluestein_plan()
Definition ntt.H:758
Array< uint64_t > bluestein_kernel_inverse_
Definition ntt.H:485
Array< uint64_t > ptransformed(ThreadPool &pool, const Array< uint64_t > &input, const bool invert=false, const size_t chunk_size=0) const
Parallel transform returning a new array.
Definition ntt.H:1177
uint64_t inv_n_
Definition ntt.H:479
Number Theoretic Transform over Z / MOD Z.
Definition ntt.H:115
static Array< uint64_t > interpolate_recursive(const Array< Array< uint64_t > > &tree, const Array< uint64_t > &scaled_values, const size_t node, const size_t left, const size_t right)
Definition ntt.H:1629
static Array< uint64_t > poly_sqrt(const Array< uint64_t > &coeffs, const size_t n)
Formal polynomial square root modulo x^n.
Definition ntt.H:2121
static uint64_t tonelli_shanks(const uint64_t value, const char *const ctx)
Definition ntt.H:1476
static bool avx2_dispatch_available() noexcept
Returns whether AVX2 dispatch is available at runtime.
Definition ntt.H:397
static constexpr const char * simd_preference_name(const SimdPreference preference) noexcept
Definition ntt.H:350
static void validate_distinct_points(const Array< uint64_t > &points, const char *const ctx)
Definition ntt.H:1572
static constexpr bool supports_bluestein_size(const size_t n) noexcept
Definition ntt.H:222
static Array< uint64_t > bigint_multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Multiply two non-negative integers represented as base-Base digits.
Definition ntt.H:2740
static Array< uint64_t > poly_exp(const Array< uint64_t > &coeffs, const size_t n)
Formal polynomial exponential modulo x^n.
Definition ntt.H:2084
static Array< Array< uint64_t > > transformed_batch(const Array< Array< uint64_t > > &batch, const bool invert=false)
Return a transformed copy of an entire batch.
Definition ntt.H:1850
static constexpr uint64_t add_mod(const uint64_t lhs, const uint64_t rhs) noexcept
Definition ntt.H:153
static void transform_batch(Array< Array< uint64_t > > &batch, const bool invert)
Sequential batch transform for equal-sized inputs.
Definition ntt.H:1833
static Array< uint64_t > poly_sub_series(const Array< uint64_t > &lhs, const Array< uint64_t > &rhs, const size_t n)
Definition ntt.H:1386
static Array< uint64_t > negacyclic_multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Negacyclic convolution modulo x^N + 1.
Definition ntt.H:1777
static Array< uint64_t > poly_integral(const Array< uint64_t > &coeffs)
Definition ntt.H:1463
static constexpr bool supports_size(const size_t n) noexcept
Check whether a transform size is supported.
Definition ntt.H:1668
static constexpr uint64_t sub_mod(const uint64_t lhs, const uint64_t rhs) noexcept
Definition ntt.H:161
static void multipoint_eval_recursive(const Array< Array< uint64_t > > &tree, const Array< uint64_t > &poly, Array< uint64_t > &output, const size_t node, const size_t left, const size_t right)
Definition ntt.H:1599
static Array< Array< uint64_t > > make_product_tree_storage(const size_t count)
Definition ntt.H:1539
NTTSimdBackend
SIMD backends available to the NTT butterfly core.
Definition ntt.H:123
@ avx2
x86-64 AVX2 grouped butterfly path.
@ scalar
Portable scalar implementation.
@ neon
AArch64 NEON grouped butterfly path.
static Array< uint64_t > poly_power(const Array< uint64_t > &coeffs, const uint64_t k, const size_t n)
Formal polynomial power modulo x^n.
Definition ntt.H:2181
static constexpr bool supports_root_order(const uint64_t order) noexcept
Definition ntt.H:216
static size_t next_power_of_two(size_t n, const char *const ctx)
Definition ntt.H:282
static Array< uint64_t > multipoint_eval(const Array< uint64_t > &coeffs, const Array< uint64_t > &points)
Evaluate a polynomial on multiple points modulo MOD.
Definition ntt.H:2240
static Array< uint64_t > poly_inverse(const Array< uint64_t > &coeffs, const size_t n)
Formal polynomial inverse modulo x^n.
Definition ntt.H:1979
static Array< uint64_t > poly_sub_normalized(const Array< uint64_t > &lhs, const Array< uint64_t > &rhs)
Definition ntt.H:1411
static Array< uint64_t > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Multiply two polynomials modulo MOD.
Definition ntt.H:1729
static Array< uint64_t > poly_add_series(const Array< uint64_t > &lhs, const Array< uint64_t > &rhs, const size_t n)
Definition ntt.H:1371
static Array< uint64_t > reverse_poly(const Array< uint64_t > &input)
Definition ntt.H:1361
static Array< uint64_t > poly_derivative(const Array< uint64_t > &coeffs)
Definition ntt.H:1449
static void validate_root_order(const uint64_t order, const char *const ctx)
Definition ntt.H:250
static Array< uint64_t > pmultiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0)
Parallel polynomial multiplication modulo MOD.
Definition ntt.H:1907
static Array< uint64_t > truncate_poly(const Array< uint64_t > &input, const size_t n)
Definition ntt.H:1350
static void ptransform(ThreadPool &pool, Array< uint64_t > &a, const bool invert, const size_t chunk_size=0)
Parallel in-place transform using a ThreadPool.
Definition ntt.H:1868
static Array< uint64_t > prefix_copy(const Array< uint64_t > &input, const size_t length)
Definition ntt.H:315
static constexpr uint64_t max_transform_size() noexcept
Maximum supported power-of-two transform size.
Definition ntt.H:1656
static constexpr const char * simd_backend_name(const NTTSimdBackend backend) noexcept
Definition ntt.H:334
static Array< uint64_t > interpolate(const Array< uint64_t > &points, const Array< uint64_t > &values)
Interpolate a polynomial from point-value samples modulo MOD.
Definition ntt.H:2272
static void multiply_inplace(Array< uint64_t > &a, const Array< uint64_t > &b)
Replace a by the product a * b.
Definition ntt.H:1756
static constexpr uint64_t mod
Definition ntt.H:330
static Array< uint64_t > ptransformed(ThreadPool &pool, const Array< uint64_t > &input, const bool invert=false, const size_t chunk_size=0)
Parallel transform returning a new array.
Definition ntt.H:1886
static SimdPreference simd_preference() noexcept
Definition ntt.H:367
static uint64_t poly_eval(const Array< uint64_t > &coeffs, const uint64_t x)
Evaluate a polynomial at a single point modulo MOD.
Definition ntt.H:1958
static Array< uint64_t > transformed(const Array< uint64_t > &input, const bool invert=false)
Transform an input array and return the result.
Definition ntt.H:1712
static Array< uint64_t > series_prefix(const Array< uint64_t > &input, const size_t n)
Definition ntt.H:1339
static void ptransform_batch(ThreadPool &pool, Array< Array< uint64_t > > &batch, const bool invert, const size_t chunk_size=0)
Parallel batch transform for equal-sized inputs.
Definition ntt.H:1940
static NTTSimdBackend detected_simd_backend() noexcept
Definition ntt.H:385
static constexpr bool supports_power_of_two_size(const size_t n) noexcept
Definition ntt.H:209
static constexpr MontgomeryCtx mctx_
Definition ntt.H:144
static constexpr bool simd_mod_supported() noexcept
Definition ntt.H:168
static constexpr uint64_t pow_mod_constexpr(uint64_t base, uint64_t exp) noexcept
Definition ntt.H:174
static Array< uint64_t > poly_mul_trunc(const Array< uint64_t > &lhs, const Array< uint64_t > &rhs, const size_t n)
Definition ntt.H:1433
static constexpr uint64_t root
Definition ntt.H:331
static bool neon_dispatch_available() noexcept
Returns whether AArch64 NEON dispatch is available at runtime.
Definition ntt.H:413
static Array< uint64_t > poly_add_normalized(const Array< uint64_t > &lhs, const Array< uint64_t > &rhs)
Definition ntt.H:1401
static std::pair< Array< uint64_t >, Array< uint64_t > > poly_divmod(const Array< uint64_t > &dividend, const Array< uint64_t > &divisor)
Polynomial division with remainder modulo MOD.
Definition ntt.H:2020
static void transform(Array< uint64_t > &a, const bool invert)
In-place forward or inverse transform.
Definition ntt.H:1699
static Array< uint64_t > poly_mod(const Array< uint64_t > &dividend, const Array< uint64_t > &divisor)
Definition ntt.H:1583
static constexpr uint64_t primitive_root_of_unity(const size_t n)
Return an n-th primitive root of unity modulo MOD.
Definition ntt.H:1682
static Array< uint64_t > normalize_poly(const Array< uint64_t > &input)
Definition ntt.H:1319
static NTTSimdBackend simd_backend() noexcept
Returns the SIMD backend selected under ALEPH_NTT_SIMD.
Definition ntt.H:432
static void validate_supported_size(const size_t n, const char *const ctx)
Definition ntt.H:267
static Array< uint64_t > poly_scalar_mul_series(const Array< uint64_t > &input, const uint64_t scalar, const size_t n)
Definition ntt.H:1421
static Array< uint64_t > pbigint_multiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0)
Parallel big-integer multiplication in base Base.
Definition ntt.H:2843
static void trim_trailing_zeros(Array< uint64_t > &poly)
Definition ntt.H:1312
SimdPreference
Definition ntt.H:131
static constexpr bool is_power_of_two(const size_t n) noexcept
Definition ntt.H:147
static Array< uint64_t > padded_copy(const Array< uint64_t > &input, const size_t n)
Definition ntt.H:301
Representation
Definition ntt.H:139
static Array< uint64_t > zero_series(const size_t n)
Definition ntt.H:1330
static constexpr uint64_t primitive_root_of_order(const uint64_t order)
Definition ntt.H:261
static const char * simd_backend_name() noexcept
Returns the active SIMD backend name.
Definition ntt.H:455
static void build_product_tree(Array< Array< uint64_t > > &tree, const Array< uint64_t > &points, const size_t node, const size_t left, const size_t right)
Definition ntt.H:1550
static Array< uint64_t > poly_log(const Array< uint64_t > &coeffs, const size_t n)
Formal polynomial logarithm modulo x^n.
Definition ntt.H:2056
static constexpr uint64_t max_transform_size_impl() noexcept
Definition ntt.H:196
A reusable thread pool for efficient parallel task execution.
size_t num_threads() const noexcept
Get the number of worker threads.
auto enqueue(F &&f, Args &&... args) -> std::future< std::invoke_result_t< F, Args... > >
Submit a task for execution and get a future for the result.
__gmp_expr< T, __gmp_unary_expr< __gmp_expr< T, U >, __gmp_exp_function > > exp(const __gmp_expr< T, U > &expr)
Definition gmpfrxx.h:4066
__gmp_expr< typename __gmp_resolve_expr< T, V >::value_type, __gmp_binary_expr< __gmp_expr< T, U >, __gmp_expr< V, W >, __gmp_remainder_function > > remainder(const __gmp_expr< T, U > &expr1, const __gmp_expr< V, W > &expr2)
Definition gmpfrxx.h:4115
const long double offset[]
Offset values indexed by symbol string length (bounded by MAX_OFFSET_INDEX)
Safe modular arithmetic, extended Euclidean algorithm, and Chinese Remainder Theorem.
Main namespace for Aleph-w library functions.
Definition ah-arena.H:89
and
Check uniqueness with explicit hash + equality functors.
uint64_t mod_inv(const uint64_t a, const uint64_t m)
Modular Inverse.
bool eq(const C1 &c1, const C2 &c2, Eq e=Eq())
Check equality of two containers using a predicate.
size_t size(Node *root) noexcept
static long & low(typename GT::Node *p)
Internal helper: low-link value stored directly in NODE_COOKIE(p).
void parallel_for_index(ThreadPool &pool, size_t start, size_t end, F &&f, size_t chunk_size=0)
Apply a function to each element in parallel (index-based).
Divide_Conquer_DP_Result< Cost > divide_and_conquer_partition_dp(const size_t groups, const size_t n, Transition_Cost_Fn transition_cost, const Cost inf=dp_optimization_detail::default_inf< Cost >())
Optimize partition DP using divide-and-conquer optimization.
uint64_t mod_exp(uint64_t base, uint64_t exp, const uint64_t m)
Modular exponentiation.
bool diff(const C1 &c1, const C2 &c2, Eq e=Eq())
Check if two containers differ.
uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t m)
Safe 64-bit modular multiplication.
auto mode(const Container &data) -> std::decay_t< decltype(*std::begin(data))>
Compute the mode (most frequent value).
Definition stat_utils.H:456
Itor::difference_type count(const Itor &beg, const Itor &end, const T &value)
Count elements equal to a value.
Definition ahAlgo.H:127
T sum(const Container &container, const T &init=T{})
Compute sum of all elements.
#define MOD(p)
Definition ntreepic.C:346
@ ROOT
Definition ntreepic.C:221
FooMap m(5, fst_unit_pair_hash, snd_unit_pair_hash)
static int * k
gsl_rng * r
A modern, efficient thread pool for parallel task execution.
Dynamic array container with automatic resizing.
ofstream output
Definition writeHeap.C:215