Skip to content

Commit b703e09

Browse files
committed
Improve the runtime of secret-insert-mgmt-* passes
1 parent 41ce880 commit b703e09

File tree

15 files changed

+868
-677
lines changed

15 files changed

+868
-677
lines changed

MODULE.bazel.lock

Lines changed: 366 additions & 385 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/Analysis/LevelAnalysis/LevelAnalysis.cpp

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,67 +34,77 @@ namespace heir {
3434
// LevelAnalysis (Forward)
3535
//===----------------------------------------------------------------------===//
3636

37-
LogicalResult LevelAnalysis::visitOperation(
38-
Operation* op, ArrayRef<const LevelLattice*> operands,
39-
ArrayRef<LevelLattice*> results) {
40-
auto propagate = [&](Value value, const LevelState& state) {
41-
auto* lattice = getLatticeElement(value);
42-
ChangeResult changed = lattice->join(state);
43-
propagateIfChanged(lattice, changed);
44-
};
45-
46-
LLVM_DEBUG(llvm::dbgs() << "Forward Propagate visiting " << op->getName()
47-
<< "\n");
48-
49-
llvm::TypeSwitch<Operation&>(*op)
50-
.Case<mgmt::ModReduceOp>([&](auto modReduceOp) {
37+
FailureOr<int64_t> deriveResultLevel(Operation* op,
38+
ArrayRef<const LevelLattice*> operands) {
39+
return llvm::TypeSwitch<Operation&, FailureOr<int64_t>>(*op)
40+
.Case<mgmt::ModReduceOp>([&](auto modReduceOp) -> FailureOr<int64_t> {
5141
// implicitly ensure that the operand is secret
5242
const auto* operandLattice = operands[0];
5343
if (!operandLattice->getValue().isInitialized()) {
54-
return;
44+
return failure();
5545
}
56-
auto level = operandLattice->getValue().getLevel();
57-
propagate(modReduceOp.getResult(), LevelState(level + 1));
46+
int64_t level = operandLattice->getValue().getLevel();
47+
return level + 1;
5848
})
59-
.Case<mgmt::LevelReduceOp>([&](auto levelReduceOp) {
49+
.Case<mgmt::LevelReduceOp>([&](auto levelReduceOp) -> FailureOr<int64_t> {
6050
// implicitly ensure that the operand is secret
6151
const auto* operandLattice = operands[0];
6252
if (!operandLattice->getValue().isInitialized()) {
63-
return;
53+
return failure();
6454
}
65-
auto level = operandLattice->getValue().getLevel();
66-
propagate(levelReduceOp.getResult(),
67-
LevelState(level + levelReduceOp.getLevelToDrop()));
55+
return operandLattice->getValue().getLevel() +
56+
levelReduceOp.getLevelToDrop();
6857
})
69-
.Case<mgmt::BootstrapOp>([&](auto bootstrapOp) {
58+
.Case<mgmt::BootstrapOp>([&](auto bootstrapOp) -> FailureOr<int64_t> {
7059
// implicitly ensure that the result is secret
7160
// reset level to 0
7261
// TODO(#1207): reset level to currentLevel - bootstrapDepth
73-
propagate(bootstrapOp.getResult(), LevelState(0));
62+
return 0;
7463
})
75-
.Default([&](auto& op) {
76-
// condition on result secretness
77-
SmallVector<OpResult> secretResults;
78-
getSecretResults(&op, secretResults);
79-
if (secretResults.empty()) {
80-
return;
81-
}
82-
64+
.Default([&](auto& op) -> FailureOr<int64_t> {
8365
auto levelResult = 0;
84-
SmallVector<OpOperand*> secretOperands;
85-
getSecretOperands(&op, secretOperands);
86-
for (auto* operand : secretOperands) {
87-
auto& levelState = getLatticeElement(operand->get())->getValue();
88-
if (!levelState.isInitialized()) {
89-
return;
66+
for (auto* levelState : operands) {
67+
if (!levelState || !levelState->getValue().isInitialized()) {
68+
continue;
9069
}
91-
levelResult = std::max(levelResult, levelState.getLevel());
70+
levelResult =
71+
std::max(levelResult, levelState->getValue().getLevel());
9272
}
9373

94-
for (auto result : secretResults) {
95-
propagate(result, LevelState(levelResult));
96-
}
74+
return levelResult;
9775
});
76+
}
77+
78+
LogicalResult LevelAnalysis::visitOperation(
79+
Operation* op, ArrayRef<const LevelLattice*> operands,
80+
ArrayRef<LevelLattice*> results) {
81+
auto propagate = [&](Value value, const LevelState& state) {
82+
auto* lattice = getLatticeElement(value);
83+
ChangeResult changed = lattice->join(state);
84+
propagateIfChanged(lattice, changed);
85+
};
86+
87+
LLVM_DEBUG(llvm::dbgs() << "Forward Propagate visiting " << op->getName()
88+
<< "\n");
89+
90+
SmallVector<OpOperand*> secretOperands;
91+
getSecretOperands(op, secretOperands);
92+
SmallVector<const LevelLattice*, 2> secretOperandLattices;
93+
for (auto* operand : secretOperands) {
94+
secretOperandLattices.push_back(getLatticeElement(operand->get()));
95+
}
96+
FailureOr<int64_t> resultLevel = deriveResultLevel(op, secretOperandLattices);
97+
if (failed(resultLevel)) {
98+
// Ignore failure and continue
99+
return success();
100+
}
101+
102+
SmallVector<OpResult> secretResults;
103+
getSecretResults(op, secretResults);
104+
for (auto result : secretResults) {
105+
propagate(result, LevelState(resultLevel.value()));
106+
}
107+
98108
return success();
99109
}
100110

lib/Analysis/LevelAnalysis/LevelAnalysis.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class LevelState {
3535
assert(isInitialized());
3636
return level.value();
3737
}
38+
void setLevel(LevelType value) {
39+
level = std::make_optional<LevelType>(value);
40+
}
3841
LevelType get() const { return getLevel(); }
3942

4043
bool operator==(const LevelState& rhs) const { return level == rhs.level; }
@@ -118,6 +121,9 @@ class LevelAnalysis
118121
}
119122
};
120123

124+
FailureOr<int64_t> deriveResultLevel(Operation* op,
125+
ArrayRef<const LevelLattice*> operands);
126+
121127
/// Backward Analyse the level of plaintext Value
122128
///
123129
/// This analysis should be run after the (forward) LevelAnalysis

lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,25 @@
2020
namespace mlir {
2121
namespace heir {
2222

23+
FailureOr<int64_t> deriveResultMulDepth(
24+
Operation* op, ArrayRef<const MulDepthLattice*> operands) {
25+
auto isMul = false;
26+
if (isa<arith::MulIOp, arith::MulFOp, mgmt::AdjustScaleOp>(op)) {
27+
isMul = true;
28+
}
29+
30+
int64_t operandsMulDepth = 0;
31+
for (auto* operand : operands) {
32+
if (!operand || !operand->getValue().isInitialized()) {
33+
continue;
34+
}
35+
operandsMulDepth =
36+
std::max(operandsMulDepth, operand->getValue().getMulDepth());
37+
}
38+
39+
return operandsMulDepth + (isMul ? 1 : 0);
40+
}
41+
2342
LogicalResult MulDepthAnalysis::visitOperation(
2443
Operation* op, ArrayRef<const MulDepthLattice*> operands,
2544
ArrayRef<MulDepthLattice*> results) {
@@ -39,35 +58,26 @@ LogicalResult MulDepthAnalysis::visitOperation(
3958
})
4059
.Default([&](auto& op) {
4160
// condition on result secretness
42-
SmallVector<OpResult> secretDepths;
43-
getSecretResults(&op, secretDepths);
44-
if (secretDepths.empty()) {
61+
SmallVector<OpResult> secretResults;
62+
getSecretResults(&op, secretResults);
63+
if (secretResults.empty()) {
4564
return;
4665
}
4766

48-
auto isMul = false;
49-
50-
if (isa<arith::MulIOp, arith::MulFOp, mgmt::AdjustScaleOp>(op)) {
51-
isMul = true;
52-
}
53-
54-
// inherit mul depth from secret operands
55-
int64_t operandsMulDepth = 0;
5667
SmallVector<OpOperand*> secretOperands;
5768
getSecretOperands(&op, secretOperands);
69+
SmallVector<const MulDepthLattice*, 2> secretOperandLattices;
5870
for (auto* operand : secretOperands) {
59-
auto& mulDepthState = getLatticeElement(operand->get())->getValue();
60-
if (!mulDepthState.isInitialized()) {
61-
return;
62-
}
63-
operandsMulDepth =
64-
std::max(operandsMulDepth, mulDepthState.getMulDepth());
71+
secretOperandLattices.push_back(getLatticeElement(operand->get()));
72+
}
73+
FailureOr<int64_t> resultsMulDepth =
74+
deriveResultMulDepth(&op, secretOperandLattices);
75+
if (failed(resultsMulDepth)) {
76+
return;
6577
}
6678

67-
int64_t resultsMulDepth = operandsMulDepth + (isMul ? 1 : 0);
68-
69-
for (auto result : secretDepths) {
70-
propagate(result, MulDepthState(resultsMulDepth));
79+
for (auto result : secretResults) {
80+
propagate(result, MulDepthState(resultsMulDepth.value()));
7181
}
7282
});
7383
return success();

lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class MulDepthState {
3737
return mulDepth.value();
3838
}
3939

40+
void setMulDepth(int64_t depth) {
41+
mulDepth = std::make_optional<int64_t>(depth);
42+
}
43+
4044
bool operator==(const MulDepthState& rhs) const {
4145
return mulDepth == rhs.mulDepth;
4246
}
@@ -102,6 +106,9 @@ class MulDepthAnalysis
102106
}
103107
};
104108

109+
FailureOr<int64_t> deriveResultMulDepth(
110+
Operation* op, ArrayRef<const MulDepthLattice*> operands);
111+
105112
int64_t getMaxMulDepth(Operation* op, DataFlowSolver& solver);
106113

107114
} // namespace heir

lib/Transforms/SecretInsertMgmt/BUILD

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_library(
1717
"Passes.h",
1818
],
1919
deps = [
20+
":Pipeline",
2021
":SecretInsertMgmtPatterns",
2122
":pass_inc_gen",
2223
"@heir//lib/Analysis/LevelAnalysis",
@@ -41,6 +42,39 @@ cc_library(
4142
],
4243
)
4344

45+
cc_library(
46+
name = "Pipeline",
47+
srcs = [
48+
"Pipeline.cpp",
49+
"Pipeline.h",
50+
],
51+
hdrs = [
52+
"Passes.h",
53+
],
54+
deps = [
55+
":SecretInsertMgmtPatterns",
56+
"@heir//lib/Analysis/LevelAnalysis",
57+
"@heir//lib/Analysis/MulDepthAnalysis",
58+
"@heir//lib/Analysis/SecretnessAnalysis",
59+
"@heir//lib/Dialect:ModuleAttributes",
60+
"@heir//lib/Dialect/BGV/IR:Dialect",
61+
"@heir//lib/Dialect/CKKS/IR:Dialect",
62+
"@heir//lib/Dialect/Mgmt/IR:Dialect",
63+
"@heir//lib/Dialect/Mgmt/Transforms",
64+
"@heir//lib/Dialect/Mgmt/Transforms:AnnotateMgmt",
65+
"@heir//lib/Dialect/Secret/IR:Dialect",
66+
"@llvm-project//llvm:Support",
67+
"@llvm-project//mlir:Analysis",
68+
"@llvm-project//mlir:ArithDialect",
69+
"@llvm-project//mlir:IR",
70+
"@llvm-project//mlir:Pass",
71+
"@llvm-project//mlir:Support",
72+
"@llvm-project//mlir:TensorDialect",
73+
"@llvm-project//mlir:TransformUtils",
74+
"@llvm-project//mlir:Transforms",
75+
],
76+
)
77+
4478
cc_library(
4579
name = "SecretInsertMgmtPatterns",
4680
srcs = [

0 commit comments

Comments
 (0)