99#include < optional>
1010#include < string>
1111#include < type_traits>
12- #include < variant>
1312#include < vector>
1413
1514#include " lib/Kernel/AbstractValue.h"
@@ -23,6 +22,17 @@ namespace mlir {
2322namespace heir {
2423namespace kernel {
2524
25+ // A function that generalizes the reduction operation in all kernels in this
26+ // file. E.g., whether to use `add` or `mul`
27+ template <typename T>
28+ using DagReducer = std::function<std::shared_ptr<ArithmeticDagNode<T>>(
29+ std::shared_ptr<ArithmeticDagNode<T>>,
30+ std::shared_ptr<ArithmeticDagNode<T>>)>;
31+
32+ template <typename T>
33+ using DagExtractor = std::function<std::shared_ptr<ArithmeticDagNode<T>>(
34+ std::shared_ptr<ArithmeticDagNode<T>>, int64_t )>;
35+
2636// Returns an arithmetic DAG that implements a matvec kernel. Ensure this is
2737// only generated for T a subclass of AbstractValue.
2838template <typename T>
@@ -49,41 +59,69 @@ implementMatvec(KernelName kernelName, const T& matrix, const T& vector) {
4959 return accumulatedSum;
5060}
5161
52- // Returns an arithmetic DAG that implements a rotate and reduce op. Ensure
53- // this is only generated for T a subclass of AbstractValue.
62+ // Returns an arithmetic DAG that implements a logarithmic rotate-and-reduce
63+ // accumulation of an input ciphertext.
64+ //
65+ // This is a special case of `tensor_ext.rotate_and_reduce`
5466template <typename T>
5567std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
5668 std::shared_ptr<ArithmeticDagNode<T>>>
57- implementRotateAndReduce (const T& vector, std::optional<T> plaintexts,
58- int64_t period, int64_t steps,
59- const std::string& reduceOp = " arith.addi" ) {
69+ implementRotateAndReduceAccumulation (const T& vector, int64_t period,
70+ int64_t steps, DagReducer<T> reduceFunc) {
6071 using NodeTy = ArithmeticDagNode<T>;
6172 auto vectorDag = NodeTy::leaf (vector);
6273
63- auto performReduction = [&](std::shared_ptr<NodeTy> left,
64- std::shared_ptr<NodeTy> right) {
65- if (reduceOp == " arith.addi" || reduceOp == " arith.addf" ) {
66- return NodeTy::add (left, right);
67- }
68-
69- if (reduceOp == " arith.muli" || reduceOp == " arith.mulf" ) {
70- return NodeTy::mul (left, right);
71- }
72-
73- // Default to add for unknown operations
74- return NodeTy::add (left, right);
75- };
76-
77- if (!plaintexts.has_value ()) {
78- for (int64_t shiftSize = steps / 2 ; shiftSize > 0 ; shiftSize /= 2 ) {
79- auto rotated = NodeTy::leftRotate (vectorDag, shiftSize * period);
80- auto reduced = performReduction (vectorDag, rotated);
81- vectorDag = reduced;
82- }
83- return vectorDag;
74+ for (int64_t shiftSize = steps / 2 ; shiftSize > 0 ; shiftSize /= 2 ) {
75+ auto rotated = NodeTy::leftRotate (vectorDag, shiftSize * period);
76+ auto reduced = reduceFunc (vectorDag, rotated);
77+ vectorDag = reduced;
8478 }
79+ return vectorDag;
80+ }
8581
86- auto plaintextsDag = NodeTy::leaf (*plaintexts);
82+ // A function that generalizes the choice of rotation for the "baby stepped
83+ // operand" of a baby-step giant-step algorithm. This is required because
84+ // the rotation used in Halevi-Shoup matvec differs from that of bicyclic
85+ // matmul.
86+ using DerivedRotationIndexFn = std::function<int64_t (
87+ // giant step size
88+ int64_t ,
89+ // current giant step index
90+ int64_t ,
91+ // current baby step index
92+ int64_t ,
93+ // period
94+ int64_t )>;
95+
96+ inline int64_t defaultDerivedRotationIndexFn (int64_t giantStepSize,
97+ int64_t giantStepIndex,
98+ int64_t babyStepIndex,
99+ int64_t period) {
100+ return -giantStepSize * giantStepIndex * period;
101+ }
102+
103+ // Returns an arithmetic DAG that implements a baby-step-giant-step
104+ // rotate-and-reduce accumulation between an input ciphertext
105+ // (giantSteppedOperand) and an abstraction over the other argument
106+ // (babySteppedOperand). In particular, the babySteppedOperand may be a list of
107+ // plaintexts like in Halevi-Shoup matvec, or a single ciphertext like in
108+ // bicyclic matmul, and this abstracts over both by taking in an extraction
109+ // callback.
110+ //
111+ // This is a special case of `tensor_ext.rotate_and_reduce`, but with the added
112+ // abstractions it also supports situations not currently expressible by
113+ // `tensor_ext.rotate_and_reduce`.
114+ template <typename T>
115+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
116+ std::shared_ptr<ArithmeticDagNode<T>>>
117+ implementBabyStepGiantStep (
118+ const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
119+ int64_t steps, DagExtractor<T> extractFunc,
120+ const DerivedRotationIndexFn& derivedRotationIndexFn =
121+ defaultDerivedRotationIndexFn) {
122+ using NodeTy = ArithmeticDagNode<T>;
123+ auto giantSteppedDag = NodeTy::leaf (giantSteppedOperand);
124+ auto babySteppedDag = NodeTy::leaf (babySteppedOperand);
87125
88126 // Use a value of sqrt(n) as the baby step / giant step size.
89127 int64_t numBabySteps = static_cast <int64_t >(std::floor (std::sqrt (steps)));
@@ -112,35 +150,124 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
112150
113151 // Compute sqrt(n) ciphertext rotations of the input as baby-steps.
114152 SmallVector<std::shared_ptr<NodeTy>> babyStepVals;
115- babyStepVals.push_back (vectorDag ); // rot by zero
153+ babyStepVals.push_back (giantSteppedDag ); // rot by zero
116154 for (int64_t i = 1 ; i < numBabySteps; ++i) {
117- babyStepVals.push_back (NodeTy::leftRotate (vectorDag , period * i));
155+ babyStepVals.push_back (NodeTy::leftRotate (giantSteppedDag , period * i));
118156 }
119157
120158 // Compute the inner baby step sums.
121159 std::shared_ptr<NodeTy> result = nullptr ;
122160 for (int64_t j = 0 ; j < numGiantSteps; ++j) {
123161 std::shared_ptr<NodeTy> innerSum = nullptr ;
124- // The rotation used for the plaintext
125- int64_t plaintextRotationAmount = -giantStepSize * j * period;
126162 for (int64_t i = 0 ; i < numBabySteps; ++i) {
163+ int64_t innerRotAmount =
164+ derivedRotationIndexFn (giantStepSize, j, i, period);
127165 size_t extractionIndex = i + j * giantStepSize;
128- auto plaintext = NodeTy::extract (plaintextsDag, extractionIndex);
129- auto rotatedPlaintext =
130- NodeTy::leftRotate (plaintext, plaintextRotationAmount);
166+ auto plaintext = extractFunc (babySteppedDag, extractionIndex);
167+ auto rotatedPlaintext = NodeTy::leftRotate (plaintext, innerRotAmount);
131168 auto multiplied = NodeTy::mul (rotatedPlaintext, babyStepVals[i]);
132- innerSum = innerSum == nullptr ? multiplied
133- : performReduction (innerSum, multiplied);
169+ innerSum =
170+ innerSum == nullptr ? multiplied : NodeTy::add (innerSum, multiplied);
134171 }
135172
136173 auto rotatedSum = NodeTy::leftRotate (innerSum, period * j * giantStepSize);
137- result =
138- result == nullptr ? rotatedSum : performReduction (result, rotatedSum);
174+ result = result == nullptr ? rotatedSum : NodeTy::add (result, rotatedSum);
139175 }
140176
141177 return result;
142178}
143179
180+ // Returns an arithmetic DAG that implements a tensor_ext.rotate_and_reduce op.
181+ //
182+ // See TensorExtOps.td docs for RotateAndReduceOp for more details.
183+ //
184+ // The `vector` argument is a ciphertext value that will be rotated O(sqrt(n))
185+ // times when the `plaintexts` argument is set (Baby Step Giant Step), or
186+ // O(log(n)) times when the `plaintexts` argument is not set (log-style
187+ // rotate-and-reduce accumulation).
188+ //
189+ // The `plaintexts` argument, when present, represents a vector of pre-packed
190+ // plaintexts that will be rotated and multiplied with the rotated `vector`
191+ // argument in BSGS style.
192+ //
193+ // Note that using this kernel results in places in the pipeline where a
194+ // plaintext type is rotated, but most FHE implementations don't have a
195+ // plaintext rotation operation (it would be wasteful) and instead expect the
196+ // "plaintext rotation" to apply to the cleartext. HEIR has places in the
197+ // pipeline that support this by converting a rotate(encode(cleartext)) to
198+ // encode(rotate(cleartext)).
199+ template <typename T>
200+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
201+ std::shared_ptr<ArithmeticDagNode<T>>>
202+ implementRotateAndReduce (const T& vector, std::optional<T> plaintexts,
203+ int64_t period, int64_t steps,
204+ const std::string& reduceOp = " arith.addi" ) {
205+ using NodeTy = ArithmeticDagNode<T>;
206+ auto performReduction = [&](std::shared_ptr<NodeTy> left,
207+ std::shared_ptr<NodeTy> right) {
208+ if (reduceOp == " arith.addi" || reduceOp == " arith.addf" ) {
209+ return NodeTy::add (left, right);
210+ }
211+
212+ if (reduceOp == " arith.muli" || reduceOp == " arith.mulf" ) {
213+ return NodeTy::mul (left, right);
214+ }
215+
216+ // Default to add for unknown operations
217+ return NodeTy::add (left, right);
218+ };
219+
220+ if (!plaintexts.has_value ()) {
221+ return implementRotateAndReduceAccumulation<T>(vector, period, steps,
222+ performReduction);
223+ }
224+
225+ assert (reduceOp == " arith.addi" ||
226+ reduceOp == " arith.addf" &&
227+ " Baby-step-giant-step rotate-and-reduce only supports addition "
228+ " as the reduction operation" );
229+
230+ auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
231+ int64_t extractionIndex) {
232+ return NodeTy::extract (babySteppedDag, extractionIndex);
233+ };
234+
235+ return implementBabyStepGiantStep<T>(vector, plaintexts.value (), period,
236+ steps, extractFunc);
237+ }
238+
239+ // Returns an arithmetic DAG that implements a baby-step-giant-step between
240+ // ciphertexts.
241+ //
242+ // This implements equation 21 in 6.2.2 of LKAA25: "Tricycle: Private
243+ // Transformer Inference with Tricyclic Encodings"
244+ // https://eprint.iacr.org/2025/1200
245+ //
246+ // This differs from the above implementRotateAndReduce in that, instead of a
247+ // set of pre-computed plaintexts, both arguments are individual ciphertexts.
248+ // Normally with one ciphertext, the naive approach uses n - 1 rotations that
249+ // BSGS reduces to c sqrt(n) + O(1) rotations, if both inputs are ciphertexts
250+ // then it converts 2n - 2 total rotations to n + c sqrt(n) + O(1) rotations.
251+ // Essentially, the "n to sqrt(n)" redution applies to the `vector` argument
252+ // only, while the `plaintexts` argument still gets n-1 rotations.
253+ template <typename T>
254+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
255+ std::shared_ptr<ArithmeticDagNode<T>>>
256+ implementCiphertextCiphertextBabyStepGiantStep (
257+ const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
258+ int64_t steps, DerivedRotationIndexFn derivedRotationIndexFn) {
259+ using NodeTy = ArithmeticDagNode<T>;
260+
261+ // Avoid replicating and re-extracting by simulating the extraction step by
262+ // just returning the single ciphertext.
263+ auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
264+ int64_t extractionIndex) { return babySteppedDag; };
265+
266+ return implementBabyStepGiantStep<T>(giantSteppedOperand, babySteppedOperand,
267+ period, steps, extractFunc,
268+ derivedRotationIndexFn);
269+ }
270+
144271// Returns an arithmetic DAG that implements the Halevi-Shoup matrix
145272// multiplication algorithm. This implementation uses a rotate-and-reduce
146273// operation, followed by a summation of partial sums if the matrix is not
@@ -187,47 +314,56 @@ implementHaleviShoup(const T& vector, const T& matrix,
187314// zero-padded so that their dimensions are coprime, they are cyclically
188315// repeated to fill all the slots of the ciphertext, and they are packed
189316// according to the bicyclic ordering.
317+ //
318+ // This function produces a kernel using roughly n + 2sqrt(n) - 3 rotations
319+ // (for matrix dimensions all order n), by applying the baby-step-giant-step
320+ // method to reduce the number of rotations of packedA.
321+ //
322+ // This implements the BMM-I algorithm from https://eprint.iacr.org/2024/1762
323+ // with modifications from LKAA25 (https://eprint.iacr.org/2025/1200):
324+ //
325+ // - A simplification of the rotation formula in Sec 5.2.1 (equation 9).
326+ // - A baby-step-giant-step optimization of the summation below, from Sec
327+ // 6.2.2 (equation 21).
328+ //
329+ // It computes
330+ //
331+ // C = sum_{c=0}^{n-1} rot(A, r1(c)) * rot(B, r2(c))
332+ //
333+ // where
334+ //
335+ // r1(c) = cm(m^{-1} mod n) mod mn
336+ // r2(c) = cp(p^{-1} mod n) mod np
337+ //
190338template <typename T>
191339std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
192340 std::shared_ptr<ArithmeticDagNode<T>>>
193341implementBicyclicMatmul (const T& packedA, const T& packedB, int64_t m,
194342 int64_t n, int64_t p) {
195- using NodeTy = ArithmeticDagNode<T>;
196- auto packedADag = NodeTy::leaf (packedA);
197- auto packedBDag = NodeTy::leaf (packedB);
198-
199- // This implements the BMM-I algorithm from https://eprint.iacr.org/2024/1762
200- // with a simplification of the rotation formula in Sec 5.2.1 (equation 9) of
201- // https://eprint.iacr.org/2025/1200
202- //
203- // C = sum_{c=0}^{n-1} rot(A, r1(c)) * rot(B, r2(c))
204- //
205- // where
206- //
207- // r1(c) = cm(m^{-1} mod n) mod mn
208- // r2(c) = cp(p^{-1} mod n) mod np
209- //
210343 APInt mAPInt = APInt (64 , m);
211344 APInt nAPInt = APInt (64 , n);
212345 APInt pAPInt = APInt (64 , p);
213346
214- APInt mInvModN = multiplicativeInverse (mAPInt , nAPInt);
215- APInt pInvModN = multiplicativeInverse (pAPInt, nAPInt);
347+ APInt mInvModN = multiplicativeInverse (mAPInt . urem (nAPInt) , nAPInt);
348+ APInt pInvModN = multiplicativeInverse (pAPInt. urem (nAPInt) , nAPInt);
216349
217- // The part of r1(c), r2(c) that is independent of the loop iterations
218- int64_t r1Const = m * mInvModN .getSExtValue ();
219- int64_t r2Const = p * pInvModN.getSExtValue ();
350+ auto derivedRotationIndexFn = [&](int64_t giantStepSize,
351+ int64_t giantStepIndex,
352+ int64_t babyStepIndex, int64_t period) {
353+ APInt c (64 , giantStepIndex * giantStepSize + babyStepIndex);
354+ APInt mAPInt (64 , m);
220355
221- std::shared_ptr<NodeTy> result = nullptr ;
222- for (int i = 0 ; i < n; ++i) {
223- int64_t shiftA = (i * r1Const) % (m * n);
224- int64_t shiftB = (i * r2Const) % (n * p);
225- auto rotA = NodeTy::leftRotate (packedADag, shiftA);
226- auto rotB = NodeTy::leftRotate (packedBDag, shiftB);
227- auto term = NodeTy::mul (rotA, rotB);
228- result = result == nullptr ? term : NodeTy::add (result, term);
229- }
230- return result;
356+ // RotY(c) = (p * (c * m * p^{-1} mod n)) mod (n * p)
357+ APInt rotyInner = (c * mAPInt * pInvModN.getSExtValue ()).urem (nAPInt);
358+ APInt roty = (rotyInner * pAPInt).urem (nAPInt * pAPInt);
359+
360+ APInt result = roty - APInt (64 , period) * APInt (64 , giantStepSize) *
361+ APInt (64 , giantStepIndex);
362+ return result.getSExtValue ();
363+ };
364+
365+ return implementCiphertextCiphertextBabyStepGiantStep<T>(
366+ packedA, packedB, /* period=*/ m, /* steps=*/ n, derivedRotationIndexFn);
231367}
232368
233369} // namespace kernel
0 commit comments