@@ -34,67 +34,77 @@ namespace heir {
3434// LevelAnalysis (Forward)
3535// ===----------------------------------------------------------------------===//
3636
37- LogicalResult LevelAnalysis::visitOperation (
38- Operation* op, ArrayRef<const LevelLattice*> operands,
39- ArrayRef<LevelLattice*> results) {
40- auto propagate = [&](Value value, const LevelState& state) {
41- auto * lattice = getLatticeElement (value);
42- ChangeResult changed = lattice->join (state);
43- propagateIfChanged (lattice, changed);
44- };
45-
46- LLVM_DEBUG (llvm::dbgs () << " Forward Propagate visiting " << op->getName ()
47- << " \n " );
48-
49- llvm::TypeSwitch<Operation&>(*op)
50- .Case <mgmt::ModReduceOp>([&](auto modReduceOp) {
37+ FailureOr<int64_t > deriveResultLevel (Operation* op,
38+ ArrayRef<const LevelLattice*> operands) {
39+ return llvm::TypeSwitch<Operation&, FailureOr<int64_t >>(*op)
40+ .Case <mgmt::ModReduceOp>([&](auto modReduceOp) -> FailureOr<int64_t > {
5141 // implicitly ensure that the operand is secret
5242 const auto * operandLattice = operands[0 ];
5343 if (!operandLattice->getValue ().isInitialized ()) {
54- return ;
44+ return failure () ;
5545 }
56- auto level = operandLattice->getValue ().getLevel ();
57- propagate (modReduceOp. getResult (), LevelState ( level + 1 )) ;
46+ int64_t level = operandLattice->getValue ().getLevel ();
47+ return level + 1 ;
5848 })
59- .Case <mgmt::LevelReduceOp>([&](auto levelReduceOp) {
49+ .Case <mgmt::LevelReduceOp>([&](auto levelReduceOp) -> FailureOr< int64_t > {
6050 // implicitly ensure that the operand is secret
6151 const auto * operandLattice = operands[0 ];
6252 if (!operandLattice->getValue ().isInitialized ()) {
63- return ;
53+ return failure () ;
6454 }
65- auto level = operandLattice->getValue ().getLevel ();
66- propagate (levelReduceOp.getResult (),
67- LevelState (level + levelReduceOp.getLevelToDrop ()));
55+ return operandLattice->getValue ().getLevel () +
56+ levelReduceOp.getLevelToDrop ();
6857 })
69- .Case <mgmt::BootstrapOp>([&](auto bootstrapOp) {
58+ .Case <mgmt::BootstrapOp>([&](auto bootstrapOp) -> FailureOr< int64_t > {
7059 // implicitly ensure that the result is secret
7160 // reset level to 0
7261 // TODO(#1207): reset level to currentLevel - bootstrapDepth
73- propagate (bootstrapOp. getResult (), LevelState ( 0 )) ;
62+ return 0 ;
7463 })
75- .Default ([&](auto & op) {
76- // condition on result secretness
77- SmallVector<OpResult> secretResults;
78- getSecretResults (&op, secretResults);
79- if (secretResults.empty ()) {
80- return ;
81- }
82-
64+ .Default ([&](auto & op) -> FailureOr<int64_t > {
8365 auto levelResult = 0 ;
84- SmallVector<OpOperand*> secretOperands;
85- getSecretOperands (&op, secretOperands);
86- for (auto * operand : secretOperands) {
87- auto & levelState = getLatticeElement (operand->get ())->getValue ();
88- if (!levelState.isInitialized ()) {
89- return ;
66+ for (auto * levelState : operands) {
67+ if (!levelState || !levelState->getValue ().isInitialized ()) {
68+ continue ;
9069 }
91- levelResult = std::max (levelResult, levelState.getLevel ());
70+ levelResult =
71+ std::max (levelResult, levelState->getValue ().getLevel ());
9272 }
9373
94- for (auto result : secretResults) {
95- propagate (result, LevelState (levelResult));
96- }
74+ return levelResult;
9775 });
76+ }
77+
78+ LogicalResult LevelAnalysis::visitOperation (
79+ Operation* op, ArrayRef<const LevelLattice*> operands,
80+ ArrayRef<LevelLattice*> results) {
81+ auto propagate = [&](Value value, const LevelState& state) {
82+ auto * lattice = getLatticeElement (value);
83+ ChangeResult changed = lattice->join (state);
84+ propagateIfChanged (lattice, changed);
85+ };
86+
87+ LLVM_DEBUG (llvm::dbgs () << " Forward Propagate visiting " << op->getName ()
88+ << " \n " );
89+
90+ SmallVector<OpOperand*> secretOperands;
91+ getSecretOperands (op, secretOperands);
92+ SmallVector<const LevelLattice*, 2 > secretOperandLattices;
93+ for (auto * operand : secretOperands) {
94+ secretOperandLattices.push_back (getLatticeElement (operand->get ()));
95+ }
96+ FailureOr<int64_t > resultLevel = deriveResultLevel (op, secretOperandLattices);
97+ if (failed (resultLevel)) {
98+ // Ignore failure and continue
99+ return success ();
100+ }
101+
102+ SmallVector<OpResult> secretResults;
103+ getSecretResults (op, secretResults);
104+ for (auto result : secretResults) {
105+ propagate (result, LevelState (resultLevel.value ()));
106+ }
107+
98108 return success ();
99109}
100110
0 commit comments