Skip to content

Commit e7210db

Browse files
CKKS: Add CCH+23 Noise Model
1 parent da9c98f commit e7210db

File tree

9 files changed

+587
-7
lines changed

9 files changed

+587
-7
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package(
2+
default_applicable_licenses = ["@heir//:license"],
3+
default_visibility = ["//visibility:public"],
4+
)
5+
6+
cc_library(
7+
name = "NoiseAnalysis",
8+
srcs = [
9+
"NoiseAnalysis.cpp",
10+
],
11+
hdrs = [
12+
],
13+
deps = [
14+
":NoiseByVarianceCoeffModel",
15+
"@heir//lib/Analysis:Utils",
16+
"@heir//lib/Analysis/DimensionAnalysis",
17+
"@heir//lib/Analysis/LevelAnalysis",
18+
"@heir//lib/Analysis/NoiseAnalysis",
19+
"@heir//lib/Analysis/NoiseAnalysis:Noise",
20+
"@heir//lib/Analysis/ScaleAnalysis",
21+
"@heir//lib/Dialect/Mgmt/IR:Dialect",
22+
"@heir//lib/Dialect/Secret/IR:Dialect",
23+
"@heir//lib/Dialect/TensorExt/IR:Dialect",
24+
"@heir//lib/Parameters/BGV:Params",
25+
"@llvm-project//llvm:Support",
26+
"@llvm-project//mlir:ArithDialect",
27+
"@llvm-project//mlir:CallOpInterfaces",
28+
"@llvm-project//mlir:IR",
29+
"@llvm-project//mlir:Support",
30+
"@llvm-project//mlir:TensorDialect",
31+
],
32+
)
33+
34+
cc_library(
35+
name = "NoiseByVarianceCoeffModel",
36+
srcs = [
37+
"NoiseByVarianceCoeffModel.cpp",
38+
],
39+
hdrs = [
40+
"NoiseByVarianceCoeffModel.h",
41+
],
42+
deps = [
43+
"@heir//lib/Analysis/NoiseAnalysis:Noise",
44+
"@heir//lib/Parameters/CKKS:Params",
45+
"@heir//lib/Utils:MathUtils",
46+
],
47+
)
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#include "lib/Analysis/NoiseAnalysis/NoiseAnalysis.h"
2+
3+
#include <functional>
4+
5+
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"
6+
#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h"
7+
#include "lib/Analysis/NoiseAnalysis/CKKS/NoiseByVarianceCoeffModel.h"
8+
#include "lib/Analysis/ScaleAnalysis/ScaleAnalysis.h"
9+
#include "lib/Analysis/Utils.h"
10+
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
11+
#include "lib/Dialect/Secret/IR/SecretOps.h"
12+
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
13+
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
14+
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
15+
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
16+
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
18+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
19+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
20+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
21+
22+
#define DEBUG_TYPE "NoiseAnalysis"
23+
24+
namespace mlir {
25+
namespace heir {
26+
27+
// explicit specialization of NoiseAnalysis for NoiseByBoundCoeffModel
28+
template <typename NoiseModel>
29+
void NoiseAnalysis<NoiseModel>::setToEntryState(LatticeType *lattice) {
30+
// At an entry point, we have no information about the noise.
31+
this->propagateIfChanged(lattice, lattice->join(NoiseState::uninitialized()));
32+
}
33+
34+
// explicit specialization of NoiseAnalysis for NoiseByBoundCoeffModel
35+
template <typename NoiseModel>
36+
void NoiseAnalysis<NoiseModel>::visitExternalCall(
37+
CallOpInterface call, ArrayRef<const LatticeType *> argumentLattices,
38+
ArrayRef<LatticeType *> resultLattices) {
39+
auto callback =
40+
std::bind(&NoiseAnalysis<NoiseModel>::propagateIfChangedWrapper, this,
41+
std::placeholders::_1, std::placeholders::_2);
42+
::mlir::heir::visitExternalCall<NoiseState, LatticeType>(
43+
call, argumentLattices, resultLattices, callback);
44+
}
45+
46+
// explicit specialization of NoiseAnalysis for NoiseByBoundCoeffModel
47+
template <typename NoiseModel>
48+
LogicalResult NoiseAnalysis<NoiseModel>::visitOperation(
49+
Operation *op, ArrayRef<const LatticeType *> operands,
50+
ArrayRef<LatticeType *> results) {
51+
auto getLocalParam = [&](Value value) {
52+
auto level = getLevelFromMgmtAttr(value);
53+
auto dimension = getDimensionFromMgmtAttr(value);
54+
auto scale = getScaleFromMgmtAttr(value);
55+
return LocalParamType(&schemeParam, level, dimension, scale);
56+
};
57+
58+
auto propagate = [&](Value value, NoiseState noise) {
59+
LLVM_DEBUG(llvm::dbgs()
60+
<< "Propagating "
61+
<< NoiseModel::toLogBoundString(getLocalParam(value), noise)
62+
<< " to " << value << "\n");
63+
LatticeType *lattice = this->getLatticeElement(value);
64+
auto changeResult = lattice->join(noise);
65+
this->propagateIfChanged(lattice, changeResult);
66+
};
67+
68+
auto getOperandNoises = [&](Operation *op,
69+
SmallVectorImpl<NoiseState> &noises) {
70+
SmallVector<OpOperand *> secretOperands;
71+
SmallVector<OpOperand *> nonSecretOperands;
72+
this->getSecretOperands(op, secretOperands);
73+
this->getNonSecretOperands(op, nonSecretOperands);
74+
75+
for (auto *operand : secretOperands) {
76+
noises.push_back(this->getLatticeElement(operand->get())->getValue());
77+
}
78+
for (auto *operand : nonSecretOperands) {
79+
(void)operand;
80+
// at least one operand is secret
81+
auto localParam = getLocalParam(secretOperands[0]->get());
82+
noises.push_back(NoiseModel::evalConstant(localParam));
83+
}
84+
};
85+
86+
auto res =
87+
llvm::TypeSwitch<Operation &, LogicalResult>(*op)
88+
.Case<secret::GenericOp>([&](auto genericOp) {
89+
Block *body = genericOp.getBody();
90+
for (Value &arg : body->getArguments()) {
91+
auto localParam = getLocalParam(arg);
92+
NoiseState encrypted = NoiseModel::evalEncrypt(localParam);
93+
propagate(arg, encrypted);
94+
}
95+
return success();
96+
})
97+
.template Case<arith::MulFOp, arith::MulIOp>([&](auto mulOp) {
98+
SmallVector<OpResult> secretResults;
99+
this->getSecretResults(mulOp, secretResults);
100+
if (secretResults.empty()) {
101+
return success();
102+
}
103+
104+
SmallVector<NoiseState, 2> operandNoises;
105+
getOperandNoises(mulOp, operandNoises);
106+
107+
auto localParam = getLocalParam(mulOp.getResult());
108+
auto lhsParam = getLocalParam(mulOp.getOperand(0));
109+
auto rhsParam = getLocalParam(mulOp.getOperand(1));
110+
NoiseState mult =
111+
NoiseModel::evalMul(localParam, lhsParam, rhsParam,
112+
operandNoises[0], operandNoises[1]);
113+
propagate(mulOp.getResult(), mult);
114+
return success();
115+
})
116+
.template Case<arith::AddFOp, arith::SubFOp, arith::AddIOp,
117+
arith::SubIOp>([&](auto addOp) {
118+
SmallVector<OpResult> secretResults;
119+
this->getSecretResults(addOp, secretResults);
120+
if (secretResults.empty()) {
121+
return success();
122+
}
123+
124+
SmallVector<NoiseState, 2> operandNoises;
125+
getOperandNoises(addOp, operandNoises);
126+
NoiseState add =
127+
NoiseModel::evalAdd(operandNoises[0], operandNoises[1]);
128+
propagate(addOp.getResult(), add);
129+
return success();
130+
})
131+
.template Case<tensor_ext::RotateOp>([&](auto rotateOp) {
132+
// implicitly assumed secret
133+
auto localParam = getLocalParam(rotateOp.getOperand(0));
134+
135+
// assume relinearize immediately after rotate
136+
// when we support hoisting relinearize, we need to change
137+
// this
138+
NoiseState rotate = NoiseModel::evalRelinearize(
139+
localParam, operands[0]->getValue());
140+
propagate(rotateOp.getResult(), rotate);
141+
return success();
142+
})
143+
// NOTE: special case for ExtractOp... it is a mulconst+rotate
144+
// if not annotated with slot_extract
145+
// TODO(#1174): decide packing earlier in the pipeline instead
146+
// of annotation
147+
//.template Case<tensor::ExtractOp>([&](auto extractOp) {
148+
// auto localParam = getLocalParam(extractOp.getOperand(0));
149+
150+
// // extract = mul_plain 1 + rotate
151+
// // although the cleartext is 1, when encoded (i.e. CRT
152+
// // packing), the value multiplied to the ciphertext is not 1,
153+
// // If we can know the encoded value, we can bound it more
154+
// // precisely.
155+
// NoiseState one = NoiseModel::evalConstant(localParam);
156+
// NoiseState extract =
157+
// NoiseModel::evalMul(localParam, operands[0]->getValue(), one);
158+
// // assume relinearize immediately after rotate
159+
// // when we support hoisting relinearize, we need to change
160+
// // this
161+
// NoiseState rotate =
162+
// NoiseModel::evalRelinearize(localParam, extract);
163+
// propagate(extractOp.getResult(), extract);
164+
// return success();
165+
//})
166+
.template Case<mgmt::ModReduceOp>([&](auto modReduceOp) {
167+
// No-op for B/FV
168+
modReduceOp->emitWarning("ModReduceOp encountered in CKKS");
169+
propagate(modReduceOp.getResult(), operands[0]->getValue());
170+
return success();
171+
})
172+
.template Case<mgmt::RelinearizeOp>([&](auto relinearizeOp) {
173+
auto localParam = getLocalParam(relinearizeOp.getInput());
174+
175+
NoiseState relinearize = NoiseModel::evalRelinearize(
176+
localParam, operands[0]->getValue());
177+
propagate(relinearizeOp.getResult(), relinearize);
178+
return success();
179+
})
180+
.Default([&](auto &op) {
181+
// condition on result secretness
182+
SmallVector<OpResult> secretResults;
183+
this->getSecretResults(&op, secretResults);
184+
if (secretResults.empty()) {
185+
return success();
186+
}
187+
188+
if (!mlir::isa<arith::ConstantOp, arith::ExtSIOp, arith::ExtUIOp,
189+
arith::ExtFOp>(op)) {
190+
op.emitError()
191+
<< "Unsupported operation for noise analysis encountered.";
192+
}
193+
194+
SmallVector<OpOperand *> secretOperands;
195+
this->getSecretOperands(&op, secretOperands);
196+
if (secretOperands.empty()) {
197+
return success();
198+
}
199+
200+
// inherit noise from the first secret operand
201+
NoiseState first;
202+
for (auto *operand : secretOperands) {
203+
auto &noise = this->getLatticeElement(operand->get())->getValue();
204+
if (!noise.isInitialized()) {
205+
return success();
206+
}
207+
first = noise;
208+
break;
209+
}
210+
211+
for (auto result : secretResults) {
212+
propagate(result, first);
213+
}
214+
return success();
215+
});
216+
return res;
217+
}
218+
219+
// template instantiation
220+
// for variance
221+
template class NoiseAnalysis<ckks::NoiseByVarianceCoeffModel>;
222+
223+
} // namespace heir
224+
} // namespace mlir

0 commit comments

Comments
 (0)