Skip to content

Commit e603924

Browse files
Merge pull request #2359 from j2kun:bicyclic-matmul
PiperOrigin-RevId: 826599179
2 parents 6f60cd8 + 4ea5c46 commit e603924

File tree

7 files changed

+211
-0
lines changed

7 files changed

+211
-0
lines changed

lib/Kernel/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ cc_library(
7373
":AbstractValue",
7474
":ArithmeticDag",
7575
":Kernel",
76+
"@heir//lib/Utils:APIntUtils",
7677
"@heir//lib/Utils:MathUtils",
7778
"@llvm-project//mlir:Support",
7879
],

lib/Kernel/KernelImplementation.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "lib/Kernel/AbstractValue.h"
1616
#include "lib/Kernel/ArithmeticDag.h"
1717
#include "lib/Kernel/KernelName.h"
18+
#include "lib/Utils/APIntUtils.h"
1819
#include "lib/Utils/MathUtils.h"
1920
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
2021

@@ -178,6 +179,57 @@ implementHaleviShoup(const T& vector, const T& matrix,
178179
return summedShifts;
179180
}
180181

182+
// Returns an arithmetic DAG that implements the bicyclic matrix multiplication
183+
// algorithm.
184+
//
185+
// The input matrices packedA and packedB are assumed to be properly packed to
186+
// meet the conditions for bicyclic multiplication. That is: both matrices are
187+
// zero-padded so that their dimensions are coprime, they are cyclically
188+
// repeated to fill all the slots of the ciphertext, and they are packed
189+
// according to the bicyclic ordering.
190+
template <typename T>
191+
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
192+
std::shared_ptr<ArithmeticDagNode<T>>>
193+
implementBicyclicMatmul(const T& packedA, const T& packedB, int64_t m,
194+
int64_t n, int64_t p) {
195+
using NodeTy = ArithmeticDagNode<T>;
196+
auto packedADag = NodeTy::leaf(packedA);
197+
auto packedBDag = NodeTy::leaf(packedB);
198+
199+
// This implements the BMM-I algorithm from https://eprint.iacr.org/2024/1762
200+
// with a simplification of the rotation formula in Sec 5.2.1 (equation 9) of
201+
// https://eprint.iacr.org/2025/1200
202+
//
203+
// C = sum_{c=0}^{n-1} rot(A, r1(c)) * rot(B, r2(c))
204+
//
205+
// where
206+
//
207+
// r1(c) = cm(m^{-1} mod n) mod mn
208+
// r2(c) = cp(p^{-1} mod n) mod np
209+
//
210+
APInt mAPInt = APInt(64, m);
211+
APInt nAPInt = APInt(64, n);
212+
APInt pAPInt = APInt(64, p);
213+
214+
APInt mInvModN = multiplicativeInverse(mAPInt, nAPInt);
215+
APInt pInvModN = multiplicativeInverse(pAPInt, nAPInt);
216+
217+
// The part of r1(c), r2(c) that is independent of the loop iterations
218+
int64_t r1Const = m * mInvModN.getSExtValue();
219+
int64_t r2Const = p * pInvModN.getSExtValue();
220+
221+
std::shared_ptr<NodeTy> result = nullptr;
222+
for (int i = 0; i < n; ++i) {
223+
int64_t shiftA = (i * r1Const) % (m * n);
224+
int64_t shiftB = (i * r2Const) % (n * p);
225+
auto rotA = NodeTy::leftRotate(packedADag, shiftA);
226+
auto rotB = NodeTy::leftRotate(packedBDag, shiftB);
227+
auto term = NodeTy::mul(rotA, rotB);
228+
result = result == nullptr ? term : NodeTy::add(result, term);
229+
}
230+
return result;
231+
}
232+
181233
} // namespace kernel
182234
} // namespace heir
183235
} // namespace mlir

lib/Kernel/KernelImplementationTest.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,41 @@ TEST(KernelImplementationTest, Test2DConvWithLayout) {
138138
EXPECT_EQ(extractedResult, expected);
139139
}
140140

141+
TEST(KernelImplementationTest, BicyclicMatmul) {
142+
MLIRContext context;
143+
std::vector<std::vector<int>> matrixA = {
144+
{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}};
145+
std::vector<std::vector<int>> matrixB = {
146+
{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}};
147+
int m = 3;
148+
int n = 5;
149+
int p = 2;
150+
int numSlots = m * n * p;
151+
152+
auto layoutA = getBicyclicLayoutRelation(
153+
RankedTensorType::get({m, n}, IndexType::get(&context)), numSlots);
154+
auto packedA = evaluateLayoutOnMatrix(layoutA, matrixA);
155+
156+
auto layoutB = getBicyclicLayoutRelation(
157+
RankedTensorType::get({n, p}, IndexType::get(&context)), numSlots);
158+
auto packedB = evaluateLayoutOnMatrix(layoutB, matrixB);
159+
160+
LiteralValue packedAValue = packedA[0];
161+
LiteralValue packedBValue = packedB[0];
162+
163+
auto dag = implementBicyclicMatmul(packedAValue, packedBValue, m, n, p);
164+
LiteralValue result = evalKernel(dag);
165+
auto resultVec = std::get<std::vector<int>>(result.getTensor());
166+
167+
auto resultLayout = getBicyclicLayoutRelation(
168+
RankedTensorType::get({m, p}, IndexType::get(&context)), numSlots);
169+
auto unpackedResult =
170+
unpackLayoutToMatrix<int>(resultLayout, {resultVec}, {m, p});
171+
172+
std::vector<std::vector<int>> expected = {{95, 110}, {220, 260}, {345, 410}};
173+
EXPECT_EQ(unpackedResult, expected);
174+
}
175+
141176
} // namespace
142177
} // namespace kernel
143178
} // namespace heir

lib/Utils/Layout/Evaluate.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ std::vector<std::vector<T>> evaluateLayoutOnMatrix(
8787
return evaluateLayout(relation, getValueFn);
8888
}
8989

90+
template <typename T>
91+
std::vector<std::vector<T>> unpackLayoutToMatrix(
92+
const presburger::IntegerRelation& relation,
93+
const std::vector<std::vector<T>>& packed,
94+
ArrayRef<int64_t> originalShape) {
95+
std::vector<std::vector<T>> result(originalShape[0],
96+
std::vector<T>(originalShape[1], 0));
97+
98+
// Get all points in the relation.
99+
PointPairCollector collector(relation.getNumDomainVars(), /*rangeDims=*/2);
100+
enumeratePoints(relation, collector);
101+
102+
for (const auto& pointPair : collector.points) {
103+
std::vector<int64_t> domainPoint = pointPair.first;
104+
std::vector<int64_t> rangePoint = pointPair.second;
105+
result[domainPoint[0]][domainPoint[1]] =
106+
packed[rangePoint[0]][rangePoint[1]];
107+
}
108+
return result;
109+
}
110+
90111
} // namespace heir
91112
} // namespace mlir
92113

lib/Utils/Layout/Utils.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,64 @@ presburger::IntegerRelation getDiagonalLayoutRelation(
208208
return result;
209209
}
210210

211+
presburger::IntegerRelation getBicyclicLayoutRelation(
212+
RankedTensorType matrixType, int64_t numSlots) {
213+
unsigned int rows = matrixType.getDimSize(0);
214+
unsigned int cols = matrixType.getDimSize(1);
215+
216+
assert(std::gcd(rows, cols) == 1 &&
217+
"bicyclic layout requires coprime dimensions");
218+
219+
IntegerRelation result(PresburgerSpace::getRelationSpace(
220+
matrixType.getRank(), /*numRange=*/2, /*numSymbol=*/0,
221+
/*numLocals=*/0));
222+
223+
// Add bounds for the data matrix dimensions.
224+
int domainOffset = result.getVarKindOffset(VarKind::Domain);
225+
int rangeOffset = result.getVarKindOffset(VarKind::Range);
226+
int rowVarIndex = domainOffset;
227+
int colVarIndex = domainOffset + 1;
228+
int ctVarIndex = rangeOffset;
229+
int slotVarIndex = rangeOffset + 1;
230+
231+
addBounds(result, rowVarIndex, 0, rows - 1);
232+
addBounds(result, colVarIndex, 0, cols - 1);
233+
addBounds(result, ctVarIndex, 0,
234+
std::ceil((float)matrixType.getNumElements() / numSlots) - 1);
235+
addBounds(result, slotVarIndex, 0, numSlots - 1);
236+
237+
// Let k = ct * numSlots + slot.
238+
// We need to add constraints for:
239+
// row = k % rows
240+
// col = k % cols
241+
242+
// k_mod_rows = (ct * numSlots + slot) % rows
243+
SmallVector<int64_t> kCoeffs(result.getNumCols(), 0);
244+
kCoeffs[ctVarIndex] = numSlots;
245+
kCoeffs[slotVarIndex] = 1;
246+
auto kModRows = addModConstraint(result, kCoeffs, rows);
247+
248+
// row = k_mod_rows
249+
SmallVector<int64_t> rowEquality(result.getNumCols(), 0);
250+
rowEquality[rowVarIndex] = 1;
251+
rowEquality[kModRows] = -1;
252+
result.addEquality(rowEquality);
253+
254+
// k_mod_cols = (ct * numSlots + slot) % cols
255+
kCoeffs.resize(result.getNumCols(), 0);
256+
kCoeffs[ctVarIndex] = numSlots;
257+
kCoeffs[slotVarIndex] = 1;
258+
auto kModCols = addModConstraint(result, kCoeffs, cols);
259+
260+
// col = k_mod_cols
261+
SmallVector<int64_t> colEquality(result.getNumCols(), 0);
262+
colEquality[colVarIndex] = 1;
263+
colEquality[kModCols] = -1;
264+
result.addEquality(colEquality);
265+
266+
return result;
267+
}
268+
211269
presburger::IntegerRelation getPerRowLayoutRelation(RankedTensorType matrixType,
212270
int64_t ciphertextSize) {
213271
auto domainSize = matrixType.getRank();

lib/Utils/Layout/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ presburger::IntegerRelation getRowMajorLayoutRelation(
4848
presburger::IntegerRelation getDiagonalLayoutRelation(
4949
RankedTensorType matrixType, int64_t ciphertextSize);
5050

51+
// Returns an IntegerRelation that represents a bicyclic layout for a matrix.
52+
// See https://eprint.iacr.org/2024/1762 for details.
53+
presburger::IntegerRelation getBicyclicLayoutRelation(
54+
RankedTensorType matrixType, int64_t numSlots);
55+
5156
// Returns an IntegerRelation that represents a per-row layout for a matrix
5257
// such that each row of the matrix is in a separate ciphertext.
5358
presburger::IntegerRelation getPerRowLayoutRelation(RankedTensorType matrixType,

lib/Utils/Layout/UtilsTest.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,45 @@ TEST(UtilsTest, SquatDiagonalLayout) {
160160
}
161161
}
162162

163+
TEST(UtilsTest, BicyclicLayout3x5) {
164+
MLIRContext context;
165+
int64_t numSlots = 15;
166+
RankedTensorType matrixType =
167+
RankedTensorType::get({3, 5}, IndexType::get(&context));
168+
IntegerRelation bicyclicRelation =
169+
getBicyclicLayoutRelation(matrixType, numSlots);
170+
171+
std::vector<std::vector<int>> matrix = {
172+
{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}};
173+
std::vector<std::vector<int>> packedMatrix =
174+
evaluateLayoutOnMatrix(bicyclicRelation, matrix);
175+
176+
std::vector<std::vector<int>> expected = {
177+
{1, 7, 13, 4, 10, 11, 2, 8, 14, 5, 6, 12, 3, 9, 15}};
178+
EXPECT_EQ(packedMatrix, expected);
179+
}
180+
181+
TEST(UtilsTest, BicyclicLayout3x5Repeated) {
182+
MLIRContext context;
183+
184+
int64_t numSlots = 32;
185+
RankedTensorType matrixType =
186+
RankedTensorType::get({3, 5}, IndexType::get(&context));
187+
IntegerRelation bicyclicRelation =
188+
getBicyclicLayoutRelation(matrixType, numSlots);
189+
190+
std::vector<std::vector<int>> matrix = {
191+
{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, {11, 12, 13, 14, 15}};
192+
std::vector<std::vector<int>> packedMatrix =
193+
evaluateLayoutOnMatrix(bicyclicRelation, matrix);
194+
195+
std::vector<std::vector<int>> expected = {
196+
{1, 7, 13, 4, 10, 11, 2, 8, 14, 5, 6, 12, 3, 9, 15,
197+
// Cyclically repeated to fill 32 slots
198+
1, 7, 13, 4, 10, 11, 2, 8, 14, 5, 6, 12, 3, 9, 15, 1, 7}};
199+
EXPECT_EQ(packedMatrix, expected);
200+
}
201+
163202
TEST(UtilsTest, TestGetRangePoints) {
164203
MLIRContext context;
165204
auto rel = getIntegerRelationFromIslStr(

0 commit comments

Comments
 (0)