Aleph-w 3.0
A C++ Library for Data Structures and Algorithms
Loading...
Searching...
No Matches
ntt_example.cc
Go to the documentation of this file.
1/*
2 Aleph_w
3
4 Data structures & Algorithms
5 version 2.0.0b
6 https://github.com/lrleon/Aleph-w
7
8 This file is part of Aleph-w library
9
10 Copyright (c) 2002-2026 Leandro Rabindranath Leon
11
12 Permission is hereby granted, free of charge, to any person obtaining a copy
13 of this software and associated documentation files (the "Software"), to deal
14 in the Software without restriction, including without limitation the rights
15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 copies of the Software, and to permit persons to whom the Software is
17 furnished to do so, subject to the following conditions:
18
19 The above copyright notice and this permission notice shall be included in all
20 copies or substantial portions of the Software.
21
22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
29*/
30
35# include <algorithm>
36# include <iostream>
37# include <string>
38
39# include <ntt.H>
40# include <print_rule.H>
41# include <thread_pool.H>
42
43using namespace Aleph;
44using namespace std;
45
46namespace
47{
48 string
50 {
51 if (value == 0)
52 return "0";
53
54 string digits;
55 while (value > 0)
56 {
57 digits.push_back(static_cast<char>('0' + static_cast<unsigned>(value % 10)));
58 value /= 10;
59 }
60
61 reverse(digits.begin(), digits.end());
62 return digits;
63 }
64
66 decimal_digits_from_string(const string & value)
67 {
69 digits.reserve(value.size());
70 for (size_t i = value.size(); i > 0; --i)
71 digits.append(static_cast<uint64_t>(value[i - 1] - '0'));
72
73 while (digits.size() > 1 and digits.get_last() == 0)
74 static_cast<void>(digits.remove_last());
75 return digits.is_empty() ? Array<uint64_t>({0}) : digits;
76 }
77
78 string
80 {
81 if (digits.is_empty())
82 return "0";
83
84 size_t used = digits.size();
85 while (used > 1 and digits[used - 1] == 0)
86 --used;
87
88 string value;
89 value.reserve(used);
90 for (size_t i = used; i > 0; --i)
91 value.push_back(static_cast<char>('0' + digits[i - 1]));
92 return value;
93 }
94
95 bool
97 const Array<uint64_t> & rhs)
98 {
99 if (lhs.size() != rhs.size())
100 return false;
101
102 for (size_t i = 0; i < lhs.size(); ++i)
103 if (lhs[i] != rhs[i])
104 return false;
105
106 return true;
107 }
108
109 void
110 print_coeffs(const char * const name,
111 const Array<uint64_t> & poly)
112 {
113 cout << name << "(x) = ";
114 for (size_t i = 0; i < poly.size(); ++i)
115 {
116 if (i > 0)
117 cout << " + ";
118 cout << poly[i];
119 if (i > 0)
120 cout << "x^" << i;
121 }
122 cout << '\n';
123 }
124
125 void
127 {
128 for (size_t i = 0; i < batch.size(); ++i)
129 {
130 cout << " item[" << i << "] = [";
131 for (size_t j = 0; j < batch[i].size(); ++j)
132 cout << batch[i][j] << (j + 1 == batch[i].size() ? "" : ", ");
133 cout << "]\n";
134 }
135 }
136}
137
138int
140{
141 using DefaultNTT = NTT<>;
142
143 cout << "\n=== Number Theoretic Transform (NTT) ===\n\n";
144 cout << "Active SIMD backend: " << DefaultNTT::simd_backend_name() << '\n';
145 cout << " AVX2 available: "
146 << (DefaultNTT::avx2_dispatch_available() ? "yes" : "no") << '\n';
147 cout << " NEON available: "
148 << (DefaultNTT::neon_dispatch_available() ? "yes" : "no") << "\n\n";
149
150 cout << "[1] Static polynomial multiplication\n";
151 print_rule();
152 Array<uint64_t> a = {1, 2, 3, 4};
153 Array<uint64_t> b = {5, 6, 7};
154 const Array<uint64_t> product = DefaultNTT::multiply(a, b);
155 print_coeffs("A", a);
156 print_coeffs("B", b);
157 print_coeffs("A * B mod 998244353", product);
158 cout << '\n';
159
160 cout << "[2] Exact CRT multiplication beyond a single modulus\n";
161 print_rule();
162 const Array<uint64_t> exact_lhs = {
163 1000000000000ULL,
164 DefaultNTT::mod + 77ULL,
165 1234567890123ULL
166 };
167 const Array<uint64_t> exact_rhs = {
168 777777777777ULL,
169 42ULL,
170 DefaultNTT::mod + 11ULL
171 };
173 cout << "Exact coefficients with CRT:\n [";
174 for (size_t i = 0; i < exact_product.size(); ++i)
175 cout << to_string_u128(exact_product[i])
176 << (i + 1 == exact_product.size() ? "" : ", ");
177 cout << "]\n\n";
178
179 cout << "[3] Reusable arbitrary-size plan (Bluestein)\n";
180 print_rule();
181 DefaultNTT::Plan plan(7);
182 Array<uint64_t> signal = {3, 1, 4, 1, 5, 9, 2};
183 Array<uint64_t> spectrum = signal;
184 plan.transform(spectrum, false);
185 cout << "Forward transform:\n [";
186 for (size_t i = 0; i < spectrum.size(); ++i)
187 cout << spectrum[i] << (i + 1 == spectrum.size() ? "" : ", ");
188 cout << "]\n";
189 plan.transform(spectrum, true);
190 cout << "Inverse transform recovers:\n [";
191 for (size_t i = 0; i < spectrum.size(); ++i)
192 cout << spectrum[i] << (i + 1 == spectrum.size() ? "" : ", ");
193 cout << "]\n\n";
194
195 cout << "[4] Batch transforms\n";
196 print_rule();
198 Array<uint64_t>({1, 0, 0, 0, 0, 0, 0}),
199 Array<uint64_t>({0, 1, 0, 0, 0, 0, 0}),
200 Array<uint64_t>({1, 1, 1, 1, 0, 0, 0})
201 };
202 plan.transform_batch(batch, false);
204 cout << '\n';
205
206 cout << "[5] Parallel multiplication with ThreadPool\n";
207 print_rule();
208 ThreadPool pool(4);
210 DefaultNTT::pmultiply(pool, a, b);
211 print_coeffs("Parallel A * B", parallel_product);
212
215 cout << "Parallel exact CRT product first coefficient: "
216 << to_string_u128(parallel_exact[0]) << '\n';
217
218 cout << "\n[6] Polynomial algebra modulo 998244353\n";
219 print_rule();
220 const Array<uint64_t> unit_series = {1, 4, 7, 2};
221 const Array<uint64_t> log_series = DefaultNTT::poly_log(unit_series, 6);
222 const Array<uint64_t> exp_series = DefaultNTT::poly_exp(log_series, 6);
223 const Array<uint64_t> power_series = DefaultNTT::poly_power(unit_series, 3, 6);
224 cout << "log(1 + 4x + 7x^2 + 2x^3) mod x^6:\n [";
225 for (size_t i = 0; i < log_series.size(); ++i)
226 cout << log_series[i] << (i + 1 == log_series.size() ? "" : ", ");
227 cout << "]\n";
228 cout << "exp(log(series)) recovers:\n [";
229 for (size_t i = 0; i < exp_series.size(); ++i)
230 cout << exp_series[i] << (i + 1 == exp_series.size() ? "" : ", ");
231 cout << "]\n";
232 cout << "series^3 mod x^6:\n [";
233 for (size_t i = 0; i < power_series.size(); ++i)
234 cout << power_series[i] << (i + 1 == power_series.size() ? "" : ", ");
235 cout << "]\n";
236
237 const Array<uint64_t> interp_points = {0, 1, 2, 5};
239 DefaultNTT::multipoint_eval(Array<uint64_t>({9, 7, 5, 3}), interp_points);
241 DefaultNTT::interpolate(interp_points, interp_values);
242 cout << "Interpolation from sampled values:\n [";
243 for (size_t i = 0; i < recovered_poly.size(); ++i)
244 cout << recovered_poly[i]
245 << (i + 1 == recovered_poly.size() ? "" : ", ");
246 cout << "]\n";
247
248 cout << "\n[7] Big integer multiplication in base 10\n";
249 print_rule();
250 const string decimal_a = "12345678901234567890";
251 const string decimal_b = "98765432109876543210";
255 DefaultNTT::bigint_multiply<10>(bigint_a, bigint_b);
257 DefaultNTT::pbigint_multiply<10>(pool, bigint_a, bigint_b);
258 cout << decimal_a << " * " << decimal_b << " =\n "
260 cout << "Parallel product matches sequential: "
261 << (equal_digits(bigint_parallel, bigint_product) ? "yes" : "no") << '\n';
262
263 cout << "\n[8] Negacyclic multiplication modulo (x^4 + 1)\n";
264 print_rule();
265 const Array<uint64_t> neg_lhs = {1, 1, 1, 1};
266 const Array<uint64_t> neg_rhs = {1, 1, 0, 0};
268 DefaultNTT::negacyclic_multiply(neg_lhs, neg_rhs);
269 print_coeffs("A", neg_lhs);
270 print_coeffs("B", neg_rhs);
271 print_coeffs("A * B mod (x^4 + 1)", neg_product);
272
273 cout << "\nDone.\n";
274 return 0;
275}
Simple dynamic array with automatic resizing and functional operations.
Definition tpl_array.H:139
constexpr size_t size() const noexcept
Return the number of elements stored in the stack.
Definition tpl_array.H:351
void reserve(size_t cap)
Reserves cap cells into the array.
Definition tpl_array.H:315
static Array< coeff_type > pmultiply(ThreadPool &pool, const Array< uint64_t > &a, const Array< uint64_t > &b, const size_t chunk_size=0)
Exact parallel polynomial multiplication.
Definition ntt.H:2704
static Array< coeff_type > multiply(const Array< uint64_t > &a, const Array< uint64_t > &b)
Exact sequential polynomial multiplication.
Definition ntt.H:2673
Number Theoretic Transform over Z / MOD Z.
Definition ntt.H:115
A reusable thread pool for efficient parallel task execution.
Main namespace for Aleph-w library functions.
Definition ah-arena.H:89
and
Check uniqueness with explicit hash + equality functors.
void reverse(Itor beg, Itor end)
Reverse elements in a range.
Definition ahAlgo.H:1094
size_t size(Node *root) noexcept
Divide_Conquer_DP_Result< Cost > divide_and_conquer_partition_dp(const size_t groups, const size_t n, Transition_Cost_Fn transition_cost, const Cost inf=dp_optimization_detail::default_inf< Cost >())
Optimize partition DP using divide-and-conquer optimization.
T product(const Container &container, const T &init=T{1})
Compute product of all elements.
void print_rule()
Prints a horizontal rule for example output separation.
Definition print_rule.H:39
STL namespace.
Industrial-grade Number Theoretic Transform core for modular polynomial multiplication.
int main()
A modern, efficient thread pool for parallel task execution.