Skip to content

Commit 571dd8a

Browse files
#2388 - added tricyclic batch ctct matmul to kernelImplementaion.h
1 parent 0c5ba35 commit 571dd8a

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

lib/Kernel/KernelImplementation.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,54 @@ implementBicyclicMatmul(const T& packedA, const T& packedB, int64_t m,
366366
packedA, packedB, /*period=*/m, /*steps=*/n, derivedRotationIndexFn);
367367
}
368368

369+
// Returns an arithmetic DAG that implements the tricyclic batch matrix
370+
// multiplication algorithm (ciphertext-ciphertext). Uses the tricyclic
371+
// rotation formulas from LKAA25 (Tricycle paper) and applies BSGS to
372+
// reduce rotations on the φ(A) side.
373+
//
374+
// The inputs packedA and packedB are expected to be tricyclic encodings
375+
// φ(A) and φ(B) for tensors A ∈ R^{h×m×n} and B ∈ R^{h×n×p}. The function
376+
// applies equation (22) and the ct-ct BSGS decomposition in Section 6.2.2.
377+
//
378+
// Parameters:
379+
// - packedA: tricyclic-encoded ciphertext for A (φ(A))
380+
// - packedB: tricyclic-encoded ciphertext for B (φ(B))
381+
// - h, m, n, p: tricyclic tensor dimensions (h: batch count / heads)
382+
//
383+
// Returns φ(Z) where Z = batch_matmul(A, B) as a DAG node.
384+
template <typename T>
385+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
386+
std::shared_ptr<ArithmeticDagNode<T>>>
387+
implementTricyclicBatchMatmul(const T& packedA, const T& packedB, int64_t h,
388+
int64_t m, int64_t n, int64_t p) {
389+
APInt hAPInt = APInt(64, h);
390+
APInt mAPInt = APInt(64, m);
391+
APInt nAPInt = APInt(64, n);
392+
APInt pAPInt = APInt(64, p);
393+
394+
APInt pInvModN = multiplicativeInverse(pAPInt.urem(nAPInt), nAPInt);
395+
396+
// This follows Eq. (22) and the ct-ct BSGS decomposition.
397+
auto derivedRotationIndexFn = [&](int64_t giantStepSize,
398+
int64_t giantStepIndex,
399+
int64_t babyStepIndex, int64_t period) {
400+
APInt c(64, giantStepIndex * giantStepSize + babyStepIndex);
401+
402+
APInt rotyInner = (c * mAPInt * pInvModN.getSExtValue()).urem(nAPInt);
403+
404+
APInt modulus = (hAPInt * nAPInt * pAPInt);
405+
APInt roty = (rotyInner * hAPInt * pAPInt).urem(modulus);
406+
407+
APInt result = roty - APInt(64, period) * APInt(64, giantStepSize) *
408+
APInt(64, giantStepIndex);
409+
return result.getSExtValue();
410+
};
411+
412+
int64_t period = h * m;
413+
return implementCiphertextCiphertextBabyStepGiantStep<T>(
414+
packedA, packedB, /*period=*/period, /*steps=*/n, derivedRotationIndexFn);
415+
}
416+
369417
} // namespace kernel
370418
} // namespace heir
371419
} // namespace mlir

0 commit comments

Comments
 (0)