Skip to content

Commit cc216d2

Browse files
committed
genearlize rotate-and-reduce kernel implementation to support ct-ct BSGS
1 parent 05fdb07 commit cc216d2

File tree

1 file changed

+145
-30
lines changed

1 file changed

+145
-30
lines changed

lib/Kernel/KernelImplementation.h

Lines changed: 145 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,41 +49,57 @@ implementMatvec(KernelName kernelName, const T& matrix, const T& vector) {
4949
return accumulatedSum;
5050
}
5151

52-
// Returns an arithmetic DAG that implements a rotate and reduce op. Ensure
53-
// this is only generated for T a subclass of AbstractValue.
52+
// Returns an arithmetic DAG that implements a logarithmic rotate-and-reduce
53+
// accumulation of an input ciphertext.
54+
//
55+
// This is a special case of `tensor_ext.rotate_and_reduce`
5456
template <typename T>
5557
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
5658
std::shared_ptr<ArithmeticDagNode<T>>>
57-
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
58-
int64_t period, int64_t steps,
59-
const std::string& reduceOp = "arith.addi") {
59+
implementRotateAndReduceAccumulation(
60+
const T& vector, int64_t period, int64_t steps,
61+
std::function<std::shared_ptr<ArithmeticDagNode<T>>(
62+
std::shared_ptr<ArithmeticDagNode<T>>,
63+
std::shared_ptr<ArithmeticDagNode<T>>)>
64+
reduceFunc) {
6065
using NodeTy = ArithmeticDagNode<T>;
6166
auto vectorDag = NodeTy::leaf(vector);
6267

63-
auto performReduction = [&](std::shared_ptr<NodeTy> left,
64-
std::shared_ptr<NodeTy> right) {
65-
if (reduceOp == "arith.addi" || reduceOp == "arith.addf") {
66-
return NodeTy::add(left, right);
67-
}
68-
69-
if (reduceOp == "arith.muli" || reduceOp == "arith.mulf") {
70-
return NodeTy::mul(left, right);
71-
}
72-
73-
// Default to add for unknown operations
74-
return NodeTy::add(left, right);
75-
};
76-
77-
if (!plaintexts.has_value()) {
78-
for (int64_t shiftSize = steps / 2; shiftSize > 0; shiftSize /= 2) {
79-
auto rotated = NodeTy::leftRotate(vectorDag, shiftSize * period);
80-
auto reduced = performReduction(vectorDag, rotated);
81-
vectorDag = reduced;
82-
}
83-
return vectorDag;
68+
for (int64_t shiftSize = steps / 2; shiftSize > 0; shiftSize /= 2) {
69+
auto rotated = NodeTy::leftRotate(vectorDag, shiftSize * period);
70+
auto reduced = performReduction(vectorDag, rotated);
71+
vectorDag = reduced;
8472
}
73+
return vectorDag;
74+
}
8575

86-
auto plaintextsDag = NodeTy::leaf(*plaintexts);
76+
// Returns an arithmetic DAG that implements a baby-step-giant-step
77+
// rotate-and-reduce accumulation between an input ciphertext
78+
// (giantSteppedOperand) and an abstraction over the other argument
79+
// (babySteppedOperand). In particular, the babySteppedOperand may be a list of
80+
// plaintexts like in Halevi-Shoup matvec, or a single ciphertext like in
81+
// bicyclic matmul, and this abstracts over both by taking in an extraction
82+
// callback.
83+
//
84+
// This is a special case of `tensor_ext.rotate_and_reduce`, but with the added
85+
// abstractions it also supports situations not currently expressible by
86+
// `tensor_ext.rotate_and_reduce`.
87+
template <typename T>
88+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
89+
std::shared_ptr<ArithmeticDagNode<T>>>
90+
implementBabyStepGiantStep(const T& giantSteppedOperand,
91+
const T& babySteppedOperand, int64_t period,
92+
int64_t steps,
93+
std::function<std::shared_ptr<ArithmeticDagNode<T>>(
94+
std::shared_ptr<ArithmeticDagNode<T>>, int64_t)>
95+
extractFunc,
96+
std::function<std::shared_ptr<ArithmeticDagNode<T>>(
97+
std::shared_ptr<ArithmeticDagNode<T>>,
98+
std::shared_ptr<ArithmeticDagNode<T>>)>
99+
reduceFunc) {
100+
using NodeTy = ArithmeticDagNode<T>;
101+
auto giantSteppedDag = NodeTy::leaf(giantSteppedOperand);
102+
auto babySteppedDag = NodeTy::leaf(babySteppedOperand);
87103

88104
// Use a value of sqrt(n) as the baby step / giant step size.
89105
int64_t numBabySteps = static_cast<int64_t>(std::floor(std::sqrt(steps)));
@@ -112,9 +128,9 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
112128

113129
// Compute sqrt(n) ciphertext rotations of the input as baby-steps.
114130
SmallVector<std::shared_ptr<NodeTy>> babyStepVals;
115-
babyStepVals.push_back(vectorDag); // rot by zero
131+
babyStepVals.push_back(giantSteppedDag); // rot by zero
116132
for (int64_t i = 1; i < numBabySteps; ++i) {
117-
babyStepVals.push_back(NodeTy::leftRotate(vectorDag, period * i));
133+
babyStepVals.push_back(NodeTy::leftRotate(giantSteppedDag, period * i));
118134
}
119135

120136
// Compute the inner baby step sums.
@@ -125,7 +141,7 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
125141
int64_t plaintextRotationAmount = -giantStepSize * j * period;
126142
for (int64_t i = 0; i < numBabySteps; ++i) {
127143
size_t extractionIndex = i + j * giantStepSize;
128-
auto plaintext = NodeTy::extract(plaintextsDag, extractionIndex);
144+
auto plaintext = extractFunc(babySteppedDag, extractionIndex);
129145
auto rotatedPlaintext =
130146
NodeTy::leftRotate(plaintext, plaintextRotationAmount);
131147
auto multiplied = NodeTy::mul(rotatedPlaintext, babyStepVals[i]);
@@ -141,6 +157,105 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
141157
return result;
142158
}
143159

160+
// Returns an arithmetic DAG that implements a tensor_ext.rotate_and_reduce op.
161+
//
162+
// See TensorExtOps.td docs for RotateAndReduceOp for more details.
163+
//
164+
// The `vector` argument is a ciphertext value that will be rotated O(sqrt(n))
165+
// times when the `plaintexts` argument is set (Baby Step Giant Step), or
166+
// O(log(n)) times when the `plaintexts` argument is not set (log-style
167+
// rotate-and-reduce accumulation).
168+
//
169+
// The `plaintexts` argument, when present, represents a vector of pre-packed
170+
// plaintexts that will be rotated and multiplied with the rotated `vector`
171+
// argument in BSGS style.
172+
//
173+
// Note that using this kernel results in places in the pipeline where a
174+
// plaintext type is rotated, but most FHE implementations don't have a
175+
// plaintext rotation operation (it would be wasteful) and instead expect the
176+
// "plaintext rotation" to apply to the cleartext. HEIR has places in the
177+
// pipeline that support this by converting a rotate(encode(cleartext)) to
178+
// encode(rotate(cleartext)).
179+
template <typename T>
180+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
181+
std::shared_ptr<ArithmeticDagNode<T>>>
182+
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
183+
int64_t period, int64_t steps,
184+
const std::string& reduceOp = "arith.addi") {
185+
using NodeTy = ArithmeticDagNode<T>;
186+
auto performReduction = [&](std::shared_ptr<NodeTy> left,
187+
std::shared_ptr<NodeTy> right) {
188+
if (reduceOp == "arith.addi" || reduceOp == "arith.addf") {
189+
return NodeTy::add(left, right);
190+
}
191+
192+
if (reduceOp == "arith.muli" || reduceOp == "arith.mulf") {
193+
return NodeTy::mul(left, right);
194+
}
195+
196+
// Default to add for unknown operations
197+
return NodeTy::add(left, right);
198+
};
199+
200+
if (!plaintexts.has_value()) {
201+
return implementRotateAndReduceAccumulation<T>(vector, period, steps,
202+
performReduction);
203+
}
204+
205+
auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
206+
int64_t extractionIndex) {
207+
return NodeTy::extract(babySteppedDag, extractionIndex);
208+
};
209+
210+
return implementBabyStepGiantStep<T>(vector, plaintexts.value(), period,
211+
steps, extractFunc, performReduction);
212+
}
213+
214+
// Returns an arithmetic DAG that implements a baby-step-giant-step between
215+
// ciphertexts.
216+
//
217+
// This implements equation 21 in 6.2.2 of LKAA25: "Tricycle: Private
218+
// Transformer Inference with Tricyclic Encodings"
219+
// https://eprint.iacr.org/2025/1200
220+
//
221+
// This differs from the above implementRotateAndReduce in that, instead of a
222+
// set of pre-computed plaintexts, both arguments are individual ciphertexts.
223+
// Normally with one ciphertext, the naive approach uses n - 1 rotations that
224+
// BSGS reduces to c sqrt(n) + O(1) rotations, if both inputs are ciphertexts
225+
// then it converts 2n - 2 total rotations to n + c sqrt(n) + O(1) rotations.
226+
// Essentially, the "n to sqrt(n)" redution applies to the `vector` argument
227+
// only, while the `plaintexts` argument still gets n-1 rotations.
228+
template <typename T>
229+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
230+
std::shared_ptr<ArithmeticDagNode<T>>>
231+
implementCiphertextCiphertextBabyStepGiantStep(
232+
const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
233+
int64_t steps, const std::string& reduceOp = "arith.addi") {
234+
using NodeTy = ArithmeticDagNode<T>;
235+
auto performReduction = [&](std::shared_ptr<NodeTy> left,
236+
std::shared_ptr<NodeTy> right) {
237+
if (reduceOp == "arith.addi" || reduceOp == "arith.addf") {
238+
return NodeTy::add(left, right);
239+
}
240+
241+
if (reduceOp == "arith.muli" || reduceOp == "arith.mulf") {
242+
return NodeTy::mul(left, right);
243+
}
244+
245+
// Default to add for unknown operations
246+
return NodeTy::add(left, right);
247+
};
248+
249+
// Avoid replicating and re-extracting by simulating the extraction step by
250+
// just returning the single ciphertext.
251+
auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
252+
int64_t extractionIndex) { return babySteppedDag; };
253+
254+
return implementBabyStepGiantStep<T>(giantSteppedOperand, babySteppedOperand,
255+
period, steps, extractFunc,
256+
performReduction);
257+
}
258+
144259
// Returns an arithmetic DAG that implements the Halevi-Shoup matrix
145260
// multiplication algorithm. This implementation uses a rotate-and-reduce
146261
// operation, followed by a summation of partial sums if the matrix is not

0 commit comments

Comments
 (0)