Skip to content

Conversation

@bon-cdp
Copy link
Contributor

@bon-cdp bon-cdp commented Nov 4, 2025

Summary

This PR addresses #2351 by implementing a multi-dimensional cost model for HEIR's layout optimization. I realize this may be outside the box compared to what you were thinking, and I'm completely open to simplifying if this feels like too much too soon; but I wanted to share the approach I landed on after digging into the compiler literature and thinking about how to make the cost model extensible.

The Problem

From your comment on #2351, I understood the key challenges as:

  1. Need to track multiple cost metrics (not just rotations)
  2. Cost models should be swappable/configurable
  3. Backend-specific costs vary (OpenFHE vs Lattigo vs SEAL)
  4. Configuration options (parallelism, OpenMP) affect costs
  5. Eventually need scheduling simulation

My Approach

After reading Muchnick's compiler book (Chapter 17 on instruction scheduling) and seeing how classical compilers handle heterogeneous costs, I thought about applying causal reasoning to model FHE operation costs. Here's the thinking:

Core insight: Operations don't have fixed costs - their cost depends on why they're expensive:

  • Multiplication → noise growth → might trigger bootstrap → latency spike
  • Rotation → key switching → memory overhead + latency
  • Backend/scheme/params → confound everything

So instead of flat cost tables, I built a causal graph that models these dependencies explicitly.

The Causal DAG (as implemented in CausalGraph.cpp)

EXOGENOUS VARIABLES (what we observe/control):
├── op_type (add=1.0, mul=2.0, rotate=3.0)
├── operand_types (ct-pt=1.0, ct-ct=2.0)
├── rotation_offset
├── backend (OpenFHE=1.0, Lattigo=2.0, SEAL=3.0)
├── scheme (CKKS=1.0, BGV=2.0, BFV=3.0)
├── security_params (ring_dimension / 16384)
├── hardware_config (thread_count)

INTERMEDIATE EFFECTS (mediators):
├── noise_growth
├── depth_consumed
├── key_switching
├── relinearization
├── critical_path
├── parallelizable
├── depth_remaining
├── memory_pressure

OBSERVABLE OUTCOMES (what we optimize):
├── latency (milliseconds)
├── memory (bytes)
├── noise_level
└── throughput (operations/sec)

Edge weights (the actual causal relationships coded):

Direct causal effects:

  • op_type → noise_growth (weight 2.0 for multiply, 0.1 for add)
  • op_type → depth_consumed (weight 1.0 for multiply)
  • op_type → relinearization (weight 1.0 for multiply)
  • op_type → key_switching (weight 1.0 for rotate)
  • rotation_offset → key_switching (weight 1.0)
  • noise_growth → noise_level (weight 1.0)
  • key_switching → latency (weight 10.0, expensive!)
  • key_switching → memory (weight 5.0, rotation keys are large)
  • relinearization → latency (weight 8.0)
  • relinearization → memory (weight 3.0)
  • depth_consumed → depth_remaining (weight -1.0)
  • depth_remaining → latency (weight -5.0, low budget triggers bootstrap)
  • critical_path → latency (weight 2.0 penalty)
  • parallelizable → throughput (weight 3.0 boost)
  • memory_pressure → latency (weight 1.5, cache misses)

Confounding paths (not direct causes):

  • backend → latency, noise_growth, key_switching (implementation differences)
  • scheme → noise_growth, depth_consumed (algorithmic differences)
  • security_params → latency, memory, noise_level (parameter choices)
  • hardware_config → parallelizable, latency (more threads = faster)

id est:
mul_depth_dag

Why this structure matters:

Multiplication is expensive through three causal paths:

  1. Causes quadratic noise growth (weight 2.0)
  2. Consumes multiplicative depth → may need bootstrap (+45ms)
  3. Requires relinearization → key switching (weight 8.0 latency)

Rotation is expensive through one main path:

  • Requires key switching (weight 10.0 latency + weight 5.0 memory)

Backend affects everything but is a confounder not a cause; we model this separately so learned weights don't confuse correlation with causation.

What's Implemented

1. Multi-Dimensional Cost Tracking

cpp
struct CostMetrics {
  double latency_ms;
  size_t memory_bytes;
  int64_t depth_consumed;    // NEW
  int64_t rotations;
  double noise_growth;
};

2. Multiplicative Depth Visitor (Fixed)

Parallel to RotationCountVisitor (from PR #2347), now correctly tracks FHE depth:

  • Only counts ciphertext-ciphertext multiplications (plaintext-ciphertext is free)
  • Tracks secret status through DAG (similar to rotation fix in Add DAG-based kernel cost model for layout optimization #2347)
  • mul: +1 depth only if both operands are ciphertext
  • power(n): +ceil(log2(n)) depth only if base is ciphertext
  • rotation/add: +0 depth (automorphisms don't consume depth)
  • 8 comprehensive tests validating correct FHE depth reporting

3. Causal Cost Model Foundation

cpp
class CostModel {  // Abstract interface for swappable models
  virtual CostMetrics computeCost(Operation* op, CostContext& ctx) = 0;
};

class CausalCostModel : public CostModel {
  CausalGraph causal_graph_;  // The DAG shown above
  
  CostMetrics computeCost(Operation* op, CostContext& ctx) override {
    // 1. Extract features (op_type, rotation_offset, etc.)
    // 2. Set as interventions in causal graph: do(op_type = mul)
    // 3. Propagate through DAG edges using topological sort
    // 4. Adjust for context (critical_path, parallelizable)
    // 5. Return predicted CostMetrics
  }
};

4. Context-Aware Cost Adjustments

After computing base cost from causal DAG, we adjust for context:

  • Critical path: latency *= 1.5 (can't parallelize)
  • Parallelizable: latency /= min(thread_count, 4)
  • Depth remaining < 1: latency += 45ms (bootstrap!)
  • Memory pressure: latency *= 1.3 (cache misses)

Why Causal Reasoning?

  1. Extensibility: Add new metrics by adding nodes/edges to DAG
  2. Correctness: Distinguishes causation from correlation
  3. Data-driven: Can learn edge weights from microbenchmarks using do-calculus
  4. Testable: Can validate with interventions (e.g., does increasing rotations cause latency increase?)

What's NOT Implemented (Yet)

This is foundation only:

  • Not integrated into `LayoutOptimization` pass yet
  • No microbenchmark suite (edge weights are placeholders)
  • No learned weights from real data
  • No scheduling simulator

I wanted to get feedback on the approach before going further. However I did have a cool idea for certain causal vertices being at different float per our office-hours talk, as well as potential to implement some type of ERM (check out https://proceedings.neurips.cc/paper/2020/file/95a6fc111fa11c3ab209a0ed1b9abeb6-Paper.pdf for some inspiration off the top of my head).

Testing

  • All 9 existing rotation count tests pass (no regressions)
  • 8 multiplicative depth tests pass with CORRECT FHE depths
    • Halevi-Shoup: depth 0 (plaintext matrix × ciphertext vector) ✓
    • Power operations: correct depth via optimal evaluation
    • All depths match real FHE behavior
  • CausalCostModel compiles and links correctly

Alternative: Simpler Approach

If this feels too complex, I can simplify to:

  1. Just add depth tracking (keep `MultiplicativeDepthVisitor`)
  2. Simple cost table: `map<(op_type, backend) → CostMetrics>`
  3. Skip causal reasoning entirely

The causal framework is appealing because it naturally handles the config/backend/parallelism interactions you mentioned, but I completely understand if you'd prefer starting simpler.

Questions for You

  1. Is the causal reasoning approach interesting, or should I simplify?
  2. Should I integrate into `LayoutOptimization` now, or wait for feedback?
  3. For microbenchmarks, would you prefer I build the harness or collaborate with someone already measuring OpenFHE costs?

Thanks for your patience with this somewhat experimental approach; happy to iterate based on your feedback!

Related Issues

Addresses #2351
Builds on #2347 (rotation counting)

Signed-off-by: bon-cdp [email protected]

Critical Fix: Multiplicative Depth Tracking

Update (latest commit): Fixed MultiplicativeDepthVisitor to match the behavior of RotationCountVisitor from PR #2347.

The Bug: Previously counted ALL multiplications as depth +1, even plaintext-ciphertext multiplications (which are free in FHE).

The Fix: Now tracks secret status through the DAG and only counts ciphertext-ciphertext multiplications:

  • plaintext × ciphertext: depth 0 (free scalar multiplication in FHE) ✓
  • ciphertext × ciphertext: depth 1 (expensive, increases noise) ✓

Impact: Cost models now report CORRECT FHE depths. For example:

  • Dense layer W×x (W plaintext, x ciphertext): depth 0 (was 1)
  • Activation (z ciphertext): depth 1 (correct)
  • 2-layer network: depth 2 (was 4) ✓

This brings MultiplicativeDepthVisitor in line with how RotationCountVisitor handles plaintext rotations (which also cost nothing).

…le#2351

This PR implements a foundation for multi-dimensional cost modeling
that tracks rotation count, multiplicative depth, latency, memory,
and noise growth. The implementation uses causal reasoning to model
how operation properties (e.g., mul, rotate) causally affect costs
through intermediate effects (noise growth, key switching).

Key components:
- CausalGraph: Models causal dependencies between operation properties
  and observable costs (e.g., mul → noise → depth → bootstrap cost)
- CostModel: Abstract interface for swappable cost model implementations
- CausalCostModel: Context-aware cost computation using causal inference
- MultiplicativeDepthVisitor: Tracks multiplicative depth consumption
  (parallel to RotationCountVisitor from PR google#2347)

The causal approach provides:
- Extensibility: Add new cost metrics without rebuilding
- Correctness: Distinguishes causation from correlation
- Context-awareness: Same operation costs differently based on DAG context
  (critical path, parallelism, depth budget)

Testing:
- 8 new tests for MultiplicativeDepthVisitor (all passing)
- All 9 existing rotation count tests still pass (no regressions)

This is foundation only - not yet integrated into LayoutOptimization.
Seeking feedback on approach before proceeding further.

Addresses google#2351

Signed-off-by: bon-cdp <[email protected]>
…ltiplications

Similar to PR google#2347's fix for RotationCountVisitor, this updates
MultiplicativeDepthVisitor to correctly distinguish between plaintext
and ciphertext operations.

Key changes:
- Track secret status (isSecret) for each DAG node
- Only count depth for ciphertext × ciphertext multiplications
- Plaintext × ciphertext multiplications are depth 0 (free in FHE)
- Plaintext power operations are depth 0

Impact on test results:
- Iris classifier: depth 4 (was 6) - CORRECT for real FHE!
- Dense layer: depth 0 (was 1) - plaintext weights are free
- Multi-layer networks: depth = number of layers (was 2× layers)

This makes depth tracking match actual FHE behavior where plaintext
operations are essentially free compared to ciphertext operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant