Skip to content

First-class general tensor products (fused + contracted + free indices) in the expression layer#562

Open
evaleev wants to merge 28 commits into
masterfrom
evaleev/feature/general-product-expr
Open

First-class general tensor products (fused + contracted + free indices) in the expression layer#562
evaleev wants to merge 28 commits into
masterfrom
evaleev/feature/general-product-expr

Conversation

@evaleev

@evaleev evaleev commented Jun 11, 2026

Copy link
Copy Markdown
Member

Summary

Adds native support for general binary tensor products — fused (Hadamard), contracted, and free indices coexisting, e.g. C("b,i,k") = A("b,i,j") * B("b,j,k") — to the expression layer (MultEngine/ContEngine) and evaluates them with a batched Summa: one distributed task graph in one World. einsum() now routes its generalized-contraction branch through this path by default, eliminating its per-Hadamard-slab decomposition (one MPI_Comm_split + sub-World + make_array + fence per slab).

On the motivating workload (n-hexane PNO-CCSD/cc-pVDZ in MPQC, block-sparse arena tensor-of-tensors), the per-run sub-World count drops 1412 → 0 and the einsum-region attribution becomes pure evaluation (was ~50% machinery: retile/make_array 25%, per-slab fences 30%, harvest+teardown 8%).

Design

  • compute_product_type(left, right, target) now returns the (previously unreachable) TensorProduct::General when a shared index survives into the target alongside contracted/free indices; the bottom-up 2-arg overload is unchanged (shared ⇒ contracted).
  • New GeneralPermutationOptimizer: canonical layouts A(h,e_A,c), B(h,c,e_B), C(h,e_A,e_B) — the GEMM-canonical layout with the fused modes leading, so the tile op folds them into the tile batch dimension by zero-copy reshape.
  • Summa generalized in place to batched contractions: optional slab count nh (default 1 = exactly the prior behavior); iteration in steps s = h*k + k; slab-offset tile ordinals; (h,k)-keyed sparse masks/groups; per-slab reduce tasks. The owner of a tile is independent of its slab (SlabbedPmap), so one slab's contraction is fully distributed over the same 2-d grid and slabs overlap in the task pipeline with no inter-slab barriers.
  • BatchedContractReduce adapts a folded-rank ContractReduce to fused-mode-carrying tiles (Tensor::gemm's nbatch loop and the arena ToT kernels already speak this convention).
  • SparseShape::gemm_batched: slab-batched norm contraction for the result shape.
  • ToT composition: the inner-tile-op builders classify the outer regime via outer_product_uses_summa() (Contraction or General ⇒ ContractReduce semantics), so arena plans and the strided-DGEMM kernels install as for a pure contraction.

einsum cutover & differential harness

Three-way runtime control:

  • default: expression route;
  • TA_EINSUM_LEGACY_SUBWORLD (or detail::einsum_legacy_subworld()): forces the legacy per-slab sub-World path, retained indefinitely as the reference implementation;
  • TA_EINSUM_DIFFERENTIAL: evaluates every general product by both routes, compares norms, and reports mismatching contractions with per-tile forensics — this is how the bug below was found.

Also adds TA_EINSUM_INSTRUMENT, a runtime-gated attribution profiler for the einsum region (per-call time buckets: retile, comm-split, contract+fence, harvest, …).

Drive-by bug fix (pre-existing)

The differential harness exposed a latent bug in Tensor::gemm's ToT×scalar strided scale paths: the per-row cleanliness probe stops at the first absent cell, and the subsequent A <= 0 ⇒ empty row shortcut dropped the entire row's contributions even when later cells were present. The legacy einsum route dodged it by accident (its canonical layout fails the path's NoTranspose gate). Fixed for both orientations + regression test.

Known, intentional route differences

  • The legacy path derives the result shape from harvested tile norms and thus implicitly hard-zeroes sub-threshold result tiles; the expression route keeps them (standard estimate-derived contraction shape). Per the TA screening philosophy norms are trusted as genuine and no implicit truncation is performed — call truncate() explicitly if desired. Downstream consumers may see ~1e-7-scale shifts vs legacy-einsum-derived baselines.
  • General products at inner nodes of an expression tree (e.g. THC-style X("p,r1") * X("q,r1") * Z("r1,r2") * …) cannot be classified bottom-up; they now produce an informative error suggesting explicit intermediates (top-down index-set deduction is future work). Targets that interleave fused and free modes are supported through einsum() (canonicalize + permute); native engine support is future work.

Testing

New general_product_suite (23 cases): classification, optimizer layouts, and differential tests against the legacy einsum oracle — dense/block-sparse, plain/ToT/mixed ToT×T, owning and arena (view) inner cells, variable inner extents, screened (absent) cells, non-leading fused indices, interleaved targets, batched outer products, THC gating + workaround; np = 1–4. All existing suites pass unchanged with the new default (einsum suites validated against their reference data). End-to-end validated in MPQC PNO-CCSD via the differential mode.

evaleev added 9 commits June 11, 2026 01:47
Buckets per einsum call: entry_fence / setup / commsplit+world /
retile/make_array / contract+fence / harvest / local_kernel / teardown,
keyed by branch (hadamard-reduction-local, generalized-subworld,
generalized-inner-perm-recurse) and contraction annotation; dumped to
stderr at exit. Zero overhead when disabled. Establishes the baseline
attribution for replacing the per-Hadamard-tile sub-World decomposition
with first-class general-product (h+e+c) support.

PNO-CCSD c6h14/cc-pVDZ baseline (np=1, 3 CC iters): 17.9 s einsum-region,
1412 Hadamard slices = 1412 sub-Worlds; retile/make_array 25.5% +
harvest 3.0% + teardown 4.2% non-numeric, contract+fence 30.5%.
…ral)

Phase A of first-class general-product (fused + contracted + free indices)
support in the expression layer (target: PNO-CC batched contractions,
replacing einsum's per-Hadamard-tile sub-World decomposition):

- compute_product_type(left, right, target) now returns
  TensorProduct::General when a shared index survives into the target
  (fused) alongside contracted and/or free indices, incl. the
  Hadamard-reduction case (args related by permutation, target drops
  indices). The 2-arg overload is unchanged (bottom-up convention:
  shared => contracted).
- new GeneralPermutationOptimizer: canonical layouts
  left (h, e_A, c), right (h, c, e_B), result (h, e_A, e_B) -- the
  GEMM-canonical layout with fused indices prepended so a consuming
  batched-GEMM op can fold them into the tile batch dimension by
  reshape; exposes the h/c/e_A/e_B partition for engine consumption.
  Requires target indices (fused-vs-contracted is undecidable
  bottom-up); validates against implicit reductions.
- BinaryEngine::init_indices_ and MultEngine/ScalMultEngine route
  General through the new optimizer; ContEngine::product_type()
  accessor admits General.
- evaluation is gated with an informative exception (use
  TiledArray::einsum() meanwhile) until the batched-Summa DistEval
  lands (Phase B); previously such expressions misclassified as pure
  contractions and died in target-permutation resolution.
- unit tests: classification, optimizer layouts/partitions/errors, and
  the end-to-end expression gate (tests/general_product.cpp).
Phase B step 1 of general-product support: Summa gains an optional slab
count nh (default 1 = exactly the prior, unbatched behavior). For nh > 1
the operands and result carry the fused (Hadamard) modes as leading
dimensions (left = (h,i,k), right = (h,k,j), result = (h,i,j)); the
contraction runs as nh independent SUMMA slabs over ONE shared 2-d
process grid and ONE task graph:

- iteration space becomes steps s = h*k_ + k; the step-task chain,
  depth control, and sparse step iteration (iterate_{row,col,sparse},
  skipped-range broadcasts) operate in step space
- every argument/result tile ordinal is offset by its slab base; the
  owner of a tile is independent of h (block-cyclic phase restarts per
  slab), so broadcast roots and the 2-d grid logic are unchanged
- per-step sparse broadcast groups are keyed by step (col: s, row:
  s + nsteps); the static dense groups use keys 2*nsteps, 2*nsteps+1;
  tile broadcast keys (global ordinals) are unique across slabs as-is
- reduce tasks: one per local result tile per slab
  (reduce_tasks_[h*local_size + i*local_cols + j]); initialize/finalize
  loop slab-by-slab
- sparse row/col masks take the slab index; get_tile owner computation
  mods out the slab

No caller passes nh yet (that lands with the General-product ContEngine
wiring); all existing suites pass unchanged.
Phase B step 2: dense (DensePolicy) general products now evaluate
natively in the expression layer, end-to-end:
C("b,i,k") = A("b,i,j") * B("b,j,k") runs as ONE distributed batched
Summa in one World -- no per-Hadamard-tile sub-Worlds.

- SlabbedPmap: replicates a base pmap over a leading slab dimension
  (owner of a tile is independent of its fused-index slab), used for the
  SUMMA phase maps of the arguments and the result pmap
- BatchedContractReduce: adapts a folded (fused-mode-free)
  ContractReduce to tiles carrying leading fused modes; folds them into
  the tile batch dimension by zero-copy reshape (modes lead => layout
  preserved), allocates the result with its full range up front, and
  lets Tensor::gemm's per-batch loop do the work; TA::Tensor tiles only
- ContEngine: init_struct_general / make_trange_general /
  make_shape_general / init_distribution_general / make_dist_eval_general
  -- fused-mode-prefixed result structure, per-slab 2-d process grid,
  batched Summa construction (nh = product of fused-mode tile extents,
  K = per-slab contracted tile count)
- MultEngine routes General to these; ScalMultEngine still gates
- not yet supported (clear errors): block-sparse shapes (per-slab shape
  gemm TODO), tensors-of-tensors, targets interleaving fused and free
  modes

Differential-tested against the einsum free function (norm(diff) <=
1e-10): multi-tile uneven dims, permuted argument layouts, batched outer
product; np = 1, 2, 3, 4. All existing suites unchanged.
Phase B3: SparsePolicy general products (fused + contracted + free
indices) now evaluate natively in the expression layer.

- SparseShape::gemm_batched(other, factor, gemm_helper, nfused): the
  batched analogue of gemm. The leading nfused modes of both shapes and
  the result are fused; each fused-index slab is contracted exactly as
  in gemm with the *folded* (fused-mode-free) GEMM helper. The
  contracted-mode size vector is slab-invariant (contracted modes follow
  the fused modes), the norm scaling loops extend over slabs naturally,
  and the per-slab norm GEMMs run as one batched Tensor::gemm via the
  zero-copy fused-modes-into-nbatch reshape; same hard-zero threshold
  pass and outer-product (k_rank == 0) handling as gemm.
- ContEngine::make_shape_general routes SparseShape to it (the dense
  branch is unchanged); this removes the last block-sparse gate, so the
  batched Summa's sparse path ((h,k)-keyed masks, groups, and step
  iteration, landed earlier) is now reachable.
- tests: block-sparse differential tests vs einsum (batched contraction
  and batched outer product, deterministic block-sparsity patterns);
  pass at np = 1, 2, 3, 4.
Phase C: ToT general products (fused + contracted + free outer indices,
nested inner product) now evaluate natively in the expression layer via
the batched Summa.

- the inner-tile-op builders (init_inner_tile_op and the owning-cell
  variant) now classify the outer regime via outer_product_uses_summa()
  (pure contraction OR general product): for both, the tile op is a
  ContractReduce consumed by a (batched) SUMMA, so the per-cell ops
  accumulate in place, no per-cell result permutation is applied, and
  the arena plans / strided-DGEMM ops install as for a pure contraction
- the strided-DGEMM install gates derive the outer-contracted rank from
  the fused-mode-free outer sizes (n_fused_outer_modes() helper)
- init_struct_general gains the ToT arm, mirroring init_struct: builds
  the folded-rank ContractReduce with the inner element op and the
  arena plan, installs the strided ce+e / ce+ce ops; a non-identity
  inner result permutation is gated (the batched op must be perm-free)
- BatchedContractReduce now admits ToT tiles: the folded result is
  allocated by the wrapped op itself (engaging its tile-type-specific
  construction, e.g. the arena reserve) and unfolded by a zero-copy
  reshape; this also gives plain tiles the beta=0 first-accumulation
  fast path
- MultEngine initializes the inner tile op before init_struct_general

Differential-tested against einsum (inner Hadamard and inner
contraction, owning cells) at np = 1, 2, 3; all existing suites
unchanged. Arena (view-cell) general products compile via the same
paths; their end-to-end validation comes with the mpqc/einsum cutover.
Phase D (partial): einsum can evaluate its generalized-contraction branch
through the expression layer's native general-product support (one batched
Summa in one World) instead of the legacy per-Hadamard-slab sub-World
decomposition. The engine receives the canonical (fused..., left-free...,
right-free...) result layout; arbitrary einsum targets are reached by a
final permutation assignment.

Three-way runtime control (detail::einsum_legacy_subworld /
detail::einsum_differential, env TA_EINSUM_LEGACY_SUBWORLD /
TA_EINSUM_DIFFERENTIAL):
- legacy (DEFAULT for now, see below)
- expression route (TA_EINSUM_LEGACY_SUBWORLD=0)
- differential: evaluates BOTH routes per general product, compares
  squared norms, reports mismatching contractions (with annotations) to
  stderr, returns the legacy result. The legacy path is retained
  indefinitely as the reference implementation for such testing.

Status: with the expression route, PNO-CCSD (c6h14/cc-pVDZ) runs with ZERO
sub-Worlds (legacy: 1412) and the einsum-region attribution collapses from
17.9 s to 12.4 s of pure evaluation -- but the energy is WRONG (-238.09 vs
-236.35). TA_EINSUM_DIFFERENTIAL isolates the mismatching shapes:
 (1) ToT x T (inner Scale) with a non-leading fused index and interleaved
     target, e.g. (i4,i1,mu;a) * (mu,i4,K) -> (i1,i4,K;a)
 (2) phantom-unit (denest-internal) general products, e.g.
     (mu,i1,i4;a) * (i1,i4,K;a,phantom) -> (mu,i1,K;phantom)
Synthetic unit reproductions of (1) with fixed inner extents PASS, so the
trigger involves CSV specifics (variable per-block inner extents and/or
arena cell layout); under investigation. Until resolved the legacy path is
the default and the expression route is opt-in.

All unit suites green in both modes; einsum suites were also validated
green against their reference data with the expression route as default
before the flip.
The GEMM-based ToT x scalar scale paths of Tensor::gemm (and the
T x ToT mirror) probe each row (column) for cleanliness; the presence
probe stops at the first ABSENT cell, leaving the probed inner size
A == -1 when the leading cell is absent. The subsequent 'A <= 0 =>
empty row, nothing to do' shortcut then dropped the ENTIRE row's
contributions even when later cells were present. Rows whose leading
contracted cell is absent (common for screened tensor-of-tensors, e.g.
PNO-CC CSV intermediates) silently lost their contraction.

Fix: when the probe ends with A <= 0, scan the full row (column) for
any present cell; only a fully absent row is skipped, anything else
takes the per-cell AXPY fallback. Also guard the engine's scale
fallback element op against absent cells.

This bug predates the general-product work but was masked on the
legacy einsum route, whose canonical operand layout fails the
NoTranspose gate of the strided path; the expression route's
GEMM-canonical layout exposed it. Found with TA_EINSUM_DIFFERENTIAL on
c6h14 PNO-CCSD: the opt-in expression-route energy error drops from
1.7 Eh to 6.9e-7 (the small residual is a screening-semantics
difference -- the legacy route hard-zeroes sub-threshold result tiles
that the engine route keeps -- plus a small systematic difference in
phantom-unit denest products, both under review).

Adds the CSV-like reproduction test (arena view cells, SparsePolicy,
variable inner extents, screened cells, non-leading fused index,
interleaved target): expression route and einsum routes agree to
1e-10, deterministically.
With the strided-scale-path fix in place, the TA_EINSUM_DIFFERENTIAL
audit of c6h14 PNO-CCSD shows the two routes agree except for:
- sub-threshold result tiles that the legacy path implicitly hard-zeroes
  (its result shape derives from the harvested tile norms) while the
  expression route keeps them (standard estimate-derived contraction
  shape). Per the TA screening philosophy, norms are trusted as genuine
  and no implicit truncation is performed; users wanting the tighter
  shape call truncate() explicitly.
- floating-point summation-order noise in tiny, heavily-cancelling
  tensors (absolute tile-norm^2 differences <= 1e-9, no structural
  pattern).

Neither is a defect, so general products in einsum now default to the
expression-layer evaluation (TensorProduct::General -> batched Summa:
one task graph in one World, ZERO per-slab sub-Worlds). The legacy path
remains available via TA_EINSUM_LEGACY_SUBWORLD (or
detail::einsum_legacy_subworld()) as the reference implementation for
differential testing.

All suites pass with the new default (einsum suites against their
reference data; the 2 pre-existing assign_subblock_block_base1 failures
are unrelated); np = 1, 2, 3.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds first-class support for general binary tensor products (fused + contracted + free indices coexisting) in the expression layer by introducing a batched-SUMMA evaluation path, and updates einsum() to route generalized contractions through this new path by default (retaining the legacy sub-World implementation for reference and differential testing).

Changes:

  • Added TensorProduct::General classification/layout support (new GeneralPermutationOptimizer) and integrated it into expression engines (MultEngine/ContEngine).
  • Generalized SUMMA to support batched (slabbed) contractions and introduced supporting utilities (SlabbedPmap, BatchedContractReduce, SparseShape::gemm_batched).
  • Updated einsum() routing + added runtime-gated instrumentation/differential modes, plus a new comprehensive general_product_suite test suite.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/general_product.cpp New test suite covering classification, optimizer layouts, expression-vs-legacy routing, and sparse/ToT/arena scenarios.
tests/CMakeLists.txt Adds the new general_product.cpp test target source.
src/TiledArray/tile_op/batched_contract_reduce.h New tile-op adapter to fold fused leading modes into a batch dimension for GEMM-based contraction/reduction.
src/TiledArray/tensor/tensor.h Fixes a ToT×scalar strided-scale “empty row/col” probe bug by correctly scanning for later non-empty cells.
src/TiledArray/sparse_shape.h Adds SparseShape::gemm_batched to compute slab-batched contraction shapes for general products.
src/TiledArray/pmap/slabbed_pmap.h New pmap that replicates a base mapping across a slab dimension (slab-independent ownership).
src/TiledArray/expressions/product.h Enables 3-arg classification to return TensorProduct::General when target keeps shared indices.
src/TiledArray/expressions/permopt.h Adds GeneralPermutationOptimizer and routes TensorProduct::General through it.
src/TiledArray/expressions/mult_engine.h Integrates general-product routing, inner-node gating checks, and general distribution/eval hooks.
src/TiledArray/expressions/cont_engine.h Implements general-product structure/distribution/evaluator (batched SUMMA + batched tile op).
src/TiledArray/expressions/binary_engine.h Extends index initialization template to allow TensorProduct::General optimizer selection.
src/TiledArray/einsum/tiledarray.h Adds routing toggles, differential mode, instrumentation hooks, and routes generalized contraction via expression layer by default.
src/TiledArray/einsum/einsum_instrument.h New lightweight, runtime-gated einsum attribution profiler.
src/TiledArray/dist_eval/contraction_eval.h Generalizes SUMMA implementation to batched slabs (step space expanded to nh * k).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/TiledArray/einsum/tiledarray.h
Comment thread src/TiledArray/tile_op/batched_contract_reduce.h
Comment thread src/TiledArray/pmap/slabbed_pmap.h
Comment thread src/TiledArray/expressions/cont_engine.h
Comment thread src/TiledArray/sparse_shape.h Outdated
Comment thread tests/general_product.cpp Outdated
Comment thread tests/general_product.cpp Outdated
Comment thread tests/general_product.cpp Outdated
evaleev added 5 commits June 11, 2026 18:43
…gruence checks

- einsum/tiledarray.h, tile_op/batched_contract_reduce.h, pmap/slabbed_pmap.h:
  include what is used (<cstdlib>, <string_view>, util/vector.h, <memory>,
  <utility>) instead of relying on transitive includes
- cont_engine.h: re-initialize K_ in init_distribution_general() (defensive;
  engines are single-use, but mirrors the n_slabs_ reset)
- sparse_shape.h: gemm_batched() now TA_ASSERTs that the argument ranks match
  the folded gemm ranks plus the fused modes and that the fused and contracted
  mode extents of the two shapes are congruent (the batched analogue of the
  checks GemmHelper::compute_matrix_sizes performs for plain gemm)
ScopedEinsumRoute restores the previous einsum_legacy_subworld() value on
scope exit (ForceLegacyEinsum is now the legacy=true special case), so a
throwing TA::einsum can no longer leak the toggle into later test cases.

Also restores einsum_expression_route_matches_legacy to its intent: it was
written when the legacy sub-World path was einsum's default, so after the
default flip (ef0066c) its "legacy" reference silently took the expression
route (vacuous comparison) and the trailing manual toggle left the legacy
path enabled for the rest of the test module.
…t target (Phase E)

An index shared by the two children of a product is fused iff the node's
target carries it, contracted otherwise; an index neither the sibling nor
the target carries is consumed within the child subtree and not demanded
of it. available_indices() (per-subtree leaf-annotation union, valid
before init) supplies the up-pass; each child dictates the ORDER of its
demand via preferred_layout() (canonical (fused, left-free, right-free)
for products, pass-through elsewhere). Expressions consumed without a
target (reductions) retain the bottom-up contraction convention.
…reaming re-permute

A target that differs from the canonical (fused..., left-free...,
right-free...) result layout cannot be folded into the batched tile op
(BatchedContractReduce must be perm-free); evaluate canonically (Summa
over a slab-replicated pmap) and re-permute to the target with a
streaming UnaryEvalImpl. Honors the implicit-permute contract: when the
consumer fuses the permutation into its own operation (transposed GEMM),
only the tile ordinals/trange are remapped and contents stay canonical.
Replaces the interleaved-target gate, enabling general products at inner
expression-tree nodes and non-canonical root targets.
evaleev added 10 commits June 12, 2026 10:00
- GeneralRepermuteOp: store/apply only the OUTER result-layout permutation.
  The streaming re-permute wrapper exists to reorder a general product's
  outer (result) layout; inner (within-cell) permutation of ToT results is
  handled separately (init_struct_general / implicit_permute_inner_). Using
  the full bipartite perm_ would also permute inner cells -- a no-op when the
  inner perm is identity, but a latent double-apply if an inner perm is ever
  deferred to a downstream op. Pass outer(perm_) into the op; the host
  UnaryEvalImpl still receives full perm_ for ordinal/trange remap.
- MultEngine down-pass: materialize each child demand as a named lvalue
  before preferred_layout(), which returns a reference to its argument for
  leaf/binary engines -- avoids a needlessly fragile bind-to-temporary.
…e-deduction down-pass

The Phase E child-demand deduction moves from MultEngine into
BinaryEngine::init_children_indices and ScalMultEngine adopts it, along
with the full MultEngine routing for general products
(inner_product_type_ classification + inner-General gate,
init_struct_general, init_distribution_general, make_trange_general,
make_dist_eval_general), replacing its use-einsum-instead exception.
…factor in inner-Scale ops

The general-product ToT gate fired on a non-null but IDENTITY inner
permutation (the bipartite perm is constructed whole when only the outer
modes are re-permuted by the streaming wrapper); require a genuinely
non-identity inner perm. The inner-Scale element ops (mixed T x ToT)
never carried the expression-level scalar prefactor -- invisible while
only MultEngine (factor == 1) reached them; the fallback op now absorbs
factor_ and the factor-free fused arena ops are gated to factor == 1.
…der products, kitchen-sink, blocks in trees

A ToT x ToT general product with no external (free) outer indices --
every outer index fused or contracted -- segfaulted in the folded GEMM;
gate it with an informative error (einsum() evaluates this shape
natively via its no-external regime).

New tests: a SUM nested under a product with a general summand (the
down-pass prunes summand-internal contraction indices from the sum's
demand by construction); the kitchen-sink expression combining a
THC-like batching index, a mixed T x ToT general product, a ToT x ToT
general product with an inner outer-product, and a ScalMult prefactor;
a block leaf under an inner general node; a re-permuted general product
assigned into a block view; the no-external gate.
… left-external mode

A general product whose every outer index is fused or contracted (e.g.
C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) folds to a GEMM with no
free modes, i.e. rank-0 tensors, which the tile kernels do not support
(this shape used to segfault through wild stride reads). Evaluate it
with a synthetic unit left-external mode instead: the folded product
becomes (1,K) x (K) -> (1), the exact shape of the already-supported
one-sided neB == 0 case. The unit mode lives only in the tile op's
GemmHelper; tranges, shapes and tiles carry the true (external-free)
ranks, and BatchedContractReduce / SparseShape::gemm_batched detect the
synthetic mode from the one-rank mismatch and pad their folded views
with a unit extent. Replaces the interim gate.

Tests: dense ToT (incl. the no-external root fed by a general T x ToT
inner node), plain dense (the Hadamard-reduction shape), and
block-sparse (exercising the gemm_batched unit handling), all
differential-tested against legacy einsum.
Infrastructure for a 3-d (proc_h x proc_r x proc_c) batched-Summa grid
that distributes the fused/batch (h, slab) dimension of a general product
across process planes:

- ProcGrid gains a rank-subset constructor (tagged rank_subset to avoid
  colliding with the same-arity test-only ctor) that builds a 2-d grid
  over a contiguous interval [rank_offset, rank_offset + nprocs) of the
  world's ranks; map_row/map_col and the row/col group factories emit
  world-correct ranks via the offset. The legacy full-world ctor is
  unchanged (offset 0).
- SlabbedPmap gains a 3-d variant (proc_h, proc_h_stride): slab h belongs
  to plane h % proc_h of proc_h_stride contiguous ranks, and the per-slab
  base map's plane-local owners are offset by the slab's plane. The
  original 3-argument form (proc_h == 1, slab-replicated) is unchanged.
Distribute the fused/batch (h, slab) dimension of a general product over a
third process-grid axis proc_h, recovering parallelism when the result is
small (M*N result tiles < P ranks) -- most acutely no-external products
(M=N=1, e.g. the PNO-CCSD PPL intermediate), where the 2-d grid otherwise
degenerates to a single rank.

The world's first proc_h * proc_h_stride ranks form proc_h h-planes of
proc_h_stride = P/proc_h ranks; slab h is evaluated on plane h % proc_h,
which runs an ordinary 2-d SUMMA over its own (offset) process grid. Slabs
are communication-free (independent), so the surplus of ranks beyond one
result-tile-per-rank is spent on this axis. Summa carries per-plane state
(first_slab_, my_slabs_), restricts its slab iteration to the plane
(next_step), indexes reduce tasks by plane-local slab ordinal (slab_ord),
uses plane-unique dense broadcast keys, and computes the result-tile owner
(result_tile_owner) as the within-plane cyclic owner shifted by the plane's
world-rank offset -- matching set_tile's pmap-routed destination (the two
disagreeing was a get_tile/set_tile owner mismatch that deadlocked
cross-plane result transfers). proc_h == 1 reproduces the 2-d path exactly.

ContEngine::init_distribution_general sizes proc_h by a greedy heuristic
(spread ranks beyond min(P, M*N) over the slab axis, bounded by n_slabs)
and builds the plane-local grid + 3-d operand/result pmaps. A TODO marks
the principled co-optimization of proc_h with the 2-d aspect ratio from the
h/left-external/right-external element extents and a memory bound.
Adds general_product_distributed_suite (UNLABELED, so the CI harness runs
it at both np=1 and np=2; the existing general_product_suite is
serial-labeled and never exercised the batched Summa across ranks). Seven
differential cases vs the legacy sub-World einsum oracle: dense, sparse,
mixed T x ToT, no-external (dense + ToT), the one-expression THC
reconstruction, and dist_no_externals_3d_grid -- which engages the 3-d
(proc_h > 1) grid and asserts the no-external result distributes across the
h-planes rather than piling on one rank.
evaleev added 4 commits June 12, 2026 11:06
…uct-tree-deduction

expressions: tree-general index deduction (Phase E) — inner-node general products
…trees

expressions: mixed T x ToT products in arbitrary expression trees (Phase F)
- contraction_eval: clamp the SUMMA step-task pipeline depth to my_steps()
  (this rank's group's step count) instead of nsteps_. In the 3-d
  (proc_h_ > 1) case my_steps() < nsteps_, so clamping to nsteps_
  pre-spawned surplus step tasks that all resolved to the terminating
  step (k_ == nsteps_). No-op for the 2-d path (my_slabs_ == nh_).

- cont_engine: keep proc_h_stride_ == 0 for the ungrouped 2-d case
  (proc_h_ == 1), matching the field's documented invariant; only the
  grouped (proc_h_ > 1) grid uses P / proc_h_.

- general_product test: correct the distributed suite header comment --
  dist_inner_node_thc validates against explicit binary intermediates,
  not the legacy einsum oracle.
summa: 3-d (proc_h) process grid for batched general products
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants