Skip to content

Commit a6758f9

Browse files
committed
use ArithmeticDag for all lower-eval methods, and drop small coeffs
1 parent 41ce880 commit a6758f9

File tree

16 files changed

+728
-299
lines changed

16 files changed

+728
-299
lines changed

lib/Transforms/LowerPolynomialEval/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ cc_library(
3434
"@heir//lib/Utils:MathUtils",
3535
"@heir//lib/Utils/Polynomial",
3636
"@heir//lib/Utils/Polynomial:ChebyshevPatersonStockmeyer",
37+
"@heir//lib/Utils/Polynomial:Horner",
38+
"@heir//lib/Utils/Polynomial:PatersonStockmeyer",
3739
"@llvm-project//llvm:Support",
3840
"@llvm-project//mlir:ArithDialect",
3941
"@llvm-project//mlir:IR",

lib/Transforms/LowerPolynomialEval/LowerPolynomialEval.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,30 @@ struct LowerPolynomialEval
2525
MLIRContext* context = &getContext();
2626
RewritePatternSet patterns(context);
2727

28-
if (method == PolynomialApproximationMethod::Automatic) {
29-
patterns.add<LowerViaHorner, LowerViaPatersonStockmeyerChebyshev,
30-
LowerViaPatersonStockmeyerMonomial>(context,
31-
/*force=*/false);
32-
} else if (method == PolynomialApproximationMethod::Horner) {
33-
patterns.add<LowerViaHorner>(context, /*force=*/true);
34-
} else if (method == PolynomialApproximationMethod::PatersonStockmeyer) {
35-
patterns.add<LowerViaPatersonStockmeyerMonomial>(context,
36-
/*force=*/true);
37-
} else if (method ==
38-
PolynomialApproximationMethod::PatersonStockmeyerChebyshev) {
39-
patterns.add<LowerViaPatersonStockmeyerChebyshev>(context,
40-
/*force=*/true);
41-
} else {
42-
getOperation()->emitError() << "Unknown lowering method: " << method;
43-
signalPassFailure();
44-
return;
28+
switch (method) {
29+
case PolynomialApproximationMethod::Automatic:
30+
patterns.add<LowerViaHorner, LowerViaPatersonStockmeyerMonomial>(
31+
context, /*force=*/false);
32+
patterns.add<LowerViaPatersonStockmeyerChebyshev>(
33+
context,
34+
/*force=*/false, minCoefficientThreshold);
35+
break;
36+
case PolynomialApproximationMethod::Horner:
37+
patterns.add<LowerViaHorner>(context, /*force=*/true);
38+
break;
39+
case PolynomialApproximationMethod::PatersonStockmeyer:
40+
patterns.add<LowerViaPatersonStockmeyerMonomial>(context,
41+
/*force=*/true);
42+
break;
43+
case PolynomialApproximationMethod::PatersonStockmeyerChebyshev:
44+
patterns.add<LowerViaPatersonStockmeyerChebyshev>(
45+
context,
46+
/*force=*/true, minCoefficientThreshold);
47+
break;
48+
default:
49+
getOperation()->emitError() << "Unknown lowering method: " << method;
50+
signalPassFailure();
51+
return;
4552
}
4653

4754
walkAndApplyPatterns(getOperation(), std::move(patterns));

lib/Transforms/LowerPolynomialEval/LowerPolynomialEval.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def LowerPolynomialEval : Pass<"lower-polynomial-eval"> {
4444
clEnumValN(mlir::heir::PolynomialApproximationMethod::PatersonStockmeyerChebyshev,
4545
"pscheb", "Paterson-Stockmeyer method (Chebyshev basis)")
4646
)}]>,
47+
Option<"minCoefficientThreshold", "min-coefficient-threshold", "double",
48+
/*default=*/"1e-12",
49+
"Minimum threshold for coefficients to be included in the lowered polynomial. "
50+
"Coefficients with absolute value below this threshold will be dropped.">,
4751
];
4852
}
4953

lib/Transforms/LowerPolynomialEval/Patterns.cpp

Lines changed: 37 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,22 @@
11
#include "lib/Transforms/LowerPolynomialEval/Patterns.h"
22

3-
#include <algorithm>
4-
#include <cmath>
53
#include <cstdint>
6-
#include <vector>
74

85
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
96
#include "lib/Dialect/Polynomial/IR/PolynomialOps.h"
107
#include "lib/Kernel/ArithmeticDag.h"
118
#include "lib/Kernel/IRMaterializingVisitor.h"
12-
#include "lib/Kernel/KernelImplementation.h"
13-
#include "lib/Utils/MathUtils.h"
149
#include "lib/Utils/Polynomial/ChebyshevPatersonStockmeyer.h"
10+
#include "lib/Utils/Polynomial/Horner.h"
11+
#include "lib/Utils/Polynomial/PatersonStockmeyer.h"
1512
#include "lib/Utils/Polynomial/Polynomial.h"
1613
#include "lib/Utils/Utils.h"
17-
#include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project
18-
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
19-
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
20-
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
21-
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
22-
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
23-
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
24-
#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
14+
#include "llvm/include/llvm/ADT/SmallVectorExtras.h" // from @llvm-project
15+
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
16+
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
17+
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
18+
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
19+
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
2520
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
2621
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
2722
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
@@ -30,7 +25,6 @@
3025
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
3126
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
3227
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
33-
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
3428

3529
#define DEBUG_TYPE "lower-polynomial-eval"
3630

@@ -46,12 +40,6 @@ using polynomial::TypedFloatPolynomialAttr;
4640

4741
LogicalResult LowerViaHorner::matchAndRewrite(EvalOp op,
4842
PatternRewriter& rewriter) const {
49-
Type evaluatedType = op.getValue().getType();
50-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
51-
b.setInsertionPoint(op);
52-
53-
LLVM_DEBUG(llvm::dbgs() << "evaluatedType: " << evaluatedType << "\n");
54-
5543
auto attr =
5644
dyn_cast<polynomial::TypedFloatPolynomialAttr>(op.getPolynomialAttr());
5745
if (!attr) return failure();
@@ -62,154 +50,61 @@ LogicalResult LowerViaHorner::matchAndRewrite(EvalOp op,
6250
const int degreeThreshold = 5;
6351
if (!shouldForce() && maxDegree > degreeThreshold) return failure();
6452

53+
// Convert coefficient map to std::map<int64_t, double>
6554
auto monomialMap = attr.getValue().getPolynomial().getCoeffMap();
66-
DenseMap<int64_t, TypedAttr> attributeMap;
55+
std::map<int64_t, double> coefficients;
6756
for (auto& [key, monomial] : monomialMap) {
68-
attributeMap.insert(
69-
{key, getScalarOrDenseAttr(evaluatedType, monomial.getCoefficient())});
57+
double coeffValue = monomial.getCoefficient().convertToDouble();
58+
coefficients[key] = coeffValue;
7059
}
7160

72-
// Start with the coefficient of the highest degree term
73-
Value result =
74-
arith::ConstantOp::create(b, evaluatedType, attributeMap[maxDegree]);
75-
76-
// Apply Horner's method, accounting for possible missing terms
77-
auto x = op.getOperand();
78-
for (int64_t i = maxDegree - 1; i >= 0; i--) {
79-
// Multiply by x
80-
result = arith::MulFOp::create(b, result, x);
61+
// Create ArithmeticDag nodes
62+
auto xNode =
63+
kernel::ArithmeticDagNode<kernel::SSAValue>::leaf(op.getOperand());
64+
auto resultNode =
65+
polynomial::hornerMonomialPolynomialEvaluation(xNode, coefficients);
8166

82-
// Add coefficient if this term exists, otherwise continue
83-
if (attributeMap.find(i) != attributeMap.end()) {
84-
auto coeffConst =
85-
arith::ConstantOp::create(b, evaluatedType, attributeMap.at(i));
86-
result = arith::AddFOp::create(b, result, coeffConst);
87-
}
88-
}
67+
// Use IRMaterializingVisitor to convert to MLIR
68+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
69+
kernel::IRMaterializingVisitor visitor(b, op.getValue().getType());
70+
Value finalOutput = resultNode->visit(visitor);
8971

90-
rewriter.replaceOp(op, result);
72+
rewriter.replaceOp(op, finalOutput);
9173
return success();
9274
}
9375

9476
LogicalResult LowerViaPatersonStockmeyerMonomial::matchAndRewrite(
9577
EvalOp op, PatternRewriter& rewriter) const {
96-
Type evaluatedType = op.getValue().getType();
97-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
98-
b.setInsertionPoint(op);
99-
10078
auto attr =
10179
dyn_cast<polynomial::TypedFloatPolynomialAttr>(op.getPolynomialAttr());
10280
if (!attr) return failure();
10381

10482
FloatPolynomial polynomial = attr.getValue().getPolynomial();
10583
auto terms = polynomial.getTerms();
106-
10784
int64_t maxDegree = terms.back().getExponent().getSExtValue();
10885
const int degreeThreshold = 5;
10986
if (!shouldForce() && maxDegree > degreeThreshold) return failure();
11087

88+
// Convert coefficient map to std::map<int64_t, double>
11189
auto monomialMap = attr.getValue().getPolynomial().getCoeffMap();
112-
DenseMap<int64_t, TypedAttr> attributeMap;
90+
std::map<int64_t, double> coefficients;
11391
for (auto& [key, monomial] : monomialMap) {
114-
attributeMap[key] =
115-
getScalarOrDenseAttr(evaluatedType, monomial.getCoefficient());
92+
double coeffValue = monomial.getCoefficient().convertToDouble();
93+
coefficients[key] = coeffValue;
11694
}
11795

118-
// Choose k optimally - sqrt of maxDegree is typically a good choice
119-
int64_t k = std::max(static_cast<int64_t>(std::ceil(std::sqrt(maxDegree))),
120-
static_cast<int64_t>(1));
121-
122-
// Precompute x^1, x^2, ..., x^k
123-
Value x = op.getOperand();
124-
std::vector<Value> xPowers(k + 1);
125-
xPowers[0] =
126-
arith::ConstantOp::create(b, evaluatedType, b.getOneAttr(evaluatedType));
127-
xPowers[1] = x;
128-
for (int64_t i = 2; i <= k; i++) {
129-
if (i % 2 == 0) {
130-
// x^{2k} = (x^{k})^2
131-
xPowers[i] =
132-
arith::MulFOp::create(b, xPowers[i / 2], xPowers[i / 2]).getResult();
133-
} else {
134-
// x^{2k+1} = x^{k}x^{k+1}
135-
xPowers[i] = arith::MulFOp::create(b, xPowers[i / 2], xPowers[i / 2 + 1])
136-
.getResult();
137-
}
138-
}
139-
140-
// Number of chunks we'll need
141-
int64_t m =
142-
static_cast<int64_t>(std::ceil(static_cast<double>(maxDegree + 1) / k));
143-
std::vector<Value> chunkValues(m, nullptr);
144-
145-
for (int64_t i = 0; i < m; i++) {
146-
// Start with coefficient of degree (i+1)*k-1, if present
147-
int64_t highestDegreeInChunk = std::min((i + 1) * k - 1, maxDegree);
148-
int64_t lowestDegreeInChunk = i * k;
96+
// Create ArithmeticDag nodes
97+
auto xNode =
98+
kernel::ArithmeticDagNode<kernel::SSAValue>::leaf(op.getOperand());
99+
auto resultNode = polynomial::patersonStockmeyerMonomialPolynomialEvaluation(
100+
xNode, coefficients);
149101

150-
Value chunkValue = nullptr;
151-
bool hasTerms = false;
152-
153-
for (int64_t j = lowestDegreeInChunk; j <= highestDegreeInChunk; j++) {
154-
if (attributeMap.count(j)) {
155-
// Get the power index relative to the chunk's starting point
156-
int64_t powerIndex = j - lowestDegreeInChunk;
157-
158-
Value coeff =
159-
arith::ConstantOp::create(b, evaluatedType, attributeMap[j]);
160-
Value term;
161-
162-
if (powerIndex == 0) {
163-
term = coeff; // x^0 = 1
164-
} else {
165-
term = arith::MulFOp::create(b, coeff, xPowers[powerIndex]);
166-
}
167-
168-
if (!hasTerms) {
169-
chunkValue = term;
170-
hasTerms = true;
171-
} else {
172-
chunkValue = arith::AddFOp::create(b, chunkValue, term);
173-
}
174-
}
175-
}
176-
177-
if (hasTerms) {
178-
chunkValues[i] = chunkValue;
179-
} else {
180-
chunkValues[i] = arith::ConstantOp::create(b, evaluatedType,
181-
b.getZeroAttr(evaluatedType));
182-
}
183-
}
184-
185-
// Combine chunks using Horner's method with x^k
186-
Value result = nullptr;
187-
bool hasNonEmptyChunk = false;
188-
189-
for (int64_t i = m - 1; i >= 0; i--) {
190-
if (chunkValues[i]) {
191-
if (!hasNonEmptyChunk) {
192-
// First non-empty chunk encountered
193-
result = chunkValues[i];
194-
hasNonEmptyChunk = true;
195-
} else {
196-
// Multiply previous result by x^k and add this chunk
197-
result = arith::MulFOp::create(b, result, xPowers[k]);
198-
result = arith::AddFOp::create(b, result, chunkValues[i]);
199-
}
200-
} else if (hasNonEmptyChunk) {
201-
// Empty chunk but we have previous chunks
202-
result = arith::MulFOp::create(b, result, xPowers[k]);
203-
}
204-
}
205-
206-
// Handle the case where no terms were found
207-
if (!hasNonEmptyChunk) {
208-
result = arith::ConstantOp::create(b, evaluatedType,
209-
b.getZeroAttr(evaluatedType));
210-
}
102+
// Use IRMaterializingVisitor to convert to MLIR
103+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
104+
kernel::IRMaterializingVisitor visitor(b, op.getValue().getType());
105+
Value finalOutput = resultNode->visit(visitor);
211106

212-
rewriter.replaceOp(op, result);
107+
rewriter.replaceOp(op, finalOutput);
213108
return success();
214109
}
215110

@@ -261,7 +156,7 @@ LogicalResult LowerViaPatersonStockmeyerChebyshev::matchAndRewrite(
261156
SSAValue xNode(xInput);
262157

263158
auto resultNode = polynomial::patersonStockmeyerChebyshevPolynomialEvaluation(
264-
xNode, chebCoeffs);
159+
xNode, chebCoeffs, getMinCoefficientThreshold());
265160

266161
IRMaterializingVisitor visitor(b, op.getValue().getType());
267162
Value finalOutput = resultNode->visit(visitor);

lib/Transforms/LowerPolynomialEval/Patterns.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace mlir {
1212
namespace heir {
1313

1414
struct LoweringBase : public OpRewritePattern<polynomial::EvalOp> {
15-
LoweringBase(mlir::MLIRContext* context, bool force = false)
15+
LoweringBase(MLIRContext* context, bool force = false)
1616
: mlir::OpRewritePattern<polynomial::EvalOp>(context), force(force) {}
1717

1818
bool shouldForce() const { return force; }
@@ -23,6 +23,21 @@ struct LoweringBase : public OpRewritePattern<polynomial::EvalOp> {
2323
const bool force;
2424
};
2525

26+
struct ChebyshevLoweringBase : public LoweringBase {
27+
using LoweringBase::LoweringBase;
28+
29+
ChebyshevLoweringBase(MLIRContext* context, bool force = false,
30+
double minCoefficientThreshold = 1e-12)
31+
: LoweringBase(context, force),
32+
minCoefficientThreshold(minCoefficientThreshold) {}
33+
34+
double getMinCoefficientThreshold() const { return minCoefficientThreshold; }
35+
36+
private:
37+
// Minimum threshold for coefficients to be included in the lowered polynomial
38+
const double minCoefficientThreshold;
39+
};
40+
2641
// Lower polynomial.eval that uses a monomial float polynomial to a series of
2742
// adds and muls via Horner's method. Supports scalar and tensor operands of
2843
// floating point types.
@@ -46,8 +61,8 @@ struct LowerViaPatersonStockmeyerMonomial : public LoweringBase {
4661
// Lower polynomial.eval that uses a Chebyshev float polynomial to a series of
4762
// adds and muls via the Paterson-Stockmeyer method. Supports scalar and tensor
4863
// operands of floating point types.
49-
struct LowerViaPatersonStockmeyerChebyshev : public LoweringBase {
50-
using LoweringBase::LoweringBase;
64+
struct LowerViaPatersonStockmeyerChebyshev : public ChebyshevLoweringBase {
65+
using ChebyshevLoweringBase::ChebyshevLoweringBase;
5166

5267
LogicalResult matchAndRewrite(polynomial::EvalOp op,
5368
PatternRewriter& rewriter) const override;

0 commit comments

Comments
 (0)