einsum: local-slice fast path for batched (Hadamard) contractions#561
Closed
evaleev wants to merge 1 commit into
Closed
einsum: local-slice fast path for batched (Hadamard) contractions#561evaleev wants to merge 1 commit into
evaleev wants to merge 1 commit into
Conversation
The generalized batched-contraction path -- Hadamard indices coexisting with external/contracted indices, e.g. C(b,i,k) = A(b,i,j) * B(b,j,k) -- ran one MPI_Comm_split + a fresh sub-World + a sub-World fence per Hadamard tile. That is O(#Hadamard-tiles) collectives, and the per-tile sub-World construction/teardown dominates wall time even at np=1. Add a fast path: a Hadamard slice whose (single) input tiles are all owned by one rank is contracted locally on that rank with a direct Tensor::gemm -- no comm-split, no sub-World, no make_array/DistEval, no fence. Whether a slice is local is decided purely from global pmap/trange metadata, so every rank reaches the same verdict and the comm-splits for the remaining (genuinely cross-rank) slices stay in lockstep. For batch-blocked data every slice is single-owner, so the whole batched contraction runs communication-free. Slices that span ranks, multi-tile external/contracted dimensions, and tensor-of-tensor tiles fall back to the unchanged sub-World path. Measured speedup over the legacy path (C(b,i,k) = A(b,i,j) * B(b,j,k), 64 batch tiles, flops held constant): 12.9x at np=1, 5.3x at np=2; neutral for a single Hadamard tile. The per-Hadamard-tile machinery cost -- which previously grew the overhead from 2x to 26x as the batch was split into more tiles -- is gone; overhead vs the raw-BLAS flop floor is now flat in the number of Hadamard tiles. The fast path is on by default; set TA_EINSUM_HADAMARD_LOCAL_FASTPATH_DISABLED=1 (or flip detail::einsum_hadamard_local_fastpath_disabled()) to force the legacy sub-World path as a safety valve / differential-correctness hook. examples/tot_bench/batched_contraction_attribution.cpp is an attribution benchmark for this case: legacy vs fast path vs a raw-BLAS flop floor, with a constant-flops granularity sweep.
Member
Author
|
this has no effect on CSV-CC, only affects plain tensor products with batching indices |
Member
Author
|
this will be superceded by proper general product support in expression layer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The generalized batched-contraction
einsumpath — Hadamard indices coexisting with external/contracted indices, e.g.C(b,i,k) = A(b,i,j) * B(b,j,k)— previously ran oneMPI_Comm_split+ a fresh sub-World + a sub-World fence per Hadamard tile. That is O(#Hadamard-tiles) collectives, and the per-tile sub-World construction/teardown dominates wall time even at np=1.This PR adds a local-slice fast path: a Hadamard slice whose (single) input tiles are all owned by one rank is contracted locally on that rank with a direct
Tensor::gemm— no comm-split, no sub-World, nomake_array/DistEval, no fence. Whether a slice is local is decided purely from global pmap/trange metadata, so every rank reaches the same verdict and the comm-splits for the remaining (genuinely cross-rank) slices stay in lockstep. For batch-blocked data every slice is single-owner, so the whole batched contraction runs communication-free.Slices that span ranks, multi-tile external/contracted dimensions, and tensor-of-tensor tiles fall back to the unchanged sub-World path.
Why it's correct
pmap/trangemetadata (identical on all ranks) → no divergence in the collective sequence for the distributed-fallback slices.C(e_A,e_B) = A(e_A,i) * B(e_B,i)is a canonicalGemmHelper(NoTranspose, Transpose, ...);e = (a|b)-(a&b)orders A-externals before B-externals, matchinggemm's(left_outer, right_outer)output, so the existing permute/harvest fixes up the layout unchanged.Performance
Speedup over the legacy path for
C(b,i,k) = A(b,i,j) * B(b,j,k), 64 batch tiles, flops held constant (examples/tot_bench/batched_contraction_attribution.cpp):The legacy path's overhead grew from ~2× to ~26× over the raw-BLAS flop floor as the batch was split into more tiles; with the fast path that per-Hadamard-tile machinery cost is gone and the overhead is flat in the number of Hadamard tiles.
Toggle
On by default. Set
TA_EINSUM_HADAMARD_LOCAL_FASTPATH_DISABLED=1(or flipdetail::einsum_hadamard_local_fastpath_disabled()) to force the legacy sub-World path — kept as a safety valve and differential-correctness hook (mirrorsregime_a_strided_disabled()).Testing
einsum+einsum_totsuites pass at np=1 and np=2 with both the new and legacy paths (Release build).TA_ASSERT_THROW) suite passes at np=1.index_list,bipartite_index_list); unrelated to this change.Notes / follow-ups
Out of scope here, but the benchmark also surfaces remaining headroom: the multi-tile-external case still uses the sub-World fallback (could be a local tiled-gemm loop); the small per-batch GEMMs run at a fraction of peak (batched/strided BLAS would lift the flop floor); and the entry
world.gop.fence()hotfix may now be removable on the common path.