Skip to content

Commit 021d561

Browse files
committed
genearlize rotate-and-reduce kernel implementation to support ct-ct BSGS
1 parent e603924 commit 021d561

File tree

5 files changed

+242
-82
lines changed

5 files changed

+242
-82
lines changed

lib/Dialect/TensorExt/IR/TensorExtOps.td

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,30 +150,27 @@ def TensorExt_RotateAndReduceOp : TensorExt_Op<"rotate_and_reduce",[Pure, AllTyp
150150

151151
In almost full generality, the reduction performed is
152152

153-
\[
154-
\sum_{i \in [0, n]} p(P, T*i) \cdot rotate(v, T*i)
155-
\]
153+
$$ \sum_{i \in [0, n]} p(P, iT) \cdot rotate(v, iT) $$
156154

157-
where $f$ is a function, $p(P, T*i)$ is a function of a plaintext $P$ and
158-
$rotate(v, T*i)$ is a rotation of the ciphertext $v$ with period $T$. The
155+
where $f$ is a function, $p(P, iT)$ is a function of a plaintext $P$ and
156+
$rotate(v, iT)$ is a rotation of the ciphertext $v$ with period $T$. The
159157
operation takes as input the ciphertext vector $v$, the period $T$, the
160-
number of reductions $n$, and a tensor of plaintext values `[p(P, 0), p(P,
161-
T), ..., p(P, T*(n-1))]`.
158+
number of reductions $n$, and a tensor of plaintext values
159+
160+
`[p(P, 0), p(P, T), ..., p(P, (n-1)T)]`
162161

163162
This can be used to implement a matrix vector product that uses a
164163
Halevi-Shoup diagonalization of the plaintext matrix. In this case, the
165164
reduction is
166165

167-
\[
168-
\sum_{i \in [0, n]} P(i) \cdot rotate(v, i)
169-
\]
166+
$$ \sum_{i \in [0, n]} P(i) \cdot rotate(v, i) $$
170167

171168
where $P(i)$ is the $i$th diagonal of the plaintext matrix and the period
172169
$T$ is $1$.
173170

174171
An accumulation of the ciphertext slots is also handled via this operation
175-
by omitting the plaintext $p(P, T*i)$ argument and using a period of 1 with
176-
`n = |v|` so that the reduction is simply a sum of all rotation of the
172+
by omitting the plaintext $p(P, Ti)$ argument and using a period of 1 with
173+
`n = |v|` so that the reduction is simply a sum of all rotations of the
177174
ciphertext.
178175

179176
If `reduceOp` is set to an MLIR operation name (e.g., `arith.mulf`), then

lib/Kernel/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ cc_test(
111111
":ArithmeticDag",
112112
":Kernel",
113113
":KernelImplementation",
114+
":RotationCountVisitor",
114115
":TestingUtils",
115116
"@googletest//:gtest_main",
116117
"@heir//lib/Utils/Layout:Evaluate",

lib/Kernel/KernelImplementation.h

Lines changed: 205 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <optional>
1010
#include <string>
1111
#include <type_traits>
12-
#include <variant>
1312
#include <vector>
1413

1514
#include "lib/Kernel/AbstractValue.h"
@@ -23,6 +22,17 @@ namespace mlir {
2322
namespace heir {
2423
namespace kernel {
2524

25+
// A function that generalizes the reduction operation in all kernels in this
26+
// file. E.g., whether to use `add` or `mul`
27+
template <typename T>
28+
using DagReducer = std::function<std::shared_ptr<ArithmeticDagNode<T>>(
29+
std::shared_ptr<ArithmeticDagNode<T>>,
30+
std::shared_ptr<ArithmeticDagNode<T>>)>;
31+
32+
template <typename T>
33+
using DagExtractor = std::function<std::shared_ptr<ArithmeticDagNode<T>>(
34+
std::shared_ptr<ArithmeticDagNode<T>>, int64_t)>;
35+
2636
// Returns an arithmetic DAG that implements a matvec kernel. Ensure this is
2737
// only generated for T a subclass of AbstractValue.
2838
template <typename T>
@@ -49,41 +59,69 @@ implementMatvec(KernelName kernelName, const T& matrix, const T& vector) {
4959
return accumulatedSum;
5060
}
5161

52-
// Returns an arithmetic DAG that implements a rotate and reduce op. Ensure
53-
// this is only generated for T a subclass of AbstractValue.
62+
// Returns an arithmetic DAG that implements a logarithmic rotate-and-reduce
63+
// accumulation of an input ciphertext.
64+
//
65+
// This is a special case of `tensor_ext.rotate_and_reduce`
5466
template <typename T>
5567
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
5668
std::shared_ptr<ArithmeticDagNode<T>>>
57-
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
58-
int64_t period, int64_t steps,
59-
const std::string& reduceOp = "arith.addi") {
69+
implementRotateAndReduceAccumulation(const T& vector, int64_t period,
70+
int64_t steps, DagReducer<T> reduceFunc) {
6071
using NodeTy = ArithmeticDagNode<T>;
6172
auto vectorDag = NodeTy::leaf(vector);
6273

63-
auto performReduction = [&](std::shared_ptr<NodeTy> left,
64-
std::shared_ptr<NodeTy> right) {
65-
if (reduceOp == "arith.addi" || reduceOp == "arith.addf") {
66-
return NodeTy::add(left, right);
67-
}
68-
69-
if (reduceOp == "arith.muli" || reduceOp == "arith.mulf") {
70-
return NodeTy::mul(left, right);
71-
}
72-
73-
// Default to add for unknown operations
74-
return NodeTy::add(left, right);
75-
};
76-
77-
if (!plaintexts.has_value()) {
78-
for (int64_t shiftSize = steps / 2; shiftSize > 0; shiftSize /= 2) {
79-
auto rotated = NodeTy::leftRotate(vectorDag, shiftSize * period);
80-
auto reduced = performReduction(vectorDag, rotated);
81-
vectorDag = reduced;
82-
}
83-
return vectorDag;
74+
for (int64_t shiftSize = steps / 2; shiftSize > 0; shiftSize /= 2) {
75+
auto rotated = NodeTy::leftRotate(vectorDag, shiftSize * period);
76+
auto reduced = reduceFunc(vectorDag, rotated);
77+
vectorDag = reduced;
8478
}
79+
return vectorDag;
80+
}
8581

86-
auto plaintextsDag = NodeTy::leaf(*plaintexts);
82+
// A function that generalizes the choice of rotation for the "baby stepped
83+
// operand" of a baby-step giant-step algorithm. This is required because
84+
// the rotation used in Halevi-Shoup matvec differs from that of bicyclic
85+
// matmul.
86+
using DerivedRotationIndexFn = std::function<int64_t(
87+
// giant step size
88+
int64_t,
89+
// current giant step index
90+
int64_t,
91+
// current baby step index
92+
int64_t,
93+
// period
94+
int64_t)>;
95+
96+
inline int64_t defaultDerivedRotationIndexFn(int64_t giantStepSize,
97+
int64_t giantStepIndex,
98+
int64_t babyStepIndex,
99+
int64_t period) {
100+
return -giantStepSize * giantStepIndex * period;
101+
}
102+
103+
// Returns an arithmetic DAG that implements a baby-step-giant-step
104+
// rotate-and-reduce accumulation between an input ciphertext
105+
// (giantSteppedOperand) and an abstraction over the other argument
106+
// (babySteppedOperand). In particular, the babySteppedOperand may be a list of
107+
// plaintexts like in Halevi-Shoup matvec, or a single ciphertext like in
108+
// bicyclic matmul, and this abstracts over both by taking in an extraction
109+
// callback.
110+
//
111+
// This is a special case of `tensor_ext.rotate_and_reduce`, but with the added
112+
// abstractions it also supports situations not currently expressible by
113+
// `tensor_ext.rotate_and_reduce`.
114+
template <typename T>
115+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
116+
std::shared_ptr<ArithmeticDagNode<T>>>
117+
implementBabyStepGiantStep(
118+
const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
119+
int64_t steps, DagExtractor<T> extractFunc,
120+
const DerivedRotationIndexFn& derivedRotationIndexFn =
121+
defaultDerivedRotationIndexFn) {
122+
using NodeTy = ArithmeticDagNode<T>;
123+
auto giantSteppedDag = NodeTy::leaf(giantSteppedOperand);
124+
auto babySteppedDag = NodeTy::leaf(babySteppedOperand);
87125

88126
// Use a value of sqrt(n) as the baby step / giant step size.
89127
int64_t numBabySteps = static_cast<int64_t>(std::floor(std::sqrt(steps)));
@@ -112,35 +150,124 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
112150

113151
// Compute sqrt(n) ciphertext rotations of the input as baby-steps.
114152
SmallVector<std::shared_ptr<NodeTy>> babyStepVals;
115-
babyStepVals.push_back(vectorDag); // rot by zero
153+
babyStepVals.push_back(giantSteppedDag); // rot by zero
116154
for (int64_t i = 1; i < numBabySteps; ++i) {
117-
babyStepVals.push_back(NodeTy::leftRotate(vectorDag, period * i));
155+
babyStepVals.push_back(NodeTy::leftRotate(giantSteppedDag, period * i));
118156
}
119157

120158
// Compute the inner baby step sums.
121159
std::shared_ptr<NodeTy> result = nullptr;
122160
for (int64_t j = 0; j < numGiantSteps; ++j) {
123161
std::shared_ptr<NodeTy> innerSum = nullptr;
124-
// The rotation used for the plaintext
125-
int64_t plaintextRotationAmount = -giantStepSize * j * period;
126162
for (int64_t i = 0; i < numBabySteps; ++i) {
163+
int64_t innerRotAmount =
164+
derivedRotationIndexFn(giantStepSize, j, i, period);
127165
size_t extractionIndex = i + j * giantStepSize;
128-
auto plaintext = NodeTy::extract(plaintextsDag, extractionIndex);
129-
auto rotatedPlaintext =
130-
NodeTy::leftRotate(plaintext, plaintextRotationAmount);
166+
auto plaintext = extractFunc(babySteppedDag, extractionIndex);
167+
auto rotatedPlaintext = NodeTy::leftRotate(plaintext, innerRotAmount);
131168
auto multiplied = NodeTy::mul(rotatedPlaintext, babyStepVals[i]);
132-
innerSum = innerSum == nullptr ? multiplied
133-
: performReduction(innerSum, multiplied);
169+
innerSum =
170+
innerSum == nullptr ? multiplied : NodeTy::add(innerSum, multiplied);
134171
}
135172

136173
auto rotatedSum = NodeTy::leftRotate(innerSum, period * j * giantStepSize);
137-
result =
138-
result == nullptr ? rotatedSum : performReduction(result, rotatedSum);
174+
result = result == nullptr ? rotatedSum : NodeTy::add(result, rotatedSum);
139175
}
140176

141177
return result;
142178
}
143179

180+
// Returns an arithmetic DAG that implements a tensor_ext.rotate_and_reduce op.
181+
//
182+
// See TensorExtOps.td docs for RotateAndReduceOp for more details.
183+
//
184+
// The `vector` argument is a ciphertext value that will be rotated O(sqrt(n))
185+
// times when the `plaintexts` argument is set (Baby Step Giant Step), or
186+
// O(log(n)) times when the `plaintexts` argument is not set (log-style
187+
// rotate-and-reduce accumulation).
188+
//
189+
// The `plaintexts` argument, when present, represents a vector of pre-packed
190+
// plaintexts that will be rotated and multiplied with the rotated `vector`
191+
// argument in BSGS style.
192+
//
193+
// Note that using this kernel results in places in the pipeline where a
194+
// plaintext type is rotated, but most FHE implementations don't have a
195+
// plaintext rotation operation (it would be wasteful) and instead expect the
196+
// "plaintext rotation" to apply to the cleartext. HEIR has places in the
197+
// pipeline that support this by converting a rotate(encode(cleartext)) to
198+
// encode(rotate(cleartext)).
199+
template <typename T>
200+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
201+
std::shared_ptr<ArithmeticDagNode<T>>>
202+
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
203+
int64_t period, int64_t steps,
204+
const std::string& reduceOp = "arith.addi") {
205+
using NodeTy = ArithmeticDagNode<T>;
206+
auto performReduction = [&](std::shared_ptr<NodeTy> left,
207+
std::shared_ptr<NodeTy> right) {
208+
if (reduceOp == "arith.addi" || reduceOp == "arith.addf") {
209+
return NodeTy::add(left, right);
210+
}
211+
212+
if (reduceOp == "arith.muli" || reduceOp == "arith.mulf") {
213+
return NodeTy::mul(left, right);
214+
}
215+
216+
// Default to add for unknown operations
217+
return NodeTy::add(left, right);
218+
};
219+
220+
if (!plaintexts.has_value()) {
221+
return implementRotateAndReduceAccumulation<T>(vector, period, steps,
222+
performReduction);
223+
}
224+
225+
assert(reduceOp == "arith.addi" ||
226+
reduceOp == "arith.addf" &&
227+
"Baby-step-giant-step rotate-and-reduce only supports addition "
228+
"as the reduction operation");
229+
230+
auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
231+
int64_t extractionIndex) {
232+
return NodeTy::extract(babySteppedDag, extractionIndex);
233+
};
234+
235+
return implementBabyStepGiantStep<T>(vector, plaintexts.value(), period,
236+
steps, extractFunc);
237+
}
238+
239+
// Returns an arithmetic DAG that implements a baby-step-giant-step between
240+
// ciphertexts.
241+
//
242+
// This implements equation 21 in 6.2.2 of LKAA25: "Tricycle: Private
243+
// Transformer Inference with Tricyclic Encodings"
244+
// https://eprint.iacr.org/2025/1200
245+
//
246+
// This differs from the above implementRotateAndReduce in that, instead of a
247+
// set of pre-computed plaintexts, both arguments are individual ciphertexts.
248+
// Normally with one ciphertext, the naive approach uses n - 1 rotations that
249+
// BSGS reduces to c sqrt(n) + O(1) rotations, if both inputs are ciphertexts
250+
// then it converts 2n - 2 total rotations to n + c sqrt(n) + O(1) rotations.
251+
// Essentially, the "n to sqrt(n)" redution applies to the `vector` argument
252+
// only, while the `plaintexts` argument still gets n-1 rotations.
253+
template <typename T>
254+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
255+
std::shared_ptr<ArithmeticDagNode<T>>>
256+
implementCiphertextCiphertextBabyStepGiantStep(
257+
const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
258+
int64_t steps, DerivedRotationIndexFn derivedRotationIndexFn) {
259+
using NodeTy = ArithmeticDagNode<T>;
260+
261+
// Avoid replicating and re-extracting by simulating the extraction step by
262+
// just returning the single ciphertext.
263+
auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
264+
int64_t extractionIndex) { return babySteppedDag; };
265+
266+
return implementBabyStepGiantStep<T>(giantSteppedOperand, babySteppedOperand,
267+
period, steps, extractFunc,
268+
derivedRotationIndexFn);
269+
}
270+
144271
// Returns an arithmetic DAG that implements the Halevi-Shoup matrix
145272
// multiplication algorithm. This implementation uses a rotate-and-reduce
146273
// operation, followed by a summation of partial sums if the matrix is not
@@ -187,47 +314,56 @@ implementHaleviShoup(const T& vector, const T& matrix,
187314
// zero-padded so that their dimensions are coprime, they are cyclically
188315
// repeated to fill all the slots of the ciphertext, and they are packed
189316
// according to the bicyclic ordering.
317+
//
318+
// This function produces a kernel using roughly n + 2sqrt(n) - 3 rotations
319+
// (for matrix dimensions all order n), by applying the baby-step-giant-step
320+
// method to reduce the number of rotations of packedA.
321+
//
322+
// This implements the BMM-I algorithm from https://eprint.iacr.org/2024/1762
323+
// with modifications from LKAA25 (https://eprint.iacr.org/2025/1200):
324+
//
325+
// - A simplification of the rotation formula in Sec 5.2.1 (equation 9).
326+
// - A baby-step-giant-step optimization of the summation below, from Sec
327+
// 6.2.2 (equation 21).
328+
//
329+
// It computes
330+
//
331+
// C = sum_{c=0}^{n-1} rot(A, r1(c)) * rot(B, r2(c))
332+
//
333+
// where
334+
//
335+
// r1(c) = cm(m^{-1} mod n) mod mn
336+
// r2(c) = cp(p^{-1} mod n) mod np
337+
//
190338
template <typename T>
191339
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
192340
std::shared_ptr<ArithmeticDagNode<T>>>
193341
implementBicyclicMatmul(const T& packedA, const T& packedB, int64_t m,
194342
int64_t n, int64_t p) {
195-
using NodeTy = ArithmeticDagNode<T>;
196-
auto packedADag = NodeTy::leaf(packedA);
197-
auto packedBDag = NodeTy::leaf(packedB);
198-
199-
// This implements the BMM-I algorithm from https://eprint.iacr.org/2024/1762
200-
// with a simplification of the rotation formula in Sec 5.2.1 (equation 9) of
201-
// https://eprint.iacr.org/2025/1200
202-
//
203-
// C = sum_{c=0}^{n-1} rot(A, r1(c)) * rot(B, r2(c))
204-
//
205-
// where
206-
//
207-
// r1(c) = cm(m^{-1} mod n) mod mn
208-
// r2(c) = cp(p^{-1} mod n) mod np
209-
//
210343
APInt mAPInt = APInt(64, m);
211344
APInt nAPInt = APInt(64, n);
212345
APInt pAPInt = APInt(64, p);
213346

214-
APInt mInvModN = multiplicativeInverse(mAPInt, nAPInt);
215-
APInt pInvModN = multiplicativeInverse(pAPInt, nAPInt);
347+
APInt mInvModN = multiplicativeInverse(mAPInt.urem(nAPInt), nAPInt);
348+
APInt pInvModN = multiplicativeInverse(pAPInt.urem(nAPInt), nAPInt);
216349

217-
// The part of r1(c), r2(c) that is independent of the loop iterations
218-
int64_t r1Const = m * mInvModN.getSExtValue();
219-
int64_t r2Const = p * pInvModN.getSExtValue();
350+
auto derivedRotationIndexFn = [&](int64_t giantStepSize,
351+
int64_t giantStepIndex,
352+
int64_t babyStepIndex, int64_t period) {
353+
APInt c(64, giantStepIndex * giantStepSize + babyStepIndex);
354+
APInt mAPInt(64, m);
220355

221-
std::shared_ptr<NodeTy> result = nullptr;
222-
for (int i = 0; i < n; ++i) {
223-
int64_t shiftA = (i * r1Const) % (m * n);
224-
int64_t shiftB = (i * r2Const) % (n * p);
225-
auto rotA = NodeTy::leftRotate(packedADag, shiftA);
226-
auto rotB = NodeTy::leftRotate(packedBDag, shiftB);
227-
auto term = NodeTy::mul(rotA, rotB);
228-
result = result == nullptr ? term : NodeTy::add(result, term);
229-
}
230-
return result;
356+
// RotY(c) = (p * (c * m * p^{-1} mod n)) mod (n * p)
357+
APInt rotyInner = (c * mAPInt * pInvModN.getSExtValue()).urem(nAPInt);
358+
APInt roty = (rotyInner * pAPInt).urem(nAPInt * pAPInt);
359+
360+
APInt result = roty - APInt(64, period) * APInt(64, giantStepSize) *
361+
APInt(64, giantStepIndex);
362+
return result.getSExtValue();
363+
};
364+
365+
return implementCiphertextCiphertextBabyStepGiantStep<T>(
366+
packedA, packedB, /*period=*/m, /*steps=*/n, derivedRotationIndexFn);
231367
}
232368

233369
} // namespace kernel

0 commit comments

Comments
 (0)