36# include <gtest/gtest.h>
54 const char *
const value)
59 const int rc =
setenv(name, value, 1);
80 bool had_old_ =
false;
81 std::string old_value_;
84 ScopedEnvVar(
const char *
const name,
85 const char *
const value)
88 if (
const char *current = std::getenv(name); current !=
nullptr)
106 template <
typename NTTType>
111 for (
size_t i = 0; i <
input.size(); ++i)
116 template <
typename NTTType>
125 for (
size_t i = 0; i <
output.size(); ++i)
128 for (
size_t i = 0; i < a.
size(); ++i)
129 for (
size_t j = 0; j < b.
size(); ++j)
132 mod_mul(a[i] % NTTType::mod, b[j] % NTTType::mod, NTTType::mod);
140 template <
typename NTTType>
145 if (
input.is_empty())
152 mod_inv(NTTType::primitive_root_of_unity(n), mod) :
153 NTTType::primitive_root_of_unity(n);
157 for (
size_t k = 0;
k < n; ++
k)
160 for (
size_t j = 0; j < n; ++j)
184 const auto digit =
static_cast<unsigned>(value % 10);
201 for (
size_t i = 0; i <
output.size(); ++i)
204 for (
size_t i = 0; i < a.
size(); ++i)
205 for (
size_t j = 0; j < b.
size(); ++j)
211 template <
typename NTTType>
217 throw std::invalid_argument(
"naive_negacyclic: mismatched sizes");
220 for (
size_t i = 0; i <
output.size(); ++i)
223 for (
size_t i = 0; i < a.
size(); ++i)
224 for (
size_t j = 0; j < b.
size(); ++j)
227 mod_mul(a[i] % NTTType::mod, b[j] % NTTType::mod, NTTType::mod);
228 const size_t pos = i + j;
229 const size_t slot = pos < a.
size() ? pos : pos - a.
size();
242 template <u
int64_t Base>
248 for (
size_t i = 0; i <
input.size(); ++i)
252 static_cast<void>(
output.remove_last());
271 for (
size_t i = value.size(); i > 0; --i)
273 const char c = value[i - 1];
274 if (c <
'0' or c >
'9')
275 throw std::invalid_argument(
"decimal_digits_from_string: non-digit");
282 template <u
int64_t Base>
295 for (
size_t i = 0; i < coeffs.
size(); ++i)
298 for (
size_t i = 0; i <
lhs.size(); ++i)
299 for (
size_t j = 0; j <
rhs.size(); ++j)
305 for (
size_t i = 0; i < coeffs.
size(); ++i)
329 for (
size_t i = 1; i < n; ++i)
332 for (
size_t i = n + 1; i <
output.size(); ++i)
337 template <
typename NTTType>
345 template <
typename NTTType>
351 for (
size_t i = 0; i <
input.size(); ++i)
357 template <
typename NTTType>
363 for (
size_t i = 0; i < n; ++i)
368 template <
typename NTTType>
372 if (coeffs.
size() <= 1)
376 for (
size_t i = 1; i < coeffs.
size(); ++i)
378 static_cast<uint64_t>(i) % NTTType::mod,
383 template <
typename NTTType>
389 for (
size_t i = 0; i < n; ++i)
392 for (
size_t i = 0; i < coeffs.
size()
and i + 1 < n; ++i)
400 template <
typename NTTType>
409 template <
typename NTTType>
419 for (
size_t i = 0; i < n; ++i)
424 for (
size_t i = 1; i < n; ++i)
427 for (
size_t j = 1; j <= i; ++j)
430 mod_mul(f[j], g[i - j], NTTType::mod);
442 template <
typename NTTType>
451 throw std::invalid_argument(
"naive_poly_divmod: zero divisor");
458 for (
size_t i = 0; i <
quotient.size(); ++i)
473 const size_t pos = shift + i;
489 template <
typename NTTType>
504 template <
typename NTTType>
514 for (
size_t i = 0; i < n; ++i)
518 for (
size_t m = 1;
m < n; ++
m)
521 for (
size_t i = 1; i <=
m; ++i)
540 template <
typename NTTType>
552 for (
size_t i = 0; i < n; ++i)
559 for (
size_t i = 0; i < n; ++i)
570 template <
typename NTTType>
576 for (
size_t i = 0; i < points.
size(); ++i)
579 const uint64_t x = points[i] % NTTType::mod;
580 for (
size_t j = coeffs.
size(); j > 0; --j)
582 value =
mod_mul(value, x, NTTType::mod);
585 + coeffs[j - 1] % NTTType::mod) % NTTType::mod);
592 template <
typename NTTType>
599 for (
size_t i = 0; i < points.
size(); ++i)
605 for (
size_t j = 0; j < points.
size(); ++j)
618 mod_mul(values[i] % NTTType::mod,
622 for (
size_t k = 0;
k <
basis.size(); ++
k)
624 scale, NTTType::mod);
644 const char *
const ctx)
647 for (
size_t i = 0; i <
lhs.size(); ++i)
654 const char *
const ctx)
657 for (
size_t i = 0; i <
lhs.size(); ++i)
659 << ctx <<
" index=" << i
671 std::mt19937_64
rng_{123456789ULL};
677 for (
size_t i = 0; i < n; ++i)
685 EXPECT_EQ(DefaultNTT::max_transform_size(), 1ULL << 23);
691 EXPECT_TRUE(DefaultNTT::supports_size(1ULL << 23));
695 EXPECT_FALSE(DefaultNTT::supports_size((1ULL << 23) + 2));
697 const uint64_t root8 = DefaultNTT::primitive_root_of_unity(8);
701 const uint64_t root7 = DefaultNTT::primitive_root_of_unity(7);
704 EXPECT_EQ(DefaultNTT::primitive_root_of_unity(1), 1ULL);
709 ScopedEnvVar
auto_mode(
"ALEPH_NTT_SIMD",
"auto");
711 const std::string
backend = DefaultNTT::simd_backend_name();
714 and DefaultNTT::neon_dispatch_available());
721 for (
size_t i = 0; i < sizes.
size(); ++i)
722 for (
size_t sample = 0; sample < 6; ++sample)
727 DefaultNTT::transform(values,
false);
728 DefaultNTT::transform(values,
true);
743 ScopedEnvVar
scalar_mode(
"ALEPH_NTT_SIMD",
"scalar");
744 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()),
"scalar");
746 scalar_product = DefaultNTT::multiply(
lhs,
rhs);
750 ScopedEnvVar
avx2_mode(
"ALEPH_NTT_SIMD",
"avx2");
755 if (DefaultNTT::avx2_dispatch_available())
756 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()),
"avx2");
758 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()),
"scalar");
765 ScopedEnvVar
neon_mode(
"ALEPH_NTT_SIMD",
"neon");
770 if (DefaultNTT::neon_dispatch_available())
771 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()),
"neon");
773 EXPECT_EQ(std::string(DefaultNTT::simd_backend_name()),
"scalar");
787 DefaultNTT::multipoint_eval(poly, points);
811 for (
size_t i = 0; i < n; ++i)
838 DefaultNTT::poly_power(
unit, 3, n);
860 "poly_sqrt shifted root");
871 DefaultNTT::multipoint_eval(coeffs, points);
877 DefaultNTT::interpolate(points, values);
901 "base-1000 bigint product");
909 "base-2^15 bigint product");
922 DefaultNTT::pbigint_multiply<10>(pool,
lhs,
rhs);
936 const size_t digits = 4096;
938 for (
size_t i = 0; i <
digits; ++i)
950 for (
size_t i = 0; i < sizes.
size(); ++i)
964 "negacyclic demo product");
969 if (
const char *
env = std::getenv(
"ENABLE_PERF_TESTS");
970 env ==
nullptr or std::string(
env) !=
"1")
972 GTEST_SKIP() <<
"Set ENABLE_PERF_TESTS=1 to run million-digit bigint checks";
975 const size_t digits = 1000000;
977 for (
size_t i = 0; i <
digits; ++i)
992 for (
size_t i = 0; i < sizes.
size(); ++i)
993 for (
size_t sample = 0; sample < 4; ++sample)
998 AlternateNTT::transform(values,
false);
999 AlternateNTT::transform(values,
true);
1009 for (
size_t i = 0; i < sizes.
size(); ++i)
1010 for (
size_t sample = 0; sample < 3; ++sample)
1022 "Bluestein forward transform");
1026 "Bluestein inverse round trip");
1028 "Naive inverse round trip");
1037 for (
size_t i = 0; i <
lhs_sizes.size(); ++i)
1038 for (
size_t j = 0; j <
rhs_sizes.size(); ++j)
1051 typename DefaultNTT::Plan
plan(64);
1074 typename DefaultNTT::Plan
plan(16);
1078 for (
size_t i = 0; i < 4; ++i)
1086 for (
size_t i = 0; i <
expected.size(); ++i)
1087 DefaultNTT::transform(
expected(i),
false);
1090 for (
size_t i = 0; i <
batch.size(); ++i)
1094 for (
size_t i = 0; i <
batch.size(); ++i)
1096 "Batch inverse transform");
1101 typename DefaultNTT::Plan
plan(7);
1111 "Bluestein plan inverse round trip");
1115 for (
size_t i = 0; i < 3; ++i)
1120 DefaultNTT::transform(transformed,
false);
1121 reference.
append(transformed);
1125 for (
size_t i = 0; i <
batch.size(); ++i)
1137 DefaultNTT::transform(sequential,
false);
1138 DefaultNTT::ptransform(pool, parallel,
false);
1141 DefaultNTT::transform(sequential,
true);
1142 DefaultNTT::ptransform(pool, parallel,
true);
1155 typename DefaultNTT::Plan
plan(256);
1158 for (
size_t i = 0; i < 6; ++i)
1164 plan.transform_batch(sequential,
false);
1165 plan.ptransform_batch(pool, parallel,
false);
1167 for (
size_t i = 0; i < sequential.
size(); ++i)
1177 DefaultNTT::multiply_inplace(
lhs,
rhs);
1195 DefaultNTT::mod + 77ULL,
1201 DefaultNTT::mod + 11ULL
1212 (1ULL << 63) - 25ULL,
1214 (1ULL << 63) - 17ULL
1229 for (
size_t i = 0; i <
lhs.size(); ++i)
1230 lhs(i) = DefaultNTT::mod + 1000ULL + (rng_() % 1000000ULL);
1231 for (
size_t i = 0; i <
rhs.size(); ++i)
1232 rhs(i) = DefaultNTT::mod + 500ULL + (rng_() % 2000000ULL);
1243 for (
size_t i = 0; i < sequential.size(); ++i)
1246 <<
"Residue mod 998244353 at index " << i;
1248 <<
"Residue mod 469762049 at index " << i;
1250 <<
"Residue mod 1004535809 at index " << i;
1259 std::invalid_argument);
1269 EXPECT_THROW(DefaultNTT::transform(empty,
false), std::invalid_argument);
1270 EXPECT_THROW(DefaultNTT::transform(
bad,
false), std::invalid_argument);
1272 std::invalid_argument);
1273 EXPECT_THROW(
static_cast<void>(DefaultNTT::primitive_root_of_unity(3)),
1274 std::invalid_argument);
1275 EXPECT_THROW(
static_cast<void>(DefaultNTT::primitive_root_of_unity(10)),
1276 std::invalid_argument);
1282 std::invalid_argument);
1287 std::invalid_argument);
1289 std::invalid_argument);
1293 std::invalid_argument);
1296 std::invalid_argument);
1299 std::invalid_argument);
1301 std::invalid_argument);
1303 std::invalid_argument);
1305 std::invalid_argument);
1308 std::invalid_argument);
1311 std::invalid_argument);
1314 std::invalid_argument);
1317 std::invalid_argument);
1322 for (
size_t i = 0; i < 16; ++i)
1329 std::invalid_argument);
1332 std::invalid_argument);
1333 EXPECT_THROW(
static_cast<void>(DefaultNTT::bigint_multiply<(1ULL << 48)>(
1339 (1ULL << 48) - 1}))),
1340 std::invalid_argument);
1346 std::invalid_argument);
1351 if (std::getenv(
"ENABLE_PERF_TESTS") ==
nullptr)
1352 GTEST_SKIP() <<
"Set ENABLE_PERF_TESTS=1 to run NTT performance checks";
1355 const size_t n = 1 << 20;
1359 for (
size_t i = 0; i < n; ++i)
1365 const auto start = std::chrono::steady_clock::now();
1367 const auto end = std::chrono::steady_clock::now();
1371 const auto elapsed =
1372 std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
1375 <<
"NTT performance regression detected";
Simple dynamic array with automatic resizing and functional operations.
static Array create(size_t n)
Create an array with n logical elements.
constexpr size_t size() const noexcept
Return the number of elements stored in the stack.
constexpr bool is_empty() const noexcept
Checks if the container is empty.
T & append(const T &data)
Append a copy of data
T & get_last() noexcept
return a modifiable reference to the last element.
void reserve(size_t cap)
Reserves cap cells into the array.
static Array< coeff_type > pmultiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0)
Exact parallel polynomial multiplication.
static constexpr bool supports_product_size(const size_t required) noexcept
Check whether a target product length is supported.
static Array< coeff_type > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Exact sequential polynomial multiplication.
Number Theoretic Transform over Z / MOD Z.
static Array< uint64_t > negacyclic_multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Negacyclic convolution modulo x^N + 1.
static Array< uint64_t > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Multiply two polynomials modulo MOD.
A reusable thread pool for efficient parallel task execution.
Array< uint64_t > random_poly(const size_t n)
__gmp_expr< typename __gmp_resolve_expr< T, V >::value_type, __gmp_binary_expr< __gmp_expr< T, U >, __gmp_expr< V, W >, __gmp_remainder_function > > remainder(const __gmp_expr< T, U > &expr1, const __gmp_expr< V, W > &expr2)
__gmp_expr< T, __gmp_binary_expr< __gmp_expr< T, U >, unsigned long int, __gmp_root_function > > root(const __gmp_expr< T, U > &expr, unsigned long int l)
Safe modular arithmetic, extended Euclidean algorithm, and Chinese Remainder Theorem.
Main namespace for Aleph-w library functions.
and
Check uniqueness with explicit hash + equality functors.
uint64_t mod_inv(const uint64_t a, const uint64_t m)
Modular Inverse.
Divide_Conquer_DP_Result< Cost > divide_and_conquer_partition_dp(const size_t groups, const size_t n, Transition_Cost_Fn transition_cost, const Cost inf=dp_optimization_detail::default_inf< Cost >())
Optimize partition DP using divide-and-conquer optimization.
static void prefix(Node *root, DynList< Node * > &acc)
uint64_t mod_exp(uint64_t base, uint64_t exp, const uint64_t m)
Modular exponentiation.
T product(const Container &container, const T &init=T{1})
Compute product of all elements.
uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t m)
Safe 64-bit modular multiplication.
T sum(const Container &container, const T &init=T{})
Compute sum of all elements.
double mod(double a, double b)
Industrial-grade Number Theoretic Transform core for modular polynomial multiplication.
TEST_F(NTTIndustrialTest, ReportsSupportedSizesAndRoots)
FooMap m(5, fst_unit_pair_hash, snd_unit_pair_hash)
A modern, efficient thread pool for parallel task execution.