Aleph-w 3.0
A C++ Library for Data Structures and Algorithms
Loading...
Searching...
No Matches
ntt_test.cc
Go to the documentation of this file.
1/*
2 Aleph_w
3
4 Data structures & Algorithms
5 version 2.0.0b
6 https://github.com/lrleon/Aleph-w
7
8 This file is part of Aleph-w library
9
10 Copyright (c) 2002-2026 Leandro Rabindranath Leon
11
12 Permission is hereby granted, free of charge, to any person obtaining a copy
13 of this software and associated documentation files (the "Software"), to deal
14 in the Software without restriction, including without limitation the rights
15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 copies of the Software, and to permit persons to whom the Software is
17 furnished to do so, subject to the following conditions:
18
19 The above copyright notice and this permission notice shall be included in all
20 copies or substantial portions of the Software.
21
22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
29*/
30
36# include <gtest/gtest.h>
37
38# include <algorithm>
39# include <chrono>
40# include <cstdlib>
41# include <random>
42# include <string>
43
44# include <modular_arithmetic.H>
45# include <ntt.H>
46# include <thread_pool.H>
47
48using namespace Aleph;
49
50namespace
51{
52 void
53 set_env_var(const char * const name,
54 const char * const value)
55 {
56# if defined(_WIN32)
57 const int rc = _putenv_s(name, value);
58# else
59 const int rc = setenv(name, value, 1);
60# endif
61 if (rc != 0)
62 std::abort();
63 }
64
65 void
66 unset_env_var(const char * const name)
67 {
68# if defined(_WIN32)
69 const int rc = _putenv_s(name, "");
70# else
71 const int rc = unsetenv(name);
72# endif
73 if (rc != 0)
74 std::abort();
75 }
76
77 class ScopedEnvVar
78 {
79 std::string name_;
80 bool had_old_ = false;
81 std::string old_value_;
82
83 public:
84 ScopedEnvVar(const char * const name,
85 const char * const value)
86 : name_(name)
87 {
88 if (const char *current = std::getenv(name); current != nullptr)
89 {
90 had_old_ = true;
91 old_value_ = current;
92 }
93
94 set_env_var(name, value);
95 }
96
98 {
99 if (had_old_)
100 set_env_var(name_.c_str(), old_value_.c_str());
101 else
102 unset_env_var(name_.c_str());
103 }
104 };
105
106 template <typename NTTType>
109 {
111 for (size_t i = 0; i < input.size(); ++i)
112 output(i) = input[i] % NTTType::mod;
113 return output;
114 }
115
116 template <typename NTTType>
119 const Array<uint64_t> & b)
120 {
121 if (a.is_empty() or b.is_empty())
122 return {};
123
125 for (size_t i = 0; i < output.size(); ++i)
126 output(i) = 0;
127
128 for (size_t i = 0; i < a.size(); ++i)
129 for (size_t j = 0; j < b.size(); ++j)
130 {
131 const uint64_t term =
132 mod_mul(a[i] % NTTType::mod, b[j] % NTTType::mod, NTTType::mod);
133 output(i + j) = static_cast<uint64_t>(
134 (static_cast<__uint128_t>(output[i + j]) + term) % NTTType::mod);
135 }
136
137 return output;
138 }
139
140 template <typename NTTType>
143 const bool invert)
144 {
145 if (input.is_empty())
146 return {};
147
148 const size_t n = input.size();
149 const uint64_t mod = NTTType::mod;
150 const uint64_t root =
151 invert ?
152 mod_inv(NTTType::primitive_root_of_unity(n), mod) :
153 NTTType::primitive_root_of_unity(n);
154 const uint64_t inv_n = invert ? mod_inv(static_cast<uint64_t>(n), mod) : 1;
155
157 for (size_t k = 0; k < n; ++k)
158 {
159 uint64_t sum = 0;
160 for (size_t j = 0; j < n; ++j)
161 {
162 const uint64_t exponent = static_cast<uint64_t>(
163 (static_cast<__uint128_t>(j) * k) % n);
164 const uint64_t twiddle = mod_exp(root, exponent, mod);
165 const uint64_t term = mod_mul(input[j] % mod, twiddle, mod);
166 sum = static_cast<uint64_t>(
167 (static_cast<__uint128_t>(sum) + term) % mod);
168 }
169 output(k) = invert ? mod_mul(sum, inv_n, mod) : sum;
170 }
171
172 return output;
173 }
174
175 std::string
177 {
178 if (value == 0)
179 return "0";
180
181 std::string digits;
182 while (value > 0)
183 {
184 const auto digit = static_cast<unsigned>(value % 10);
185 digits.push_back(static_cast<char>('0' + digit));
186 value /= 10;
187 }
188
189 std::reverse(digits.begin(), digits.end());
190 return digits;
191 }
192
195 const Array<uint64_t> & b)
196 {
197 if (a.is_empty() or b.is_empty())
198 return {};
199
201 for (size_t i = 0; i < output.size(); ++i)
202 output(i) = 0;
203
204 for (size_t i = 0; i < a.size(); ++i)
205 for (size_t j = 0; j < b.size(); ++j)
206 output(i + j) += static_cast<__uint128_t>(a[i]) * b[j];
207
208 return output;
209 }
210
211 template <typename NTTType>
214 const Array<uint64_t> & b)
215 {
216 if (a.size() != b.size())
217 throw std::invalid_argument("naive_negacyclic: mismatched sizes");
218
220 for (size_t i = 0; i < output.size(); ++i)
221 output(i) = 0;
222
223 for (size_t i = 0; i < a.size(); ++i)
224 for (size_t j = 0; j < b.size(); ++j)
225 {
226 const uint64_t term =
227 mod_mul(a[i] % NTTType::mod, b[j] % NTTType::mod, NTTType::mod);
228 const size_t pos = i + j;
229 const size_t slot = pos < a.size() ? pos : pos - a.size();
230 if (pos < a.size())
231 output(slot) = static_cast<uint64_t>(
232 (static_cast<__uint128_t>(output[slot]) + term) % NTTType::mod);
233 else
234 output(slot) = output[slot] >= term ?
235 output[slot] - term :
236 NTTType::mod - (term - output[slot]);
237 }
238
239 return output;
240 }
241
242 template <uint64_t Base>
245 {
247 output.reserve(input.size());
248 for (size_t i = 0; i < input.size(); ++i)
249 output.append(input[i]);
250
251 while (not output.is_empty() and output.get_last() == 0)
252 static_cast<void>(output.remove_last());
253
254 if (output.is_empty())
255 {
257 output(0) = 0;
258 }
259
260 return output;
261 }
262
264 decimal_digits_from_string(const std::string & value)
265 {
266 if (value.empty())
267 return Array<uint64_t>({0});
268
270 digits.reserve(value.size());
271 for (size_t i = value.size(); i > 0; --i)
272 {
273 const char c = value[i - 1];
274 if (c < '0' or c > '9')
275 throw std::invalid_argument("decimal_digits_from_string: non-digit");
276 digits.append(static_cast<uint64_t>(c - '0'));
277 }
278
280 }
281
282 template <uint64_t Base>
285 const Array<uint64_t> & b)
286 {
289
290 if ((lhs.size() == 1 and lhs[0] == 0)
291 or (rhs.size() == 1 and rhs[0] == 0))
292 return Array<uint64_t>({0});
293
294 Array<__uint128_t> coeffs = Array<__uint128_t>::create(lhs.size() + rhs.size() - 1);
295 for (size_t i = 0; i < coeffs.size(); ++i)
296 coeffs(i) = 0;
297
298 for (size_t i = 0; i < lhs.size(); ++i)
299 for (size_t j = 0; j < rhs.size(); ++j)
300 coeffs(i + j) += static_cast<__uint128_t>(lhs[i]) * rhs[j];
301
303 output.reserve(coeffs.size() + 2);
304 __uint128_t carry = 0;
305 for (size_t i = 0; i < coeffs.size(); ++i)
306 {
307 const __uint128_t total = coeffs[i] + carry;
308 output.append(static_cast<uint64_t>(total % Base));
309 carry = total / Base;
310 }
311
312 while (carry > 0)
313 {
314 output.append(static_cast<uint64_t>(carry % Base));
315 carry /= Base;
316 }
317
319 }
320
322 all_nines_square_digits(const size_t n)
323 {
324 if (n == 0)
325 return Array<uint64_t>({0});
326
328 output(0) = 1;
329 for (size_t i = 1; i < n; ++i)
330 output(i) = 0;
331 output(n) = 8;
332 for (size_t i = n + 1; i < output.size(); ++i)
333 output(i) = 9;
334 return output;
335 }
336
337 template <typename NTTType>
338 void
340 {
341 while (not poly.is_empty() and poly.get_last() % NTTType::mod == 0)
342 static_cast<void>(poly.remove_last());
343 }
344
345 template <typename NTTType>
348 {
350 output.reserve(input.size());
351 for (size_t i = 0; i < input.size(); ++i)
352 output.append(input[i] % NTTType::mod);
354 return output;
355 }
356
357 template <typename NTTType>
360 const size_t n)
361 {
363 for (size_t i = 0; i < n; ++i)
364 output(i) = i < input.size() ? input[i] % NTTType::mod : 0;
365 return output;
366 }
367
368 template <typename NTTType>
371 {
372 if (coeffs.size() <= 1)
373 return {};
374
376 for (size_t i = 1; i < coeffs.size(); ++i)
377 output(i - 1) = mod_mul(coeffs[i] % NTTType::mod,
378 static_cast<uint64_t>(i) % NTTType::mod,
379 NTTType::mod);
380 return output;
381 }
382
383 template <typename NTTType>
385 poly_integral_ref(const Array<uint64_t> & coeffs,
386 const size_t n)
387 {
389 for (size_t i = 0; i < n; ++i)
390 output(i) = 0;
391
392 for (size_t i = 0; i < coeffs.size() and i + 1 < n; ++i)
393 output(i + 1) = mod_mul(coeffs[i] % NTTType::mod,
394 mod_inv(static_cast<uint64_t>(i + 1),
395 NTTType::mod),
396 NTTType::mod);
397 return output;
398 }
399
400 template <typename NTTType>
403 const Array<uint64_t> & b,
404 const size_t n)
405 {
407 }
408
409 template <typename NTTType>
411 naive_poly_inverse(const Array<uint64_t> & coeffs,
412 const size_t n)
413 {
414 if (n == 0)
415 return {};
416
419 for (size_t i = 0; i < n; ++i)
420 g(i) = 0;
421
422 const uint64_t inv0 = mod_inv(f[0], NTTType::mod);
423 g(0) = inv0;
424 for (size_t i = 1; i < n; ++i)
425 {
426 uint64_t sum = 0;
427 for (size_t j = 1; j <= i; ++j)
428 {
429 const uint64_t term =
430 mod_mul(f[j], g[i - j], NTTType::mod);
431 sum = static_cast<uint64_t>(
432 (static_cast<__uint128_t>(sum) + term) % NTTType::mod);
433 }
434
435 g(i) = mod_mul(sum == 0 ? 0 : NTTType::mod - sum,
436 inv0, NTTType::mod);
437 }
438
439 return g;
440 }
441
442 template <typename NTTType>
443 std::pair<Array<uint64_t>, Array<uint64_t>>
445 const Array<uint64_t> & divisor)
446 {
449
450 if (normalized_divisor.is_empty())
451 throw std::invalid_argument("naive_poly_divmod: zero divisor");
452
453 if (remainder.is_empty() or remainder.size() < normalized_divisor.size())
454 return {{}, remainder};
455
457 remainder.size() - normalized_divisor.size() + 1);
458 for (size_t i = 0; i < quotient.size(); ++i)
459 quotient(i) = 0;
460
461 const uint64_t inv_lead =
462 mod_inv(normalized_divisor.get_last(), NTTType::mod);
463 while (not remainder.is_empty()
464 and remainder.size() >= normalized_divisor.size())
465 {
466 const size_t shift = remainder.size() - normalized_divisor.size();
467 const uint64_t factor =
468 mod_mul(remainder.get_last(), inv_lead, NTTType::mod);
469 quotient(shift) = factor;
470
471 for (size_t i = 0; i < normalized_divisor.size(); ++i)
472 {
473 const size_t pos = shift + i;
474 const uint64_t term = mod_mul(factor,
476 NTTType::mod);
477 remainder(pos) =
478 remainder[pos] >= term ?
479 remainder[pos] - term :
480 NTTType::mod - (term - remainder[pos]);
481 }
482
484 }
485
487 }
488
489 template <typename NTTType>
491 naive_poly_log(const Array<uint64_t> & coeffs,
492 const size_t n)
493 {
494 if (n == 0)
495 return {};
496
497 const Array<uint64_t> derivative =
499 const Array<uint64_t> inverse = naive_poly_inverse<NTTType>(coeffs, n - 1);
501 naive_multiply_trunc<NTTType>(derivative, inverse, n - 1), n);
502 }
503
504 template <typename NTTType>
506 naive_poly_exp(const Array<uint64_t> & coeffs,
507 const size_t n)
508 {
509 if (n == 0)
510 return {};
511
512 const Array<uint64_t> f = series_prefix_mod<NTTType>(coeffs, n);
514 for (size_t i = 0; i < n; ++i)
515 g(i) = 0;
516 g(0) = 1;
517
518 for (size_t m = 1; m < n; ++m)
519 {
520 uint64_t sum = 0;
521 for (size_t i = 1; i <= m; ++i)
522 {
523 const uint64_t weighted =
524 mod_mul(static_cast<uint64_t>(i) % NTTType::mod,
525 f[i], NTTType::mod);
526 const uint64_t term =
527 mod_mul(weighted, g[m - i], NTTType::mod);
528 sum = static_cast<uint64_t>(
529 (static_cast<__uint128_t>(sum) + term) % NTTType::mod);
530 }
531
532 g(m) = mod_mul(sum,
533 mod_inv(static_cast<uint64_t>(m), NTTType::mod),
534 NTTType::mod);
535 }
536
537 return g;
538 }
539
540 template <typename NTTType>
542 naive_poly_power(const Array<uint64_t> & coeffs,
543 const uint64_t k,
544 const size_t n)
545 {
546 if (n == 0)
547 return {};
548
549 if (k == 0)
550 {
552 for (size_t i = 0; i < n; ++i)
553 output(i) = 0;
554 output(0) = 1;
555 return output;
556 }
557
559 for (size_t i = 0; i < n; ++i)
560 result(i) = 0;
561 result(0) = 1;
562
563 const Array<uint64_t> base = series_prefix_mod<NTTType>(coeffs, n);
564 for (uint64_t iter = 0; iter < k; ++iter)
565 result = naive_multiply_trunc<NTTType>(result, base, n);
566
567 return result;
568 }
569
570 template <typename NTTType>
573 const Array<uint64_t> & points)
574 {
576 for (size_t i = 0; i < points.size(); ++i)
577 {
578 uint64_t value = 0;
579 const uint64_t x = points[i] % NTTType::mod;
580 for (size_t j = coeffs.size(); j > 0; --j)
581 {
582 value = mod_mul(value, x, NTTType::mod);
583 value = static_cast<uint64_t>(
584 (static_cast<__uint128_t>(value)
585 + coeffs[j - 1] % NTTType::mod) % NTTType::mod);
586 }
587 output(i) = value;
588 }
589 return output;
590 }
591
592 template <typename NTTType>
594 naive_interpolate(const Array<uint64_t> & points,
595 const Array<uint64_t> & values)
596 {
597 Array<uint64_t> result;
598
599 for (size_t i = 0; i < points.size(); ++i)
600 {
602 uint64_t denom = 1;
603 const uint64_t xi = points[i] % NTTType::mod;
604
605 for (size_t j = 0; j < points.size(); ++j)
606 if (i != j)
607 {
608 const uint64_t xj = points[j] % NTTType::mod;
610 basis,
611 Array<uint64_t>({xj == 0 ? 0 : NTTType::mod - xj, 1}));
613 xi >= xj ? xi - xj : NTTType::mod - (xj - xi),
614 NTTType::mod);
615 }
616
617 const uint64_t scale =
618 mod_mul(values[i] % NTTType::mod,
619 mod_inv(denom, NTTType::mod),
620 NTTType::mod);
622 for (size_t k = 0; k < basis.size(); ++k)
623 scaled(k) = mod_mul(basis[k] % NTTType::mod,
624 scale, NTTType::mod);
625
626 const size_t out_size = std::max(result.size(), scaled.size());
628 for (size_t k = 0; k < out_size; ++k)
629 {
630 const uint64_t lhs = k < result.size() ? result[k] : 0;
631 const uint64_t rhs = k < scaled.size() ? scaled[k] : 0;
632 sum(k) = static_cast<uint64_t>(
633 (static_cast<__uint128_t>(lhs) + rhs) % NTTType::mod);
634 }
635 result = normalize_mod_poly<NTTType>(sum);
636 }
637
638 return result;
639 }
640
641 void
643 const Array<uint64_t> & rhs,
644 const char * const ctx)
645 {
646 ASSERT_EQ(lhs.size(), rhs.size()) << ctx;
647 for (size_t i = 0; i < lhs.size(); ++i)
648 EXPECT_EQ(lhs[i], rhs[i]) << ctx << " index=" << i;
649 }
650
651 void
653 const Array<__uint128_t> & rhs,
654 const char * const ctx)
655 {
656 ASSERT_EQ(lhs.size(), rhs.size()) << ctx;
657 for (size_t i = 0; i < lhs.size(); ++i)
658 EXPECT_TRUE(lhs[i] == rhs[i])
659 << ctx << " index=" << i
660 << " lhs=" << u128_to_string(lhs[i])
661 << " rhs=" << u128_to_string(rhs[i]);
662 }
663}
664
665class NTTIndustrialTest : public ::testing::Test
666{
667protected:
670
671 std::mt19937_64 rng_{123456789ULL};
672
674 random_poly(const size_t n)
675 {
677 for (size_t i = 0; i < n; ++i)
678 output(i) = rng_();
679 return output;
680 }
681};
682
684{
685 EXPECT_EQ(DefaultNTT::max_transform_size(), 1ULL << 23);
686 EXPECT_TRUE(DefaultNTT::supports_size(1));
687 EXPECT_TRUE(DefaultNTT::supports_size(8));
688 EXPECT_TRUE(DefaultNTT::supports_size(7));
689 EXPECT_TRUE(DefaultNTT::supports_size(14));
690 EXPECT_TRUE(DefaultNTT::supports_size(17));
691 EXPECT_TRUE(DefaultNTT::supports_size(1ULL << 23));
692 EXPECT_FALSE(DefaultNTT::supports_size(0));
693 EXPECT_FALSE(DefaultNTT::supports_size(3));
694 EXPECT_FALSE(DefaultNTT::supports_size(10));
695 EXPECT_FALSE(DefaultNTT::supports_size((1ULL << 23) + 2));
696
697 const uint64_t root8 = DefaultNTT::primitive_root_of_unity(8);
698 EXPECT_EQ(mod_exp(root8, 8, DefaultNTT::mod), 1ULL);
699 EXPECT_NE(mod_exp(root8, 4, DefaultNTT::mod), 1ULL);
700
701 const uint64_t root7 = DefaultNTT::primitive_root_of_unity(7);
702 EXPECT_EQ(mod_exp(root7, 7, DefaultNTT::mod), 1ULL);
703 EXPECT_NE(root7, 1ULL);
704 EXPECT_EQ(DefaultNTT::primitive_root_of_unity(1), 1ULL);
705}
706
708{
709 ScopedEnvVar auto_mode("ALEPH_NTT_SIMD", "auto");
710
711 const std::string backend = DefaultNTT::simd_backend_name();
712 EXPECT_TRUE(backend == "scalar" or backend == "avx2" or backend == "neon");
713 EXPECT_FALSE(DefaultNTT::avx2_dispatch_available()
714 and DefaultNTT::neon_dispatch_available());
715}
716
718{
719 const Array<size_t> sizes = {1, 2, 4, 8, 16, 64, 256, 1024};
720
721 for (size_t i = 0; i < sizes.size(); ++i)
722 for (size_t sample = 0; sample < 6; ++sample)
723 {
724 Array<uint64_t> values = random_poly(sizes[i]);
726
727 DefaultNTT::transform(values, false);
728 DefaultNTT::transform(values, true);
729
730 expect_equal_arrays(values, expected, "Default sequential round trip");
731 }
732}
733
735{
736 const Array<uint64_t> input = random_poly(1 << 10);
737 const Array<uint64_t> lhs = random_poly(300);
738 const Array<uint64_t> rhs = random_poly(280);
739
741 Array<uint64_t> scalar_product;
742 {
743 ScopedEnvVar scalar_mode("ALEPH_NTT_SIMD", "scalar");
744 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()), "scalar");
745 DefaultNTT::transform(scalar_transform, false);
746 scalar_product = DefaultNTT::multiply(lhs, rhs);
747 }
748
749 {
750 ScopedEnvVar avx2_mode("ALEPH_NTT_SIMD", "avx2");
752 DefaultNTT::transform(avx2_transform, false);
753 const Array<uint64_t> avx2_product = DefaultNTT::multiply(lhs, rhs);
754
755 if (DefaultNTT::avx2_dispatch_available())
756 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()), "avx2");
757 else
758 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()), "scalar");
759
760 expect_equal_arrays(avx2_transform, scalar_transform, "AVX2 forced transform");
761 expect_equal_arrays(avx2_product, scalar_product, "AVX2 forced multiply");
762 }
763
764 {
765 ScopedEnvVar neon_mode("ALEPH_NTT_SIMD", "neon");
767 DefaultNTT::transform(neon_transform, false);
768 const Array<uint64_t> neon_product = DefaultNTT::multiply(lhs, rhs);
769
770 if (DefaultNTT::neon_dispatch_available())
771 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()), "neon");
772 else
773 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()), "scalar");
774
775 expect_equal_arrays(neon_transform, scalar_transform, "NEON forced transform");
776 expect_equal_arrays(neon_product, scalar_product, "NEON forced multiply");
777 }
778}
779
781{
782 const Array<uint64_t> poly = {5, 3, 7, 11};
783 const Array<uint64_t> points = {0, 1, 9, DefaultNTT::mod + 5ULL};
787 DefaultNTT::multipoint_eval(poly, points);
788 expect_equal_arrays(actual_eval, expected_eval, "Multipoint eval");
789 EXPECT_EQ(DefaultNTT::poly_eval(poly, 9), expected_eval[2]);
790
791 const Array<uint64_t> dividend = {3, 4, 5, 6, 7};
792 const Array<uint64_t> divisor = {1, 2, 1};
794 const auto actual = DefaultNTT::poly_divmod(dividend, divisor);
795 expect_equal_arrays(actual.first, expected.first, "poly_divmod quotient");
796 expect_equal_arrays(actual.second, expected.second, "poly_divmod remainder");
797}
798
800{
801 const Array<uint64_t> coeffs = {3, 4, 5, 6};
802 const size_t n = 8;
803
805 const Array<uint64_t> actual = DefaultNTT::poly_inverse(coeffs, n);
806 expect_equal_arrays(actual, expected, "poly_inverse");
807
808 const Array<uint64_t> product = DefaultNTT::multiply(coeffs, actual);
811 for (size_t i = 0; i < n; ++i)
812 identity(i) = 0;
813 identity(0) = 1;
814 expect_equal_arrays(prefix, identity, "inverse identity");
815}
816
818{
819 const Array<uint64_t> unit = {1, 4, 7, 2, 3};
820 const size_t n = 7;
821
823 const Array<uint64_t> actual_log = DefaultNTT::poly_log(unit, n);
825
826 const Array<uint64_t> recovered = DefaultNTT::poly_exp(actual_log, n);
828 "exp(log(f))");
829
830 const Array<uint64_t> expo_input = {0, 5, 3, 1, 4};
832 const Array<uint64_t> actual_exp = DefaultNTT::poly_exp(expo_input, n);
834
838 DefaultNTT::poly_power(unit, 3, n);
840}
841
843{
844 const size_t n = 8;
845
847 Array<uint64_t>({5, 7, 3, 2}), n);
848 const Array<uint64_t> square =
850 const Array<uint64_t> recovered = DefaultNTT::poly_sqrt(square, n);
851 expect_equal_arrays(recovered, root, "poly_sqrt non-zero constant");
852
854 Array<uint64_t>({0, 5, 4, 1}), n);
858 DefaultNTT::poly_sqrt(shifted_square, n);
860 "poly_sqrt shifted root");
861}
862
864{
865 const Array<uint64_t> coeffs = {9, 7, 5, 3, 1};
866 const Array<uint64_t> points = {0, 1, 2, 5, 11};
867 const Array<uint64_t> values =
868 naive_multipoint_eval<DefaultNTT>(coeffs, points);
869
871 DefaultNTT::multipoint_eval(coeffs, points);
872 expect_equal_arrays(actual_values, values, "multipoint_eval");
873
875 naive_interpolate<DefaultNTT>(points, values);
877 DefaultNTT::interpolate(points, values);
879
881 DefaultNTT::multipoint_eval(actual_poly, points);
882 expect_equal_arrays(round_trip, values, "interpolate/eval round trip");
883}
884
886{
888 decimal_digits_from_string("12345678901234567890");
890 decimal_digits_from_string("98765432109876543210");
892 decimal_digits_from_string("1219326311370217952237463801111263526900");
894 DefaultNTT::bigint_multiply<10>(decimal_lhs, decimal_rhs);
895 expect_equal_arrays(decimal_actual, decimal_expected, "decimal bigint product");
896
897 const Array<uint64_t> base1000_lhs = {999, 123, 456, 7};
898 const Array<uint64_t> base1000_rhs = {777, 1, 5};
899 expect_equal_arrays(DefaultNTT::bigint_multiply<1000>(base1000_lhs, base1000_rhs),
901 "base-1000 bigint product");
902
903 const Array<uint64_t> base2p15_lhs = {32767, 5, 1024, 77};
904 const Array<uint64_t> base2p15_rhs = {32767, 7, 4096};
905 expect_equal_arrays(DefaultNTT::bigint_multiply<(1ULL << 15)>(base2p15_lhs,
909 "base-2^15 bigint product");
910}
911
913{
914 ThreadPool pool(4);
915
916 const Array<uint64_t> lhs = {9, 9, 0, 0};
917 const Array<uint64_t> rhs = {9, 9, 0};
918 const Array<uint64_t> expected = {1, 0, 8, 9};
919
920 const Array<uint64_t> sequential = DefaultNTT::bigint_multiply<10>(lhs, rhs);
921 const Array<uint64_t> parallel =
922 DefaultNTT::pbigint_multiply<10>(pool, lhs, rhs);
923 expect_equal_arrays(sequential, expected, "normalized bigint product");
924 expect_equal_arrays(parallel, sequential, "parallel bigint product");
925
927 DefaultNTT::bigint_multiply<10>(Array<uint64_t>(), rhs);
929 DefaultNTT::bigint_multiply<10>(lhs, Array<uint64_t>({0, 0, 0}));
930 expect_equal_arrays(zero_left, Array<uint64_t>({0}), "empty bigint is zero");
931 expect_equal_arrays(zero_right, Array<uint64_t>({0}), "zero bigint product");
932}
933
935{
936 const size_t digits = 4096;
938 for (size_t i = 0; i < digits; ++i)
939 nines(i) = 9;
940
941 const Array<uint64_t> actual = DefaultNTT::bigint_multiply<10>(nines, nines);
943 expect_equal_arrays(actual, expected, "all-9 decimal square");
944}
945
947{
948 const Array<size_t> sizes = {2, 4, 8, 1024};
949
950 for (size_t i = 0; i < sizes.size(); ++i)
951 {
952 Array<uint64_t> lhs = random_poly(sizes[i]);
953 Array<uint64_t> rhs = random_poly(sizes[i]);
955 const Array<uint64_t> actual = DefaultNTT::negacyclic_multiply(lhs, rhs);
956 expect_equal_arrays(actual, expected, "negacyclic multiply");
957 }
958
959 const Array<uint64_t> demo_lhs = {1, 1, 1, 1};
960 const Array<uint64_t> demo_rhs = {1, 1, 0, 0};
961 const Array<uint64_t> expected_demo = {0, 2, 2, 2};
962 expect_equal_arrays(DefaultNTT::negacyclic_multiply(demo_lhs, demo_rhs),
964 "negacyclic demo product");
965}
966
968{
969 if (const char * env = std::getenv("ENABLE_PERF_TESTS");
970 env == nullptr or std::string(env) != "1")
971 {
972 GTEST_SKIP() << "Set ENABLE_PERF_TESTS=1 to run million-digit bigint checks";
973 }
974
975 const size_t digits = 1000000;
977 for (size_t i = 0; i < digits; ++i)
978 nines(i) = 9;
979
980 const Array<uint64_t> product = DefaultNTT::bigint_multiply<10>(nines, nines);
981 ASSERT_EQ(product.size(), 2 * digits);
982 EXPECT_EQ(product[0], 1ULL);
983 EXPECT_EQ(product[digits - 1], 0ULL);
984 EXPECT_EQ(product[digits], 8ULL);
985 EXPECT_EQ(product[product.size() - 1], 9ULL);
986}
987
989{
990 const Array<size_t> sizes = {1, 2, 4, 32, 128};
991
992 for (size_t i = 0; i < sizes.size(); ++i)
993 for (size_t sample = 0; sample < 4; ++sample)
994 {
995 Array<uint64_t> values = random_poly(sizes[i]);
997
998 AlternateNTT::transform(values, false);
999 AlternateNTT::transform(values, true);
1000
1001 expect_equal_arrays(values, expected, "Alternate sequential round trip");
1002 }
1003}
1004
1006{
1007 const Array<size_t> sizes = {7, 14, 17, 119};
1008
1009 for (size_t i = 0; i < sizes.size(); ++i)
1010 for (size_t sample = 0; sample < 3; ++sample)
1011 {
1012 Array<uint64_t> input = random_poly(sizes[i]);
1013
1018
1020 DefaultNTT::transform(actual_forward, false);
1022 "Bluestein forward transform");
1023
1024 DefaultNTT::transform(actual_forward, true);
1026 "Bluestein inverse round trip");
1028 "Naive inverse round trip");
1029 }
1030}
1031
1033{
1034 const Array<size_t> lhs_sizes = {1, 2, 3, 5, 8, 17, 32};
1035 const Array<size_t> rhs_sizes = {1, 2, 4, 7, 9, 17, 24};
1036
1037 for (size_t i = 0; i < lhs_sizes.size(); ++i)
1038 for (size_t j = 0; j < rhs_sizes.size(); ++j)
1039 {
1040 Array<uint64_t> lhs = random_poly(lhs_sizes[i]);
1041 Array<uint64_t> rhs = random_poly(rhs_sizes[j]);
1043 const Array<uint64_t> got = DefaultNTT::multiply(lhs, rhs);
1044
1045 expect_equal_arrays(got, expected, "Static multiply");
1046 }
1047}
1048
1050{
1051 typename DefaultNTT::Plan plan(64);
1052
1053 Array<uint64_t> input = random_poly(64);
1056
1057 DefaultNTT::transform(static_result, false);
1058 plan.transform(plan_result, false);
1059 expect_equal_arrays(plan_result, static_result, "Plan forward transform");
1060
1061 DefaultNTT::transform(static_result, true);
1062 plan.transform(plan_result, true);
1063 expect_equal_arrays(plan_result, static_result, "Plan inverse transform");
1064
1065 Array<uint64_t> lhs = random_poly(20);
1066 Array<uint64_t> rhs = random_poly(17);
1067 const Array<uint64_t> expected = DefaultNTT::multiply(lhs, rhs);
1068 const Array<uint64_t> got = plan.multiply(lhs, rhs);
1069 expect_equal_arrays(got, expected, "Plan multiply");
1070}
1071
1073{
1074 typename DefaultNTT::Plan plan(16);
1077
1078 for (size_t i = 0; i < 4; ++i)
1079 {
1080 Array<uint64_t> item = random_poly(16);
1081 batch.append(item);
1082 original.append(item);
1083 }
1084
1086 for (size_t i = 0; i < expected.size(); ++i)
1087 DefaultNTT::transform(expected(i), false);
1088
1089 plan.transform_batch(batch, false);
1090 for (size_t i = 0; i < batch.size(); ++i)
1091 expect_equal_arrays(batch[i], expected[i], "Batch forward transform");
1092
1093 plan.transform_batch(batch, true);
1094 for (size_t i = 0; i < batch.size(); ++i)
1096 "Batch inverse transform");
1097}
1098
1100{
1101 typename DefaultNTT::Plan plan(7);
1102
1103 Array<uint64_t> input = random_poly(7);
1106 plan.transform(actual, false);
1107 expect_equal_arrays(actual, expected, "Bluestein plan forward transform");
1108
1109 plan.transform(actual, true);
1111 "Bluestein plan inverse round trip");
1112
1114 Array<Array<uint64_t>> reference;
1115 for (size_t i = 0; i < 3; ++i)
1116 {
1117 Array<uint64_t> item = random_poly(7);
1118 batch.append(item);
1119 Array<uint64_t> transformed = item;
1120 DefaultNTT::transform(transformed, false);
1121 reference.append(transformed);
1122 }
1123
1124 plan.transform_batch(batch, false);
1125 for (size_t i = 0; i < batch.size(); ++i)
1126 expect_equal_arrays(batch[i], reference[i], "Bluestein batch transform");
1127}
1128
1130{
1131 ThreadPool pool(4);
1132
1133 Array<uint64_t> input = random_poly(1 << 12);
1134 Array<uint64_t> sequential = input;
1135 Array<uint64_t> parallel = input;
1136
1137 DefaultNTT::transform(sequential, false);
1138 DefaultNTT::ptransform(pool, parallel, false);
1139 expect_equal_arrays(parallel, sequential, "Parallel forward transform");
1140
1141 DefaultNTT::transform(sequential, true);
1142 DefaultNTT::ptransform(pool, parallel, true);
1143 expect_equal_arrays(parallel, sequential, "Parallel inverse transform");
1144
1145 Array<uint64_t> lhs = random_poly(300);
1146 Array<uint64_t> rhs = random_poly(280);
1147 const Array<uint64_t> expected = DefaultNTT::multiply(lhs, rhs);
1148 const Array<uint64_t> got = DefaultNTT::pmultiply(pool, lhs, rhs);
1149 expect_equal_arrays(got, expected, "Parallel multiply");
1150}
1151
1153{
1154 ThreadPool pool(4);
1155 typename DefaultNTT::Plan plan(256);
1156
1158 for (size_t i = 0; i < 6; ++i)
1159 batch.append(random_poly(256));
1160
1161 Array<Array<uint64_t>> sequential = batch;
1162 Array<Array<uint64_t>> parallel = batch;
1163
1164 plan.transform_batch(sequential, false);
1165 plan.ptransform_batch(pool, parallel, false);
1166
1167 for (size_t i = 0; i < sequential.size(); ++i)
1168 expect_equal_arrays(parallel[i], sequential[i], "Parallel batch transform");
1169}
1170
1172{
1173 Array<uint64_t> lhs = {1, 2, 3};
1174 const Array<uint64_t> rhs = {4, 5, 6};
1175 const Array<uint64_t> expected = DefaultNTT::multiply(lhs, rhs);
1176
1177 DefaultNTT::multiply_inplace(lhs, rhs);
1178 expect_equal_arrays(lhs, expected, "multiply_inplace");
1179}
1180
1190
1192{
1193 const Array<uint64_t> lhs = {
1194 1000000000000ULL,
1195 DefaultNTT::mod + 77ULL,
1196 1234567890123ULL
1197 };
1198 const Array<uint64_t> rhs = {
1199 777777777777ULL,
1200 42ULL,
1201 DefaultNTT::mod + 11ULL
1202 };
1203
1206 expect_equal_arrays_u128(got, expected, "Exact multiply");
1207}
1208
1210{
1211 const Array<uint64_t> lhs = {
1212 (1ULL << 63) - 25ULL,
1213 0ULL,
1214 (1ULL << 63) - 17ULL
1215 };
1216 const Array<uint64_t> rhs = {7ULL, 1ULL};
1217
1220 expect_equal_arrays_u128(got, expected, "Near-2^63 exact multiply");
1221}
1222
1224{
1225 ThreadPool pool(4);
1226
1229 for (size_t i = 0; i < lhs.size(); ++i)
1230 lhs(i) = DefaultNTT::mod + 1000ULL + (rng_() % 1000000ULL);
1231 for (size_t i = 0; i < rhs.size(); ++i)
1232 rhs(i) = DefaultNTT::mod + 500ULL + (rng_() % 2000000ULL);
1233
1234 const Array<__uint128_t> sequential = NTTExact::multiply(lhs, rhs);
1235 const Array<__uint128_t> parallel = NTTExact::pmultiply(pool, lhs, rhs);
1236 expect_equal_arrays_u128(parallel, sequential, "Parallel exact multiply");
1237
1241
1242 ASSERT_EQ(sequential.size(), mod0.size());
1243 for (size_t i = 0; i < sequential.size(); ++i)
1244 {
1245 EXPECT_EQ(static_cast<uint64_t>(sequential[i] % 998244353ULL), mod0[i])
1246 << "Residue mod 998244353 at index " << i;
1247 EXPECT_EQ(static_cast<uint64_t>(sequential[i] % 469762049ULL), mod1[i])
1248 << "Residue mod 469762049 at index " << i;
1249 EXPECT_EQ(static_cast<uint64_t>(sequential[i] % 1004535809ULL), mod2[i])
1250 << "Residue mod 1004535809 at index " << i;
1251 }
1252}
1253
1255{
1256 const Array<uint64_t> lhs = {(1ULL << 63) - 1ULL};
1257 const Array<uint64_t> rhs = {(1ULL << 63) - 1ULL};
1258 EXPECT_THROW(static_cast<void>(NTTExact::multiply(lhs, rhs)),
1259 std::invalid_argument);
1260}
1261
1263{
1264 Array<uint64_t> empty;
1265 Array<uint64_t> bad = {1, 2, 3};
1266 Array<uint64_t> unsupported = {1, 2, 3, 4, 5, 6};
1267 Array<Array<uint64_t>> mismatched_batch = {{1, 2, 3, 4}, {1, 2}};
1268
1269 EXPECT_THROW(DefaultNTT::transform(empty, false), std::invalid_argument);
1270 EXPECT_THROW(DefaultNTT::transform(bad, false), std::invalid_argument);
1271 EXPECT_THROW(DefaultNTT::transform(unsupported, false),
1272 std::invalid_argument);
1273 EXPECT_THROW(static_cast<void>(DefaultNTT::primitive_root_of_unity(3)),
1274 std::invalid_argument);
1275 EXPECT_THROW(static_cast<void>(DefaultNTT::primitive_root_of_unity(10)),
1276 std::invalid_argument);
1277 EXPECT_THROW(([]()
1278 {
1279 typename DefaultNTT::Plan invalid_plan(0);
1281 }()),
1282 std::invalid_argument);
1283
1284 typename DefaultNTT::Plan small_plan(4);
1285 EXPECT_THROW(small_plan.multiply(Array<uint64_t>({1, 2, 3}),
1286 Array<uint64_t>({4, 5, 6})),
1287 std::invalid_argument);
1288 EXPECT_THROW(small_plan.transform_batch(mismatched_batch, false),
1289 std::invalid_argument);
1290
1291 ThreadPool pool(2);
1292 EXPECT_THROW(small_plan.ptransform_batch(pool, mismatched_batch, false),
1293 std::invalid_argument);
1294
1295 EXPECT_THROW(static_cast<void>(DefaultNTT::poly_inverse(Array<uint64_t>({0, 1}), 4)),
1296 std::invalid_argument);
1297 EXPECT_THROW(static_cast<void>(DefaultNTT::poly_divmod(Array<uint64_t>({1, 2}),
1298 Array<uint64_t>())),
1299 std::invalid_argument);
1300 EXPECT_THROW(static_cast<void>(DefaultNTT::poly_log(Array<uint64_t>({2, 1}), 4)),
1301 std::invalid_argument);
1302 EXPECT_THROW(static_cast<void>(DefaultNTT::poly_exp(Array<uint64_t>({1, 1}), 4)),
1303 std::invalid_argument);
1304 EXPECT_THROW(static_cast<void>(DefaultNTT::poly_sqrt(Array<uint64_t>({3}), 4)),
1305 std::invalid_argument);
1306 EXPECT_THROW(static_cast<void>(DefaultNTT::interpolate(Array<uint64_t>({1, 1}),
1307 Array<uint64_t>({2, 3}))),
1308 std::invalid_argument);
1309 EXPECT_THROW(static_cast<void>(DefaultNTT::negacyclic_multiply(Array<uint64_t>(),
1310 Array<uint64_t>())),
1311 std::invalid_argument);
1312 EXPECT_THROW(static_cast<void>(DefaultNTT::negacyclic_multiply(Array<uint64_t>({1, 2}),
1313 Array<uint64_t>({1, 2, 3, 4}))),
1314 std::invalid_argument);
1315 EXPECT_THROW(static_cast<void>(DefaultNTT::negacyclic_multiply(Array<uint64_t>({1, 2, 3}),
1316 Array<uint64_t>({4, 5, 6}))),
1317 std::invalid_argument);
1318 EXPECT_THROW(([]()
1319 {
1322 for (size_t i = 0; i < 16; ++i)
1323 {
1324 lhs(i) = static_cast<uint64_t>(i & 1U);
1325 rhs(i) = static_cast<uint64_t>((i + 1) & 1U);
1326 }
1327 static_cast<void>(NTT<17ULL, 3ULL>::negacyclic_multiply(lhs, rhs));
1328 }()),
1329 std::invalid_argument);
1330 EXPECT_THROW(static_cast<void>(DefaultNTT::bigint_multiply<10>(Array<uint64_t>({10}),
1331 Array<uint64_t>({1}))),
1332 std::invalid_argument);
1333 EXPECT_THROW(static_cast<void>(DefaultNTT::bigint_multiply<(1ULL << 48)>(
1334 Array<uint64_t>({(1ULL << 48) - 1,
1335 (1ULL << 48) - 1,
1336 (1ULL << 48) - 1}),
1337 Array<uint64_t>({(1ULL << 48) - 1,
1338 (1ULL << 48) - 1,
1339 (1ULL << 48) - 1}))),
1340 std::invalid_argument);
1341
1342 EXPECT_THROW(static_cast<void>(NTTExact::multiply(Array<uint64_t>({1ULL << 63,
1343 1ULL << 63}),
1344 Array<uint64_t>({1ULL << 63,
1345 1ULL << 63}))),
1346 std::invalid_argument);
1347}
1348
1350{
1351 if (std::getenv("ENABLE_PERF_TESTS") == nullptr)
1352 GTEST_SKIP() << "Set ENABLE_PERF_TESTS=1 to run NTT performance checks";
1353
1354 ThreadPool pool(4);
1355 const size_t n = 1 << 20;
1358
1359 for (size_t i = 0; i < n; ++i)
1360 {
1361 lhs(i) = static_cast<uint64_t>(i);
1362 rhs(i) = static_cast<uint64_t>(n - i);
1363 }
1364
1365 const auto start = std::chrono::steady_clock::now();
1366 const Array<uint64_t> product = DefaultNTT::pmultiply(pool, lhs, rhs);
1367 const auto end = std::chrono::steady_clock::now();
1368
1369 ASSERT_EQ(product.size(), 2 * n - 1);
1370
1371 const auto elapsed =
1372 std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
1373 .count();
1374 EXPECT_LE(elapsed, 2500)
1375 << "NTT performance regression detected";
1376}
Simple dynamic array with automatic resizing and functional operations.
Definition tpl_array.H:139
static Array create(size_t n)
Create an array with n logical elements.
Definition tpl_array.H:194
constexpr size_t size() const noexcept
Return the number of elements stored in the stack.
Definition tpl_array.H:351
constexpr bool is_empty() const noexcept
Checks if the container is empty.
Definition tpl_array.H:348
T & append(const T &data)
Append a copy of data
Definition tpl_array.H:245
T & get_last() noexcept
return a modifiable reference to the last element.
Definition tpl_array.H:366
void reserve(size_t cap)
Reserves cap cells into the array.
Definition tpl_array.H:315
static Array< coeff_type > pmultiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0)
Exact parallel polynomial multiplication.
Definition ntt.H:2704
static constexpr bool supports_product_size(const size_t required) noexcept
Check whether a target product length is supported.
Definition ntt.H:2656
static Array< coeff_type > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Exact sequential polynomial multiplication.
Definition ntt.H:2673
Number Theoretic Transform over Z / MOD Z.
Definition ntt.H:115
static Array< uint64_t > negacyclic_multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Negacyclic convolution modulo x^N + 1.
Definition ntt.H:1777
static Array< uint64_t > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Multiply two polynomials modulo MOD.
Definition ntt.H:1729
A reusable thread pool for efficient parallel task execution.
std::mt19937_64 rng_
Definition ntt_test.cc:671
Array< uint64_t > random_poly(const size_t n)
Definition ntt_test.cc:674
__gmp_expr< typename __gmp_resolve_expr< T, V >::value_type, __gmp_binary_expr< __gmp_expr< T, U >, __gmp_expr< V, W >, __gmp_remainder_function > > remainder(const __gmp_expr< T, U > &expr1, const __gmp_expr< V, W > &expr2)
Definition gmpfrxx.h:4115
__gmp_expr< T, __gmp_binary_expr< __gmp_expr< T, U >, unsigned long int, __gmp_root_function > > root(const __gmp_expr< T, U > &expr, unsigned long int l)
Definition gmpfrxx.h:4060
Safe modular arithmetic, extended Euclidean algorithm, and Chinese Remainder Theorem.
Main namespace for Aleph-w library functions.
Definition ah-arena.H:89
and
Check uniqueness with explicit hash + equality functors.
uint64_t mod_inv(const uint64_t a, const uint64_t m)
Modular Inverse.
Divide_Conquer_DP_Result< Cost > divide_and_conquer_partition_dp(const size_t groups, const size_t n, Transition_Cost_Fn transition_cost, const Cost inf=dp_optimization_detail::default_inf< Cost >())
Optimize partition DP using divide-and-conquer optimization.
static void prefix(Node *root, DynList< Node * > &acc)
uint64_t mod_exp(uint64_t base, uint64_t exp, const uint64_t m)
Modular exponentiation.
T product(const Container &container, const T &init=T{1})
Compute product of all elements.
uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t m)
Safe 64-bit modular multiplication.
T sum(const Container &container, const T &init=T{})
Compute sum of all elements.
double mod(double a, double b)
Industrial-grade Number Theoretic Transform core for modular polynomial multiplication.
TEST_F(NTTIndustrialTest, ReportsSupportedSizesAndRoots)
Definition ntt_test.cc:683
FooMap m(5, fst_unit_pair_hash, snd_unit_pair_hash)
static int * k
A modern, efficient thread pool for parallel task execution.
ofstream output
Definition writeHeap.C:215