@@ -366,6 +366,54 @@ implementBicyclicMatmul(const T& packedA, const T& packedB, int64_t m,
366366 packedA, packedB, /* period=*/ m, /* steps=*/ n, derivedRotationIndexFn);
367367}
368368
369+ // Returns an arithmetic DAG that implements the tricyclic batch matrix
370+ // multiplication algorithm (ciphertext-ciphertext). Uses the tricyclic
371+ // rotation formulas from LKAA25 (Tricycle paper) and applies BSGS to
372+ // reduce rotations on the φ(A) side.
373+ //
374+ // The inputs packedA and packedB are expected to be tricyclic encodings
375+ // φ(A) and φ(B) for tensors A ∈ R^{h×m×n} and B ∈ R^{h×n×p}. The function
376+ // applies equation (22) and the ct-ct BSGS decomposition in Section 6.2.2.
377+ //
378+ // Parameters:
379+ // - packedA: tricyclic-encoded ciphertext for A (φ(A))
380+ // - packedB: tricyclic-encoded ciphertext for B (φ(B))
381+ // - h, m, n, p: tricyclic tensor dimensions (h: batch count / heads)
382+ //
383+ // Returns φ(Z) where Z = batch_matmul(A, B) as a DAG node.
384+ template <typename T>
385+ std::enable_if_t <std::is_base_of<AbstractValue, T>::value,
386+ std::shared_ptr<ArithmeticDagNode<T>>>
387+ implementTricyclicBatchMatmul (const T& packedA, const T& packedB, int64_t h,
388+ int64_t m, int64_t n, int64_t p) {
389+ APInt hAPInt = APInt (64 , h);
390+ APInt mAPInt = APInt (64 , m);
391+ APInt nAPInt = APInt (64 , n);
392+ APInt pAPInt = APInt (64 , p);
393+
394+ APInt pInvModN = multiplicativeInverse (pAPInt.urem (nAPInt), nAPInt);
395+
396+ // This follows Eq. (22) and the ct-ct BSGS decomposition.
397+ auto derivedRotationIndexFn = [&](int64_t giantStepSize,
398+ int64_t giantStepIndex,
399+ int64_t babyStepIndex, int64_t period) {
400+ APInt c (64 , giantStepIndex * giantStepSize + babyStepIndex);
401+
402+ APInt rotyInner = (c * mAPInt * pInvModN.getSExtValue ()).urem (nAPInt);
403+
404+ APInt modulus = (hAPInt * nAPInt * pAPInt);
405+ APInt roty = (rotyInner * hAPInt * pAPInt).urem (modulus);
406+
407+ APInt result = roty - APInt (64 , period) * APInt (64 , giantStepSize) *
408+ APInt (64 , giantStepIndex);
409+ return result.getSExtValue ();
410+ };
411+
412+ int64_t period = h * m;
413+ return implementCiphertextCiphertextBabyStepGiantStep<T>(
414+ packedA, packedB, /* period=*/ period, /* steps=*/ n, derivedRotationIndexFn);
415+ }
416+
369417} // namespace kernel
370418} // namespace heir
371419} // namespace mlir
0 commit comments