@@ -49,41 +49,57 @@ implementMatvec(KernelName kernelName, const T& matrix, const T& vector) {
4949 return accumulatedSum;
5050}
5151
52- // Returns an arithmetic DAG that implements a rotate and reduce op. Ensure
53- // this is only generated for T a subclass of AbstractValue.
52+ // Returns an arithmetic DAG that implements a logarithmic rotate-and-reduce
53+ // accumulation of an input ciphertext.
54+ //
55+ // This is a special case of `tensor_ext.rotate_and_reduce`
5456template <typename T>
5557std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
5658 std::shared_ptr<ArithmeticDagNode<T>>>
57- implementRotateAndReduce (const T& vector, std::optional<T> plaintexts,
58- int64_t period, int64_t steps,
59- const std::string& reduceOp = " arith.addi" ) {
59+ implementRotateAndReduceAccumulation (
60+ const T& vector, int64_t period, int64_t steps,
61+ std::function<std::shared_ptr<ArithmeticDagNode<T>>(
62+ std::shared_ptr<ArithmeticDagNode<T>>,
63+ std::shared_ptr<ArithmeticDagNode<T>>)>
64+ reduceFunc) {
6065 using NodeTy = ArithmeticDagNode<T>;
6166 auto vectorDag = NodeTy::leaf (vector);
6267
63- auto performReduction = [&](std::shared_ptr<NodeTy> left,
64- std::shared_ptr<NodeTy> right) {
65- if (reduceOp == " arith.addi" || reduceOp == " arith.addf" ) {
66- return NodeTy::add (left, right);
67- }
68-
69- if (reduceOp == " arith.muli" || reduceOp == " arith.mulf" ) {
70- return NodeTy::mul (left, right);
71- }
72-
73- // Default to add for unknown operations
74- return NodeTy::add (left, right);
75- };
76-
77- if (!plaintexts.has_value ()) {
78- for (int64_t shiftSize = steps / 2 ; shiftSize > 0 ; shiftSize /= 2 ) {
79- auto rotated = NodeTy::leftRotate (vectorDag, shiftSize * period);
80- auto reduced = performReduction (vectorDag, rotated);
81- vectorDag = reduced;
82- }
83- return vectorDag;
68+ for (int64_t shiftSize = steps / 2 ; shiftSize > 0 ; shiftSize /= 2 ) {
69+ auto rotated = NodeTy::leftRotate (vectorDag, shiftSize * period);
70+ auto reduced = performReduction (vectorDag, rotated);
71+ vectorDag = reduced;
8472 }
73+ return vectorDag;
74+ }
8575
86- auto plaintextsDag = NodeTy::leaf (*plaintexts);
76+ // Returns an arithmetic DAG that implements a baby-step-giant-step
77+ // rotate-and-reduce accumulation between an input ciphertext
78+ // (giantSteppedOperand) and an abstraction over the other argument
79+ // (babySteppedOperand). In particular, the babySteppedOperand may be a list of
80+ // plaintexts like in Halevi-Shoup matvec, or a single ciphertext like in
81+ // bicyclic matmul, and this abstracts over both by taking in an extraction
82+ // callback.
83+ //
84+ // This is a special case of `tensor_ext.rotate_and_reduce`, but with the added
85+ // abstractions it also supports situations not currently expressible by
86+ // `tensor_ext.rotate_and_reduce`.
87+ template <typename T>
88+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
89+ std::shared_ptr<ArithmeticDagNode<T>>>
90+ implementBabyStepGiantStep (const T& giantSteppedOperand,
91+ const T& babySteppedOperand, int64_t period,
92+ int64_t steps,
93+ std::function<std::shared_ptr<ArithmeticDagNode<T>>(
94+ std::shared_ptr<ArithmeticDagNode<T>>, int64_t )>
95+ extractFunc,
96+ std::function<std::shared_ptr<ArithmeticDagNode<T>>(
97+ std::shared_ptr<ArithmeticDagNode<T>>,
98+ std::shared_ptr<ArithmeticDagNode<T>>)>
99+ reduceFunc) {
100+ using NodeTy = ArithmeticDagNode<T>;
101+ auto giantSteppedDag = NodeTy::leaf (giantSteppedOperand);
102+ auto babySteppedDag = NodeTy::leaf (babySteppedOperand);
87103
88104 // Use a value of sqrt(n) as the baby step / giant step size.
89105 int64_t numBabySteps = static_cast <int64_t >(std::floor (std::sqrt (steps)));
@@ -112,9 +128,9 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
112128
113129 // Compute sqrt(n) ciphertext rotations of the input as baby-steps.
114130 SmallVector<std::shared_ptr<NodeTy>> babyStepVals;
115- babyStepVals.push_back (vectorDag ); // rot by zero
131+ babyStepVals.push_back (giantSteppedDag ); // rot by zero
116132 for (int64_t i = 1 ; i < numBabySteps; ++i) {
117- babyStepVals.push_back (NodeTy::leftRotate (vectorDag , period * i));
133+ babyStepVals.push_back (NodeTy::leftRotate (giantSteppedDag , period * i));
118134 }
119135
120136 // Compute the inner baby step sums.
@@ -125,7 +141,7 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
125141 int64_t plaintextRotationAmount = -giantStepSize * j * period;
126142 for (int64_t i = 0 ; i < numBabySteps; ++i) {
127143 size_t extractionIndex = i + j * giantStepSize;
128- auto plaintext = NodeTy::extract (plaintextsDag , extractionIndex);
144+ auto plaintext = extractFunc (babySteppedDag , extractionIndex);
129145 auto rotatedPlaintext =
130146 NodeTy::leftRotate (plaintext, plaintextRotationAmount);
131147 auto multiplied = NodeTy::mul (rotatedPlaintext, babyStepVals[i]);
@@ -141,6 +157,105 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
141157 return result;
142158}
143159
160+ // Returns an arithmetic DAG that implements a tensor_ext.rotate_and_reduce op.
161+ //
162+ // See TensorExtOps.td docs for RotateAndReduceOp for more details.
163+ //
164+ // The `vector` argument is a ciphertext value that will be rotated O(sqrt(n))
165+ // times when the `plaintexts` argument is set (Baby Step Giant Step), or
166+ // O(log(n)) times when the `plaintexts` argument is not set (log-style
167+ // rotate-and-reduce accumulation).
168+ //
169+ // The `plaintexts` argument, when present, represents a vector of pre-packed
170+ // plaintexts that will be rotated and multiplied with the rotated `vector`
171+ // argument in BSGS style.
172+ //
173+ // Note that using this kernel results in places in the pipeline where a
174+ // plaintext type is rotated, but most FHE implementations don't have a
175+ // plaintext rotation operation (it would be wasteful) and instead expect the
176+ // "plaintext rotation" to apply to the cleartext. HEIR has places in the
177+ // pipeline that support this by converting a rotate(encode(cleartext)) to
178+ // encode(rotate(cleartext)).
179+ template <typename T>
180+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
181+ std::shared_ptr<ArithmeticDagNode<T>>>
182+ implementRotateAndReduce (const T& vector, std::optional<T> plaintexts,
183+ int64_t period, int64_t steps,
184+ const std::string& reduceOp = " arith.addi" ) {
185+ using NodeTy = ArithmeticDagNode<T>;
186+ auto performReduction = [&](std::shared_ptr<NodeTy> left,
187+ std::shared_ptr<NodeTy> right) {
188+ if (reduceOp == " arith.addi" || reduceOp == " arith.addf" ) {
189+ return NodeTy::add (left, right);
190+ }
191+
192+ if (reduceOp == " arith.muli" || reduceOp == " arith.mulf" ) {
193+ return NodeTy::mul (left, right);
194+ }
195+
196+ // Default to add for unknown operations
197+ return NodeTy::add (left, right);
198+ };
199+
200+ if (!plaintexts.has_value ()) {
201+ return implementRotateAndReduceAccumulation<T>(vector, period, steps,
202+ performReduction);
203+ }
204+
205+ auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
206+ int64_t extractionIndex) {
207+ return NodeTy::extract (babySteppedDag, extractionIndex);
208+ };
209+
210+ return implementBabyStepGiantStep<T>(vector, plaintexts.value (), period,
211+ steps, extractFunc, performReduction);
212+ }
213+
214+ // Returns an arithmetic DAG that implements a baby-step-giant-step between
215+ // ciphertexts.
216+ //
217+ // This implements equation 21 in 6.2.2 of LKAA25: "Tricycle: Private
218+ // Transformer Inference with Tricyclic Encodings"
219+ // https://eprint.iacr.org/2025/1200
220+ //
221+ // This differs from the above implementRotateAndReduce in that, instead of a
222+ // set of pre-computed plaintexts, both arguments are individual ciphertexts.
223+ // Normally with one ciphertext, the naive approach uses n - 1 rotations that
224+ // BSGS reduces to c sqrt(n) + O(1) rotations, if both inputs are ciphertexts
225+ // then it converts 2n - 2 total rotations to n + c sqrt(n) + O(1) rotations.
226+ // Essentially, the "n to sqrt(n)" redution applies to the `vector` argument
227+ // only, while the `plaintexts` argument still gets n-1 rotations.
228+ template <typename T>
229+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
230+ std::shared_ptr<ArithmeticDagNode<T>>>
231+ implementCiphertextCiphertextBabyStepGiantStep (
232+ const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
233+ int64_t steps, const std::string& reduceOp = " arith.addi" ) {
234+ using NodeTy = ArithmeticDagNode<T>;
235+ auto performReduction = [&](std::shared_ptr<NodeTy> left,
236+ std::shared_ptr<NodeTy> right) {
237+ if (reduceOp == " arith.addi" || reduceOp == " arith.addf" ) {
238+ return NodeTy::add (left, right);
239+ }
240+
241+ if (reduceOp == " arith.muli" || reduceOp == " arith.mulf" ) {
242+ return NodeTy::mul (left, right);
243+ }
244+
245+ // Default to add for unknown operations
246+ return NodeTy::add (left, right);
247+ };
248+
249+ // Avoid replicating and re-extracting by simulating the extraction step by
250+ // just returning the single ciphertext.
251+ auto extractFunc = [](std::shared_ptr<NodeTy> babySteppedDag,
252+ int64_t extractionIndex) { return babySteppedDag; };
253+
254+ return implementBabyStepGiantStep<T>(giantSteppedOperand, babySteppedOperand,
255+ period, steps, extractFunc,
256+ performReduction);
257+ }
258+
144259// Returns an arithmetic DAG that implements the Halevi-Shoup matrix
145260// multiplication algorithm. This implementation uses a rotate-and-reduce
146261// operation, followed by a summation of partial sums if the matrix is not
0 commit comments