Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ struct ConvertQuartConstantOp
mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<IndexType>(op.getValue().getType())) {
return failure();
return rewriter.notifyMatchFailure(op, "value type is IndexType");
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

Expand Down Expand Up @@ -241,7 +241,8 @@ struct ConvertQuartConstantOp

return success();
}
return failure();
return rewriter.notifyMatchFailure(op,
"old value is not an integer attribute");
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ struct ConvertConstant : public OpConversionPattern<mlir::arith::ConstantOp> {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (isa<IndexType>(op.getValue().getType())) {
return failure();
return rewriter.notifyMatchFailure(op, "value type is IndexType");
}

Type newResultType = typeConverter->convertType(op.getResult().getType());
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/CGGI/Conversions/CGGIToJaxite/CGGIToJaxite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ struct AddJaxiteContextualArgs : public OpConversionPattern<func::FuncOp> {
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (!containsDialects<lwe::LWEDialect, cggi::CGGIDialect>(op)) {
return failure();
return rewriter.notifyMatchFailure(
op, "op does not contain lwe or cggi dialects");
}

auto serverKeyType = jaxite::ServerKeySetType::get(getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ struct AddServerKeyArg : public OpConversionPattern<func::FuncOp> {
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (!containsDialects<lwe::LWEDialect, cggi::CGGIDialect>(op)) {
return failure();
return rewriter.notifyMatchFailure(
op, "op does not contain lwe or cggi dialects");
}

auto serverKeyType = tfhe_rust::ServerKeyType::get(getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ struct AddBoolServerKeyArg : public OpConversionPattern<func::FuncOp> {
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (!containsDialects<lwe::LWEDialect, cggi::CGGIDialect>(op)) {
return failure();
return rewriter.notifyMatchFailure(
op, "op does not contain lwe or cggi dialects");
}

Type serverKeyType = tfhe_rust_bool::ServerKeyType::get(getContext());
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ struct RemoveKeyArg : public OpRewritePattern<func::FuncOp> {
}

if (argsToErase.none()) {
return failure();
return rewriter.notifyMatchFailure(op, "no key arguments to erase");
}

rewriter.modifyOpInPlace(op, [&] { (void)op.eraseArguments(argsToErase); });
Expand Down Expand Up @@ -442,7 +442,7 @@ struct ConvertRlweDecodeOp : public OpConversionPattern<DecodeOp> {

auto zeroAttr = rewriter.getZeroAttr(outputTensorType);
if (!zeroAttr) {
return op.emitOpError() << "Unsupported type for lowering";
return rewriter.notifyMatchFailure(op, "Unsupported type for lowering");
}
auto alloc =
AllocOp::create(rewriter, op.getLoc(), outputTensorType, zeroAttr);
Expand Down
14 changes: 10 additions & 4 deletions lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ struct AddCryptoContextArg : public OpConversionPattern<func::FuncOp> {
containsArgumentOfDialect<lwe::LWEDialect, bgv::BGVDialect,
ckks::CKKSDialect>(op);
if (!(containsCryptoOps || containsCryptoArg)) {
return failure();
return rewriter.notifyMatchFailure(
op, "contains neither ops nor arg types from lwe/bgv/ckks dialects");
}

auto cryptoContextType = openfhe::CryptoContextType::get(getContext());
Expand Down Expand Up @@ -239,7 +240,10 @@ struct ConvertEncodeOp : public OpConversionPattern<lwe::RLWEEncodeOp> {
// TODO (#1192): support coefficient packing in `--lwe-to-openfhe`
op.emitError() << "HEIR does not yet support coefficient encoding "
" when targeting OpenFHE";
return failure();
return rewriter.notifyMatchFailure(
op,
"HEIR does not yet support coefficient encoding when targeting "
"OpenFHE");
})
.Case<lwe::FullCRTPackingEncodingAttr>([&](auto encoding) {
rewriter.replaceOpWithNewOp<openfhe::MakePackedPlaintextOp>(
Expand All @@ -252,7 +256,7 @@ struct ConvertEncodeOp : public OpConversionPattern<lwe::RLWEEncodeOp> {
"Unexpected encoding while targeting OpenFHE. "
"If you expect this type of encoding to be supported "
"for the OpenFHE backend, please file a bug report.");
return failure();
return rewriter.notifyMatchFailure(op, "Unknown encoding");
});
}
};
Expand All @@ -267,7 +271,9 @@ struct ConvertBootstrapOp : public OpConversionPattern<ckks::BootstrapOp> {
ckks::BootstrapOp op, ckks::BootstrapOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
FailureOr<Value> result = getContextualCryptoContext(op.getOperation());
if (failed(result)) return result;
if (failed(result)) {
return rewriter.notifyMatchFailure(op, "No crypto context arg");
}

Value cryptoContext = result.value();
rewriter.replaceOpWithNewOp<openfhe::BootstrapOp>(
Expand Down
20 changes: 12 additions & 8 deletions lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ struct ConvertRLWEDecrypt : public OpConversionPattern<RLWEDecryptOp> {
// TODO (#882): For TFHE, which can support higher dimensional keys,
// plaintexts, and ciphertexts, we need to add support for encrypt and
// decrypt for those cases.
return failure();
return rewriter.notifyMatchFailure(op,
"expected 2 dimensional ciphertext");
}

ImplicitLocOpBuilder builder(loc, rewriter);
Expand Down Expand Up @@ -239,7 +240,8 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
if (!plaintextModArithType) {
op.emitError() << "Unsupported plaintext coefficient type: "
<< plaintextCoeffType;
return failure();
return rewriter.notifyMatchFailure(
op, "unsupported plaintext coefficient type");
}

// create scalar constant T in the output coefficient space
Expand Down Expand Up @@ -285,6 +287,8 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
<< "`lwe.rlwe_encrypt` only supports secret keys with a single "
"polynomial, got secret key type "
<< key.getType();
return rewriter.notifyMatchFailure(
op, "secret key has more than one polynomial");
}

// Generate random e polynomial from discrete gaussian distribution
Expand Down Expand Up @@ -405,7 +409,7 @@ struct ConvertRNegate : public OpConversionPattern<RNegateOp> {
});

if (failed(neg)) {
return failure();
return rewriter.notifyMatchFailure(op, "unsupported coefficient type");
}

rewriter.replaceOp(op, polynomial::MulScalarOp::create(
Expand All @@ -431,12 +435,11 @@ struct ConvertRMul : public OpConversionPattern<RMulOp> {
if (xT.getNumElements() != 2 || yT.getNumElements() != 2) {
op.emitError() << "`lwe.rmul` expects ciphertext as two polynomials, got "
<< xT.getNumElements() << " and " << yT.getNumElements();
return failure();
return rewriter.notifyMatchFailure(op, "ciphertext not two polynomials");
}

if (xT.getElementType() != yT.getElementType()) {
op->emitOpError() << "`lwe.rmul` expects operands of the same type";
return failure();
return rewriter.notifyMatchFailure(op, "operands not of same type");
}

ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Expand Down Expand Up @@ -485,7 +488,8 @@ struct ConvertRMulPlain : public OpConversionPattern<RMulPlainOp> {
op.emitError() << "`lwe.rmul_plain` expects ciphertext as two "
"polynomials and plaintext as 1, got "
<< xT.getNumElements() << " and " << yT.getNumElements();
return failure();
return rewriter.notifyMatchFailure(
op, "ciphertext not two polynomials or plaintext not one polynomial");
}

ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Expand All @@ -509,7 +513,7 @@ struct ConvertRelin : public OpConversionPattern<ckks::RelinearizeOp> {
ckks::RelinearizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (!op.getKeySwitchingKey()) {
return failure();
return rewriter.notifyMatchFailure(op, "no key switching key provided");
}

Value zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/LWE/IR/LWEPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ struct PutCiphertextInFirstOperand : public OpRewritePattern<Op> {
});
return success();
}
return failure();
return rewriter.notifyMatchFailure(
op,
"ciphertext not in second operand and plaintext not in first operand");
}
};

Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/LWE/Transforms/AddDebugPort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ LogicalResult insertExternalCall(func::FuncOp op, Type lwePrivateKeyType) {

LogicalResult convertFunc(func::FuncOp op) {
auto type = getPrivateKeyType(op);
if (failed(type)) return failure();
if (failed(type)) return op.emitError("failed to get private key type");
auto lwePrivateKeyType = type.value();

if (failed(op.insertArgument(0, lwePrivateKeyType, nullptr, op.getLoc()))) {
return failure();
return op.emitError("failed to insert private key argument");
}
if (failed(insertExternalCall(op, lwePrivateKeyType))) {
return failure();
return op.emitError("failed to insert external call");
}
return success();
}
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ struct ConvertBinOp : public OpRewritePattern<BinOp> {
auto& storageInfo = (*blockToStorageInfo)[op->getBlock()];
auto storage = storageInfo.getAvailableStorage(op, liveness);
if (!storage) {
return failure();
return rewriter.notifyMatchFailure(op, "no available storage found");
}

// InplaceOp has the form: output = InplaceOp(evaluator, lhs, rhs,
Expand Down Expand Up @@ -174,7 +174,7 @@ struct ConvertUnaryOp : public OpRewritePattern<UnaryOp> {
auto& storageInfo = (*blockToStorageInfo)[op->getBlock()];
auto storage = storageInfo.getAvailableStorage(op, liveness);
if (!storage) {
return failure();
return rewriter.notifyMatchFailure(op, "no available storage found");
}

// InplaceOp has the form: output = InplaceOp(evaluator, lhs, inplace)
Expand Down Expand Up @@ -209,7 +209,7 @@ struct ConvertRotateOp : public OpRewritePattern<RotateOp> {
auto& storageInfo = (*blockToStorageInfo)[op->getBlock()];
auto storage = storageInfo.getAvailableStorage(op, liveness);
if (!storage) {
return failure();
return rewriter.notifyMatchFailure(op, "no available storage found");
}

// InplaceOp has the form: output = InplaceOp(evaluator, lhs, inplace)
Expand Down Expand Up @@ -245,7 +245,7 @@ struct ConvertDropLevelOp : public OpRewritePattern<DropLevelOp> {
auto& storageInfo = (*blockToStorageInfo)[op->getBlock()];
auto storage = storageInfo.getAvailableStorage(op, liveness);
if (!storage) {
return failure();
return rewriter.notifyMatchFailure(op, "no available storage found");
}

// InplaceOp has the form: output = InplaceOp(evaluator, lhs, inplace)
Expand Down
28 changes: 18 additions & 10 deletions lib/Dialect/ModArith/IR/ModArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,22 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
// more easily convert them to the correct bitwidth (ArrayAttr forces I64)
std::vector<APInt> parsedInts;
if (parser.parseLess() ||
parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::Square,
[&] {
APInt parsedInt;
if (parser.parseInteger(parsedInt))
return failure();
parsedInts.push_back(parsedInt);
return success();
}) ||
parser.parseCommaSeparatedList(
mlir::AsmParser::Delimiter::Square,
[&] {
APInt parsedInt;
if (parser.parseInteger(parsedInt))
return parser.emitError(
parser.getNameLoc(),
"failed to parse integer in dense list"),
failure();
parsedInts.push_back(parsedInt);
return success();
}) ||
parser.parseGreater() || parser.parseColonType(parsedType))
return failure();
return parser.emitError(parser.getNameLoc(),
"failed to parse dense constant"),
failure();
if (parsedInts.empty())
return parser.emitError(parser.getNameLoc(),
"expected at least one integer in dense list.");
Expand All @@ -209,7 +215,9 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
// Scalar case
APInt parsedInt;
if (parser.parseInteger(parsedInt) || parser.parseColonType(parsedType))
return failure();
return parser.emitError(parser.getNameLoc(),
"failed to parse scalar constant"),
failure();
// zero becomes `i64` when parsed, so truncate back down to minBitwidth
if (parsedInt.isZero()) parsedInt = parsedInt.trunc(minBitwidth);
result.addAttribute(
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/ModArith/Transforms/ConvertToMac.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct FindMac : public OpRewritePattern<mod_arith::AddOp> {
if (!parent) {
auto parentRhs = op.getRhs().getDefiningOp<mod_arith::MulOp>();
if (!parentRhs) {
return failure();
return rewriter.notifyMatchFailure(
op, "neither operand is a mod_arith.mul operation");
}
// Find we have a form of lhs + a x b
parent = parentRhs;
Expand Down
Loading
Loading