diff --git a/examples/tot_bench/CMakeLists.txt b/examples/tot_bench/CMakeLists.txt index 5236b48e2d..0b8471ab51 100644 --- a/examples/tot_bench/CMakeLists.txt +++ b/examples/tot_bench/CMakeLists.txt @@ -20,7 +20,7 @@ # Strided-DGEMM ToT throughput benches (strided-vs-current arena DGEMM). -foreach(_exec opa_strided_arena_dgemm opb_strided_arena_dgemm regime_a_hce_e_strided_bench ce_ce_segmented_strided_bench) +foreach(_exec opa_strided_arena_dgemm opb_strided_arena_dgemm regime_a_hce_e_strided_bench ce_ce_segmented_strided_bench batched_contraction_attribution) add_ta_executable(${_exec} "${_exec}.cpp" "tiledarray") add_dependencies(examples-tiledarray ${_exec}) endforeach() diff --git a/examples/tot_bench/batched_contraction_attribution.cpp b/examples/tot_bench/batched_contraction_attribution.cpp new file mode 100644 index 0000000000..2f5ead5050 --- /dev/null +++ b/examples/tot_bench/batched_contraction_attribution.cpp @@ -0,0 +1,364 @@ +// batched_contraction_attribution.cpp +// --------------------------------------------------------------------------- +// Attribution benchmark for the *plain* (non-ToT) batched contraction +// +// C(b,i,k) = A(b,i,j) * B(b,j,k) +// +// where `b` is a Hadamard ("batch") index (present in A, B, and C), `i`/`k` +// are external indices, and `j` is contracted. This is the case that today +// can only be expressed via TA::einsum, and whose performance the batch index +// being "hidden" inside the TA::Tensor tile (folded into Tensor::nbatch) is +// suspected to hurt. +// +// The benchmark separates two cost centers: +// +// RAW : the irreducible work. For every one of the B = bt*be batch +// elements we issue exactly ONE BLAS dgemm (I x J)*(J x K) on the +// already-local tile data -- i.e. precisely the per-batch-element +// GEMM granularity that einsum's Tensor::gemm() ultimately runs, +// but with ZERO TiledArray driver overhead. RAW therefore already +// includes the "many small GEMMs, no batched-BLAS" inefficiency +// (the P1 headroom). +// +// EINSUM : the production path, einsum(A("b,i,j"), B("b,j,k"), "b,i,k"). +// +// => (EINSUM - RAW) isolates the TA/einsum *machinery*: blocking find().get() +// of every tile, eager tile.permute(), reshape-into-nbatch, make_array of +// per-slice sub-arrays, the per-Hadamard-tile MPI_Comm_split + sub-World, +// and the per-slice fences (the P2/P3/P4 headroom). +// +// The granularity SWEEP holds the total batch B and the total flop count +// FIXED while moving B from "few big batch tiles" to "many tiny batch tiles". +// Flops are constant across the sweep, so any rise in EINSUM time with the +// number of batch tiles is machinery overhead that scales with the number of +// Hadamard tiles -- the signature of the per-H-tile comm-split/sub-World path. +// +// NOTE: the RAW baseline materializes data locally and is meaningful at np=1 +// (single rank); it is the node-local flops reference. The EINSUM timing is +// valid at any world size -- run at np=2 to watch the per-H-tile collective +// machinery show up. +// --------------------------------------------------------------------------- + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace TA = TiledArray; + +using clock_type = std::chrono::steady_clock; +static double ms_since(clock_type::time_point t0) { + return std::chrono::duration_cast( + clock_type::now() - t0) + .count() / + 1.0e6; +} + +static double median(std::vector v) { + std::sort(v.begin(), v.end()); + const std::size_t n = v.size(); + if (n == 0) return 0.0; + return (n % 2) ? v[n / 2] : 0.5 * (v[n / 2 - 1] + v[n / 2]); +} + +// =========================================================================== +// CLI +// =========================================================================== + +struct Cli { + int reps = 20; // timed reps per path + int warmup = 3; // untimed warmup reps + int bt = 16; // number of batch (Hadamard) tiles + int be = 4; // extent of each batch tile (total batch B = bt*be) + int I = 32; // external index i extent (one tile) + int J = 32; // contracted index j extent (one tile) + int K = 32; // external index k extent (one tile) + int sweep = 0; // if !=0, run the constant-B granularity sweep +}; + +static void usage() { + std::fprintf(stderr, + "batched_contraction_attribution\n" + " C(b,i,k) = A(b,i,j) * B(b,j,k) (b=Hadamard, j=contracted)\n" + " --reps R timed reps per path (default 20)\n" + " --warmup W untimed warmup reps (default 3)\n" + " --bt N number of batch tiles (default 16)\n" + " --be E extent per batch tile (default 4)\n" + " --i / --j / --k matrix extents (default 32)\n" + " --sweep run constant-B granularity sweep (B = bt*be held " + "fixed)\n"); +} + +static Cli parse_cli(int argc, char** argv) { + Cli c; + for (int a = 1; a < argc; ++a) { + std::string s = argv[a]; + auto need = [&]() -> std::string { + if (a + 1 >= argc) { + usage(); + std::exit(1); + } + return argv[++a]; + }; + if (s == "--reps") + c.reps = std::stoi(need()); + else if (s == "--warmup") + c.warmup = std::stoi(need()); + else if (s == "--bt") + c.bt = std::stoi(need()); + else if (s == "--be") + c.be = std::stoi(need()); + else if (s == "--i") + c.I = std::stoi(need()); + else if (s == "--j") + c.J = std::stoi(need()); + else if (s == "--k") + c.K = std::stoi(need()); + else if (s == "--sweep") + c.sweep = 1; + else if (s == "-h" || s == "--help") { + usage(); + std::exit(0); + } else { + std::fprintf(stderr, "unknown flag: %s\n", s.c_str()); + usage(); + std::exit(1); + } + } + return c; +} + +// =========================================================================== +// helpers +// =========================================================================== + +using Arr = TA::DistArray, TA::DensePolicy>; + +static TA::TiledRange1 batch_tr1(int bt, int be) { + std::vector bounds; + bounds.reserve(bt + 1); + for (int t = 0; t <= bt; ++t) bounds.push_back(static_cast(t) * be); + return TA::TiledRange1(bounds.begin(), bounds.end()); +} + +// deterministic fills (function of global coords) so RAW and EINSUM agree +static double a_val(long b, long i, long j) { + return 1.0 + 0.5 * std::sin(0.013 * (b + 1) * (i + 1)) + 0.1 * (j + 1); +} +static double b_val(long b, long j, long k) { + return 2.0 - 0.3 * std::cos(0.017 * (b + 1) * (k + 1)) + 0.05 * (j + 1); +} + +static Arr make_A(TA::World& world, int bt, int be, int I, int J) { + TA::TiledRange tr{batch_tr1(bt, be), TA::TiledRange1{0l, I}, + TA::TiledRange1{0l, J}}; + Arr a(world, tr); + a.init_tiles([&](const TA::Range& r) { + TA::Tensor t(r); + const auto lo = r.lobound(); + const auto ex = r.extent(); + std::size_t o = 0; + for (long b = lo[0]; b < lo[0] + (long)ex[0]; ++b) + for (long i = lo[1]; i < lo[1] + (long)ex[1]; ++i) + for (long j = lo[2]; j < lo[2] + (long)ex[2]; ++j) + t.data()[o++] = a_val(b, i, j); + return t; + }); + return a; +} + +static Arr make_B(TA::World& world, int bt, int be, int J, int K) { + TA::TiledRange tr{batch_tr1(bt, be), TA::TiledRange1{0l, J}, + TA::TiledRange1{0l, K}}; + Arr b(world, tr); + b.init_tiles([&](const TA::Range& r) { + TA::Tensor t(r); + const auto lo = r.lobound(); + const auto ex = r.extent(); + std::size_t o = 0; + for (long bb = lo[0]; bb < lo[0] + (long)ex[0]; ++bb) + for (long j = lo[1]; j < lo[1] + (long)ex[1]; ++j) + for (long k = lo[2]; k < lo[2] + (long)ex[2]; ++k) + t.data()[o++] = b_val(bb, j, k); + return t; + }); + return b; +} + +// RAW baseline: one BLAS dgemm per batch element on already-local tile data. +// Returns a checksum (to defeat DCE) via out-param; time is measured outside. +static double raw_once(const Arr& A, const Arr& B, int I, int J, int K, + std::vector& cbuf) { + namespace blas = TiledArray::math::blas; + double checksum = 0.0; + const auto& tr = A.trange(); + const std::size_t nbtiles = tr.dim(0).tile_extent(); + for (std::size_t bt = 0; bt < nbtiles; ++bt) { + // batch tile index (bt, 0, 0) + std::array tidx{bt, 0, 0}; + if (!A.is_local(tidx)) continue; + auto at = A.find_local(tidx).get(); + auto bt_tile = B.find_local(tidx).get(); + const double* ap = at.data(); + const double* bp = bt_tile.data(); + const std::size_t be = at.range().extent()[0]; + for (std::size_t e = 0; e < be; ++e) { + const double* ae = ap + e * (std::size_t)I * J; // I x J + const double* be_ = bp + e * (std::size_t)J * K; // J x K + // C(I,K) = A(I,J) * B(J,K), row-major + blas::gemm(blas::Op::NoTrans, blas::Op::NoTrans, I, K, J, 1.0, ae, J, be_, + K, 0.0, cbuf.data(), K); + checksum += cbuf[0] + cbuf[(std::size_t)I * K - 1]; + } + } + return checksum; +} + +// peak-ish reference: throughput of a single large square dgemm, to +// contextualize how far the small per-batch GEMMs are from the machine's +// achievable rate. +static double peak_gemm_gflops(int n, int reps) { + namespace blas = TiledArray::math::blas; + std::vector a((std::size_t)n * n, 1.0001), + b((std::size_t)n * n, 0.9999), c((std::size_t)n * n, 0.0); + // warmup + blas::gemm(blas::Op::NoTrans, blas::Op::NoTrans, n, n, n, 1.0, a.data(), n, + b.data(), n, 0.0, c.data(), n); + std::vector ms; + for (int r = 0; r < reps; ++r) { + auto t0 = clock_type::now(); + blas::gemm(blas::Op::NoTrans, blas::Op::NoTrans, n, n, n, 1.0, a.data(), n, + b.data(), n, 0.0, c.data(), n); + ms.push_back(ms_since(t0)); + } + const double flop = 2.0 * (double)n * n * n; + const double t = median(ms) / 1.0e3; + return flop / t / 1.0e9; +} + +struct Result { + double legacy_med, fused_med, raw_med; +}; + +// Time the einsum path under a given value of the fused-Hadamard toggle. +static double time_einsum(TA::World& world, const Arr& A, const Arr& B, + bool fused, int reps, int warmup) { + TA::detail::einsum_hadamard_local_fastpath_disabled() = !fused; + for (int w = 0; w < warmup; ++w) { + auto c = einsum(A("b,i,j"), B("b,j,k"), "b,i,k"); + c.world().gop.fence(); + } + world.gop.fence(); + std::vector ms; + ms.reserve(reps); + for (int r = 0; r < reps; ++r) { + auto t0 = clock_type::now(); + auto c = einsum(A("b,i,j"), B("b,j,k"), "b,i,k"); + c.world().gop.fence(); + ms.push_back(ms_since(t0)); + } + TA::detail::einsum_hadamard_local_fastpath_disabled() = + false; // restore default + return median(ms); +} + +static Result run_case(TA::World& world, const Cli& cli, int bt, int be, + bool quiet) { + const int I = cli.I, J = cli.J, K = cli.K; + Arr A = make_A(world, bt, be, I, J); + Arr B = make_B(world, bt, be, J, K); + world.gop.fence(); + + const double flop = 2.0 * (double)bt * be * I * J * K; + + const double legacy = + time_einsum(world, A, B, /*fused=*/false, cli.reps, cli.warmup); + const double fused = + time_einsum(world, A, B, /*fused=*/true, cli.reps, cli.warmup); + + // RAW timed (node-local; meaningful at np=1) + std::vector cbuf((std::size_t)I * K, 0.0); + for (int w = 0; w < cli.warmup; ++w) raw_once(A, B, I, J, K, cbuf); + std::vector r_ms; + r_ms.reserve(cli.reps); + volatile double sink = 0.0; + for (int r = 0; r < cli.reps; ++r) { + auto t0 = clock_type::now(); + sink += raw_once(A, B, I, J, K, cbuf); + r_ms.push_back(ms_since(t0)); + } + (void)sink; + + const double rm = median(r_ms); + if (!quiet && world.rank() == 0) { + auto gf = [&](double ms) { return ms > 0 ? flop / (ms / 1e3) / 1e9 : 0.0; }; + std::printf( + "bt=%-4d be=%-3d B=%-6d | legacy %9.4f ms (%6.2f GF/s) P4b %9.4f ms " + "(%6.2f GF/s) RAW %8.4f ms | speedup %5.2fx ovhd L=%5.1fx " + "P4b=%5.1fx\n", + bt, be, bt * be, legacy, gf(legacy), fused, gf(fused), rm, + fused > 0 ? legacy / fused : 0.0, rm > 0 ? legacy / rm : 0.0, + rm > 0 ? fused / rm : 0.0); + } + return {legacy, fused, rm}; +} + +// =========================================================================== +// main +// =========================================================================== + +int main(int argc, char** argv) { + Cli cli = parse_cli(argc, argv); + auto& world = TA_SCOPED_INITIALIZE(argc, argv); + + if (world.rank() == 0) { + std::printf( + "=== batched contraction attribution: C(b,i,k)=A(b,i,j)*B(b,j,k) " + "===\n"); + std::printf("world.size=%d reps=%d warmup=%d i=%d j=%d k=%d\n", + world.size(), cli.reps, cli.warmup, cli.I, cli.J, cli.K); + std::printf( + "legacy = comm-split sub-World per Hadamard tile; P4b = parent-" + "World slices + single fence;\nRAW = one BLAS dgemm per batch " + "elem (no TA driver). speedup = legacy/P4b.\n\n"); + } + + if (!cli.sweep) { + run_case(world, cli, cli.bt, cli.be, /*quiet=*/false); + } else { + // constant total batch B = bt*be; vary granularity from coarse to fine. + const int B = cli.bt * cli.be; + std::vector> tilings; + for (int be = B; be >= 1; be /= 2) { + if (B % be != 0) continue; + tilings.emplace_back(B / be, be); // (bt, be) + } + if (world.rank() == 0) + std::printf( + "--- granularity sweep (total batch B=%d, flops CONSTANT) ---\n", B); + for (auto [bt, be] : tilings) run_case(world, cli, bt, be, /*quiet=*/false); + if (world.rank() == 0) + std::printf( + "\n(Reading: flops are identical across rows; rising EINSUM time as\n" + " bt grows = per-Hadamard-tile machinery cost. RAW isolates the\n" + " small-GEMM/flops floor.)\n"); + } + + if (world.rank() == 0) { + const double pk = peak_gemm_gflops(1024, 5); + std::printf( + "\nmachine ref: single 1024^3 dgemm = %.1f GF/s (throughput ceiling)\n", + pk); + } + + return 0; +} diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 89fd0db4cf..4b24b1368d 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -7,16 +7,49 @@ #include "TiledArray/einsum/range.h" #include "TiledArray/expressions/fwd.h" #include "TiledArray/fwd.h" +#include "TiledArray/math/blas.h" +#include "TiledArray/math/gemm_helper.h" #include "TiledArray/tensor/arena_einsum.h" #include "TiledArray/tiled_range.h" #include "TiledArray/tiled_range1.h" #include +#include +#include +#include + namespace TiledArray { enum struct DeNest { True, False }; } +namespace TiledArray::detail { + +/// Kill switch for the local-slice fast path of the *generalized* +/// batched-contraction einsum (Hadamard indices coexisting with +/// external/contracted indices). +/// +/// - false (default): a Hadamard slice whose input tiles are all owned by a +/// single rank is contracted locally on that rank with a direct +/// `Tensor::gemm` -- no `MPI_Comm_split`, no sub-World, no per-tile fence. +/// Slices that span ranks fall back to the sub-World path. +/// - true: forces the legacy path -- one `MPI_Comm_split` + a fresh sub-World +/// + a sub-World fence *per Hadamard tile* (O(#Hadamard-tiles) collectives). +/// Retained as a safety valve / differential-correctness hook. +/// +/// Toggleable at runtime (test/bench hook) and from the environment via +/// `TA_EINSUM_HADAMARD_LOCAL_FASTPATH_DISABLED` (any non-empty value other than +/// "0" forces the legacy path). Mirrors `regime_a_strided_disabled()`. +inline bool &einsum_hadamard_local_fastpath_disabled() { + static bool flag = [] { + const char *e = std::getenv("TA_EINSUM_HADAMARD_LOCAL_FASTPATH_DISABLED"); + return e != nullptr && e[0] != '\0' && std::string_view(e) != "0"; + }(); + return flag; +} + +} // namespace TiledArray::detail + namespace TiledArray::Einsum { using ::Einsum::index::small_vector; @@ -899,65 +932,193 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, std::invoke(update_tr, std::get<0>(AB)); std::invoke(update_tr, std::get<1>(AB)); - // iterates over tiles of hadamard indices + // Per-Hadamard-tile retile: gather this rank's local input tiles for slice + // `h`, (eagerly) permute them to canonical layout, fold the within-tile + // Hadamard extent into nbatch, and assemble the slice sub-array `term.ei` + // on `slice_world`. + auto retile = [](auto &term, World &slice_world, const Index &h, + size_t batch) { + term.local_tiles.clear(); + const Permutation &P = term.permutation; + + for (Index ei : term.tiles) { + auto idx = apply_inverse(P, h + ei); + if (!term.array.is_local(idx)) continue; + if (term.array.is_zero(idx)) continue; + // TODO no need for immediate evaluation + auto tile = term.array.find_local(idx).get(); + if (P) tile = tile.permute(P); + auto shape = term.ei_tiled_range.tile(ei); + tile = tile.reshape(shape, batch); + term.local_tiles.push_back({ei, tile}); + } + bool replicated = term.array.pmap()->is_replicated(); + term.ei = TiledArray::make_array( + slice_world, term.ei_tiled_range, term.local_tiles.begin(), + term.local_tiles.end(), replicated); + }; + + // Stores one completed slice result tile (external indices `e`, with the + // within-tile Hadamard extent folded into nbatch) into C_local_tiles: + // unfold nbatch back into the Hadamard modes and permute to target C + // layout. + auto store_C_tile = [&](const Index &h, const Index &e, ResultTensor tile) { + const Permutation &P = C.permutation; + auto c = apply(P, h + e); + auto shape = C.array.trange().tile(c); + shape = apply_inverse(P, shape); + tile = tile.reshape(shape); + if (P) tile = tile.permute(P); + C_local_tiles.emplace_back(std::move(c), std::move(tile)); + }; + + // Extracts the (completed) result sub-array `c_ei` of Hadamard slice `h` + // into C_local_tiles. + auto harvest = [&](const Index &h, ArrayC &c_ei) { + for (Index e : C.tiles) { + if (!c_ei.is_local(e)) continue; + if (c_ei.is_zero(e)) continue; + store_C_tile(h, e, c_ei.find_local(e).get()); + } + }; + + if (detail::einsum_hadamard_local_fastpath_disabled()) { + // ===== Legacy path: one MPI_Comm_split + sub-World + fence per + // Hadamard tile. O(#Hadamard-tiles) collectives. ===== + for (Index h : H.tiles) { + auto &[A, B] = AB; + auto own = A.own(h) || B.own(h); + auto comm = madness::blocking_invoke( + &SafeMPI::Intracomm::Split, world.mpi.comm(), own, world.rank()); + worlds.push_back(std::make_unique(comm)); + auto &owners = worlds.back(); + if (!own) continue; + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); + } + + std::invoke(retile, std::get<0>(AB), *owners, h, batch); + std::invoke(retile, std::get<1>(AB), *owners, h, batch); + + C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); + A.ei.defer_deleter_to_next_fence(); + B.ei.defer_deleter_to_next_fence(); + A.ei = ArrayA(); + B.ei = ArrayB(); + // why omitting this fence leads to deadlock? + owners->gop.fence(); + std::invoke(harvest, h, C.ei); + // mark for lazy deletion + C.ei = ArrayC(); + } + + build_C_array(); + + for (auto &w : worlds) { + w->gop.fence(); + } + + return C.array; + } + + // ===== Local-slice fast path: a Hadamard slice whose (single) input tiles + // are all owned by one rank is contracted *locally* on that rank with a + // direct Tensor::gemm -- no comm-split, no sub-World, no make_array/ + // DistEval, no fence. Slices that span ranks fall back to the sub-World + // path below. + // + // Whether a slice is "local" is decided purely from global pmap/trange + // metadata, so every rank reaches the same verdict and the comm-splits for + // the (remaining) distributed slices stay in lockstep across ranks. For + // batch-blocked data every slice is single-owner, so the whole batched + // contraction runs communication-free. ===== + + // The local fast path needs a value-returning per-tile product, so it is + // restricted to plain (non-nested) tensor tiles; ToT slices always take the + // distributed fallback. + constexpr bool plain_tiles = + IsArrayT && IsArrayT && IsArrayT; + + // The slice contraction is a single Tensor::gemm only when each operand's + // external+contracted modes form exactly one tile (one A tile, one B tile, + // one C tile per slice). This is a property of the tiled ranges, identical + // for every Hadamard slice. + [[maybe_unused]] std::vector a_eis, b_eis, c_es; + for (Index ei : std::get<0>(AB).tiles) a_eis.push_back(ei); + for (Index ei : std::get<1>(AB).tiles) b_eis.push_back(ei); + for (Index ee : C.tiles) c_es.push_back(ee); + [[maybe_unused]] const bool single_tile_slices = + a_eis.size() == 1 && b_eis.size() == 1 && c_es.size() == 1; + + // Canonical (NoTranspose,Transpose) GEMM for C(e_A,e_B) = + // A(e_A,i)*B(e_B,i): both operands carry the contracted indices `i` as + // their trailing modes (a_ei/b_ei = (e+i)&operand, with e ordered + // A-externals-then-B-externals), and the result lands in `e` order, + // matching C.expr. + const auto a_ei_idx = (e + i) & a; + const auto b_ei_idx = (e + i) & b; + [[maybe_unused]] const math::GemmHelper gemm_helper( + math::blas::NoTranspose, math::blas::Transpose, + static_cast(e.size()), + static_cast(a_ei_idx.size()), + static_cast(b_ei_idx.size())); + for (Index h : H.tiles) { auto &[A, B] = AB; + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); + } + + // ---- local fast path ---- + if constexpr (plain_tiles) { + if (single_tile_slices) { + const Index &a_ei = a_eis[0]; + const Index &b_ei = b_eis[0]; + const Index &c_e = c_es[0]; + const auto a_idx = apply_inverse(A.permutation, h + a_ei); + const auto b_idx = apply_inverse(B.permutation, h + b_ei); + // global, rank-consistent verdict: both inputs on the same owner. + if (A.array.owner(a_idx) == B.array.owner(b_idx)) { + if (A.array.is_local(a_idx)) { // this rank is that sole owner + if (!(A.array.is_zero(a_idx) || B.array.is_zero(b_idx))) { + auto a_tile = A.array.find_local(a_idx).get(); + auto b_tile = B.array.find_local(b_idx).get(); + if (A.permutation) a_tile = a_tile.permute(A.permutation); + if (B.permutation) b_tile = b_tile.permute(B.permutation); + a_tile = a_tile.reshape(A.ei_tiled_range.tile(a_ei), batch); + b_tile = b_tile.reshape(B.ei_tiled_range.tile(b_ei), batch); + ResultTensor c_tile; + c_tile.gemm(a_tile, b_tile, + typename ResultTensor::numeric_type{1}, + gemm_helper); + store_C_tile(h, c_e, std::move(c_tile)); + } + } + continue; // all ranks skip the distributed handling for this slice + } + } + } + + // ---- distributed fallback: legacy sub-World contraction ---- auto own = A.own(h) || B.own(h); auto comm = madness::blocking_invoke(&SafeMPI::Intracomm::Split, world.mpi.comm(), own, world.rank()); worlds.push_back(std::make_unique(comm)); auto &owners = worlds.back(); if (!own) continue; - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); - } - auto retile = [&owners, &h = std::as_const(h), batch](auto &term) { - term.local_tiles.clear(); - const Permutation &P = term.permutation; - - for (Index ei : term.tiles) { - auto idx = apply_inverse(P, h + ei); - if (!term.array.is_local(idx)) continue; - if (term.array.is_zero(idx)) continue; - // TODO no need for immediate evaluation - auto tile = term.array.find_local(idx).get(); - if (P) tile = tile.permute(P); - auto shape = term.ei_tiled_range.tile(ei); - tile = tile.reshape(shape, batch); - term.local_tiles.push_back({ei, tile}); - } - bool replicated = term.array.pmap()->is_replicated(); - term.ei = TiledArray::make_array( - *owners, term.ei_tiled_range, term.local_tiles.begin(), - term.local_tiles.end(), replicated); - }; - std::invoke(retile, std::get<0>(AB)); - std::invoke(retile, std::get<1>(AB)); + std::invoke(retile, std::get<0>(AB), *owners, h, batch); + std::invoke(retile, std::get<1>(AB), *owners, h, batch); C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); A.ei.defer_deleter_to_next_fence(); B.ei.defer_deleter_to_next_fence(); A.ei = ArrayA(); B.ei = ArrayB(); - // why omitting this fence leads to deadlock? owners->gop.fence(); - for (Index e : C.tiles) { - if (!C.ei.is_local(e)) continue; - if (C.ei.is_zero(e)) continue; - // TODO no need for immediate evaluation - auto tile = C.ei.find_local(e).get(); - assert(tile.nbatch() == batch); - const Permutation &P = C.permutation; - auto c = apply(P, h + e); - auto shape = C.array.trange().tile(c); - shape = apply_inverse(P, shape); - tile = tile.reshape(shape); - if (P) tile = tile.permute(P); - C_local_tiles.emplace_back(std::move(c), std::move(tile)); - } - // mark for lazy deletion + std::invoke(harvest, h, C.ei); C.ei = ArrayC(); }