From e5693a41820b12a830c05ba96a0918f4cc7e9138 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:03:39 -0400 Subject: [PATCH 1/4] expressions: deduce inner-node index sets top-down from the assignment target (Phase E) An index shared by the two children of a product is fused iff the node's target carries it, contracted otherwise; an index neither the sibling nor the target carries is consumed within the child subtree and not demanded of it. available_indices() (per-subtree leaf-annotation union, valid before init) supplies the up-pass; each child dictates the ORDER of its demand via preferred_layout() (canonical (fused, left-free, right-free) for products, pass-through elsewhere). Expressions consumed without a target (reductions) retain the bottom-up contraction convention. --- src/TiledArray/expressions/binary_engine.h | 29 +++++ src/TiledArray/expressions/leaf_engine.h | 12 +++ src/TiledArray/expressions/mult_engine.h | 120 ++++++++++++++------- src/TiledArray/expressions/unary_engine.h | 10 ++ 4 files changed, 135 insertions(+), 36 deletions(-) diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index 7d0e3ff0d2..c3d2e69da5 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -229,6 +229,35 @@ class BinaryEngine : public ExprEngine { right_inner_permtype_ == PermutationType::general); } + /// \return the indices this subtree can supply (Phase E up-pass): the + /// first-occurrence-ordered union of the children's available indices, + /// outer and inner lists separately. Valid before init: it depends only on + /// the leaf annotations, never on resolved (post-init) index sets. + BipartiteIndexList available_indices() const { + auto union_ = [](auto const& a, auto const& b) { + container::svector r(a.begin(), a.end()); + for (auto&& idx : b) + if (!a.count(idx)) r.push_back(idx); + return r; + }; + auto const l = left_.available_indices(); + auto const r = right_.available_indices(); + auto const out = union_(outer(l), outer(r)); + auto const in = union_(inner(l), inner(r)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + } + + /// \return the layout this subtree prefers for producing the index set of + /// \p demand: element-wise binary ops (Add/Subt) impose no layout of their + /// own (both children must produce the same set and are aligned by + /// permutation), so the demand is returned unchanged. MultEngine overrides + /// this with its canonical product layout. + const BipartiteIndexList& preferred_layout( + const BipartiteIndexList& demand) const { + return demand; + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape diff --git a/src/TiledArray/expressions/leaf_engine.h b/src/TiledArray/expressions/leaf_engine.h index 8804989d6f..112581c71f 100644 --- a/src/TiledArray/expressions/leaf_engine.h +++ b/src/TiledArray/expressions/leaf_engine.h @@ -127,6 +127,18 @@ class LeafEngine : public ExprEngine { /// This function is a noop since the index list is fixed. void init_indices() {} + /// \return the indices this subtree can supply (Phase E up-pass): for a + /// leaf, simply its annotation (set at construction, valid before init) + const BipartiteIndexList& available_indices() const { return indices_; } + + /// \return the layout this subtree prefers for producing the index set of + /// \p demand: a leaf accepts any ordering (the consumer aligns against the + /// fixed annotation by permutation), so the demand is returned unchanged + const BipartiteIndexList& preferred_layout( + const BipartiteIndexList& demand) const { + return demand; + } + void init_distribution(World* world, const std::shared_ptr& pmap) { ExprEngine_::init_distribution(world, (pmap ? pmap : array_.pmap())); diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index d5db1bf093..3f74d4c7cd 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -281,44 +281,58 @@ class MultEngine : public ContEngine> { /// \param target_indices The target index list for this expression void init_indices(const BipartiteIndexList& target_indices) { - // to decide what type of product this is must initialize indices down - // the tree. - // N.B. since this may be a contraction we do not know the target indices - // for the left and right, hence do target-neutral initialization - BinaryEngine_::left_.init_indices(); - BinaryEngine_::right_.init_indices(); - - // Validate that the (bottom-up resolved) child indices are consistent - // with the target: every outer index of each child must appear in the - // other child or in the target (as a free, fused, or contracted index). - // A violation usually means a *general* product (fused + contracted + - // free indices) appears at an INNER node of the expression tree, where - // the role of a shared index cannot be deduced bottom-up; e.g. in the - // THC-like g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * ... the - // index r1 is fused in X*X but contracted downstream, while the - // bottom-up convention contracts it in X*X, orphaning the r1 of Z. - // Resolving this requires pushing the needed-index sets down the - // expression tree; until then materialize such inner products into - // explicit intermediates, so that every general product appears as the - // root of its own assignment (where the target determines the index - // roles). + // Deduce each child's index SET top-down (Phase E). The index set an + // inner node must produce depends on what its ancestors need: a shared + // index of the two children is fused iff this node's target carries it, + // contracted here otherwise; an index of one child that neither the + // sibling nor the target carries is consumed entirely within that + // child's subtree and is not demanded of it. + // Each child's demand is ordered target-kept-first (the indices this + // node's result keeps, in target order) followed by the + // contracted-here indices in the child's leaf-availability order -- + // fused indices thus lead, matching the canonical layouts of + // GeneralPermutationOptimizer (h leading) so the nbatch fold stays + // zero-copy. E.g. in the THC-like g("p,q,r,s") = X("p,r1") * X("q,r1") + // * Z("r1,r2") * ... the index r1 is demanded of X*X by its consumer + // (fused there), and contracted where Z meets the X*X subtree. + // N.B. expressions consumed WITHOUT a target (reductions, e.g. dot) + // take the no-target init_indices() overload below, which retains the + // bottom-up contraction convention -- general products under + // reductions remain unsupported. { - auto const& left_outer = outer(BinaryEngine_::left_.indices()); - auto const& right_outer = outer(BinaryEngine_::right_.indices()); - auto const& target_outer = outer(target_indices); - auto validate = [&](const IndexList& a, const IndexList& b) { - for (auto&& idx : a) - if (!b.count(idx) && !target_outer.count(idx)) - TA_EXCEPTION( - "MultEngine: an argument index appears in neither the other " - "argument nor the target. If a general product (fused + " - "contracted + free indices) appears at an inner node of the " - "expression tree, its index roles cannot be deduced " - "bottom-up; materialize it into an explicit intermediate so " - "that it appears as the root of its own assignment"); + auto const avail_l = BinaryEngine_::left_.available_indices(); + auto const avail_r = BinaryEngine_::right_.available_indices(); + auto demand = [](auto const& avail, auto const& sibling, + auto const& tgt) { + container::svector r; + for (auto&& idx : tgt) + if (avail.count(idx)) r.push_back(idx); + for (auto&& idx : avail) { + if (tgt.count(idx)) continue; + if (sibling.count(idx)) r.push_back(idx); + // else: the index is consumed entirely within the child's own + // subtree (contracted deeper down) -- not demanded here. A + // genuinely orphaned index (single occurrence, demanded nowhere = + // implicit trace, unsupported) surfaces as the leaf-level + // target-is-not-a-permutation error. + } + return r; }; - validate(left_outer, right_outer); - validate(right_outer, left_outer); + auto bipartite_demand = [&demand](auto const& avail, auto const& sib, + auto const& tgt) { + auto const out = demand(outer(avail), outer(sib), outer(tgt)); + auto const in = demand(inner(avail), inner(sib), inner(tgt)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + }; + // each child dictates the ORDER of its demand (preferred_layout): a + // product child reorders to its canonical (h, eA, eB) layout -- a + // general product cannot host a result permutation, and contraction + // consumers absorb any child layout via the GEMM transpose forms + BinaryEngine_::left_.init_indices(BinaryEngine_::left_.preferred_layout( + bipartite_demand(avail_l, avail_r, target_indices))); + BinaryEngine_::right_.init_indices(BinaryEngine_::right_.preferred_layout( + bipartite_demand(avail_r, avail_l, target_indices))); } this->product_type_ = compute_product_type( @@ -377,6 +391,40 @@ class MultEngine : public ContEngine> { } } + /// \return the layout this product prefers for producing the index set of + /// \p demand: the canonical general-product result layout (h, eA, eB) -- + /// indices supplied by both children (fused here) lead, then indices + /// supplied by the left child only, then by the right child only, each + /// group in demand order. h-leading is required when this node resolves to + /// a general product (general results cannot host a result permutation and + /// the nbatch fold needs the fused modes leading); for a pure contraction + /// (empty h) the (eA, eB) order matches the GEMM result, and for a pure + /// Hadamard the demand order is preserved. + BipartiteIndexList preferred_layout(const BipartiteIndexList& demand) const { + auto const avail_l = BinaryEngine_::left_.available_indices(); + auto const avail_r = BinaryEngine_::right_.available_indices(); + auto canonical = [](auto const& d, auto const& al, auto const& ar) { + container::svector h, ea, eb; + for (auto&& idx : d) { + bool const inl = al.count(idx); + bool const inr = ar.count(idx); + if (inl && inr) + h.push_back(idx); + else if (inl) + ea.push_back(idx); + else + eb.push_back(idx); + } + h.insert(h.end(), ea.begin(), ea.end()); + h.insert(h.end(), eb.begin(), eb.end()); + return h; + }; + auto const out = canonical(outer(demand), outer(avail_l), outer(avail_r)); + auto const in = canonical(inner(demand), inner(avail_l), inner(avail_r)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape diff --git a/src/TiledArray/expressions/unary_engine.h b/src/TiledArray/expressions/unary_engine.h index 631fca8fed..2b34cc716c 100644 --- a/src/TiledArray/expressions/unary_engine.h +++ b/src/TiledArray/expressions/unary_engine.h @@ -120,6 +120,16 @@ class UnaryEngine : ExprEngine { indices_ = arg_.indices(); } + /// \return the indices this subtree can supply (Phase E up-pass): a unary + /// op supplies exactly what its argument supplies (valid before init) + decltype(auto) available_indices() const { return arg_.available_indices(); } + + /// \return the layout this subtree prefers for producing the index set of + /// \p demand: a unary op defers to its argument's preference + decltype(auto) preferred_layout(const BipartiteIndexList& demand) const { + return arg_.preferred_layout(demand); + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape From e9980d186549a036244c4f6840680e3222502bf6 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:03:52 -0400 Subject: [PATCH 2/4] expressions: general products honor arbitrary result layouts via a streaming re-permute A target that differs from the canonical (fused..., left-free..., right-free...) result layout cannot be folded into the batched tile op (BatchedContractReduce must be perm-free); evaluate canonically (Summa over a slab-replicated pmap) and re-permute to the target with a streaming UnaryEvalImpl. Honors the implicit-permute contract: when the consumer fuses the permutation into its own operation (transposed GEMM), only the tile ordinals/trange are remapped and contents stay canonical. Replaces the interleaved-target gate, enabling general products at inner expression-tree nodes and non-canonical root targets. --- src/TiledArray/expressions/cont_engine.h | 95 ++++++++++++++++++++---- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index dfab3ee664..d79d3a27db 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -27,6 +27,7 @@ #define TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include namespace TiledArray { namespace expressions { @@ -179,6 +181,10 @@ class ContEngine : public BinaryEngine { // General (fused + contracted + free indices) products only: unsigned int n_fused_modes_ = 0; ///< # of leading fused (outer) modes size_type n_slabs_ = 1; ///< # of fused-index tile slabs + bool general_repermute_ = false; ///< whether the target layout differs + ///< from the canonical result layout, so + ///< the evaluated result is re-permuted + ///< by a streaming unary eval static unsigned int find(const BipartiteIndexList& indices, const std::string& index_label, unsigned int i, @@ -660,16 +666,13 @@ class ContEngine : public BinaryEngine { TA_ASSERT(nh > 0u); // else this is a pure contraction n_fused_modes_ = nh; - // initialize perm_; an interleaved target (a result permutation that - // mixes fused and free modes) is not yet supported -- the canonical - // result layout must equal the target + // initialize perm_; a target that differs from the canonical (fused..., + // left-free..., right-free...) result layout cannot be folded into the + // batched tile op (BatchedContractReduce must be perm-free), so the + // product is evaluated in its canonical layout and re-permuted to the + // target by a streaming unary eval (see make_dist_eval_general) this->init_perm(target_indices); - if (outer(target_indices) != outer(indices_)) - TA_EXCEPTION( - "general products (fused + contracted + free indices): targets " - "that interleave fused and free indices are not yet supported; " - "reorder the result annotation to (fused..., left-free..., " - "right-free...)"); + general_repermute_ = (outer(target_indices) != outer(indices_)); // the tile op operates on the folded (fused-mode-free) shapes const auto left_op = to_cblas_op(left_outer_permtype_); @@ -712,6 +715,12 @@ class ContEngine : public BinaryEngine { trange_ = make_trange_general(); shape_ = make_shape_general(); + if (general_repermute_) { + // consumers see the target layout; the canonical structures are + // recomputed in make_dist_eval_general for the inner Summa + trange_ = outer(perm_) * trange_; + shape_ = shape_.perm(outer(perm_)); + } if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape); @@ -834,9 +843,32 @@ class ContEngine : public BinaryEngine { } } + /// Streaming tile re-permute op for general products whose target layout + /// differs from the canonical (fused..., free...) result layout: the + /// batched tile op must stay perm-free, so the consumer-side unary eval + /// applies the result permutation per tile instead + struct GeneralRepermuteOp { + typedef value_type result_type; + typedef value_type argument_type; + static constexpr bool is_consumable = false; + BipartitePermutation perm; + /// false when the consumer fuses the permutation into its own operation + /// (implicit permute, e.g. a transposed GEMM): then only the tile + /// ordinals/trange are remapped (by the host UnaryEvalImpl) and the tile + /// contents are delivered in the canonical layout + bool permute_contents = true; + result_type operator()(const argument_type& tile) const { + if (!permute_contents) return tile; + TiledArray::detail::Noop noop; + return noop(tile, perm); + } + }; + /// Construct the distributed evaluator of a general product - /// \return The batched-Summa distributed evaluator for this expression + /// \return The batched-Summa distributed evaluator for this expression, + /// wrapped in a streaming re-permute when the target layout differs from + /// the canonical result layout dist_eval_type make_dist_eval_general() const { typedef TiledArray::detail::BatchedContractReduce batched_op_type; typedef TiledArray::detail::Summa { typename left_type::dist_eval_type left = left_.make_dist_eval(); typename right_type::dist_eval_type right = right_.make_dist_eval(); - std::shared_ptr pimpl = std::make_shared( - left, right, *world_, trange_, shape_, pmap_, perm_, - batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + if (!general_repermute_) { + std::shared_ptr pimpl = std::make_shared( + left, right, *world_, trange_, shape_, pmap_, perm_, + batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + return dist_eval_type(pimpl); + } - return dist_eval_type(pimpl); + // evaluate in the canonical layout (Summa with perm-free op), then + // re-permute tiles to the target layout with a streaming unary eval; + // trange_/shape_ hold the target-layout structures (see + // init_struct_general), the canonical ones are recomputed here + auto const canonical_trange = make_trange_general(); + auto const canonical_shape = [this]() { + auto s = make_shape_general(); + if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { + // the consumer-supplied mask is expressed in the target layout + auto const inv_perm = outer(perm_).inv(); + s = s.mask(ExprEngine_::override_ptr_->shape->perm(inv_perm)); + } + return s; + }(); + // the inner Summa's result placement must be slab-replicated (the owner + // of a tile independent of its slab index), regardless of the + // (target-layout) pmap the consumer supplied for this node + auto canonical_pmap = std::make_shared( + *world_, proc_grid_.make_pmap(), n_slabs_); + std::shared_ptr pimpl = std::make_shared( + left, right, *world_, canonical_trange, canonical_shape, canonical_pmap, + BipartitePermutation{}, batched_op_type(op_, n_fused_modes_), K_, + proc_grid_, n_slabs_); + dist_eval_type canonical(pimpl); + + typedef TiledArray::detail::UnaryEvalImpl< + dist_eval_type, GeneralRepermuteOp, typename Derived::policy> + repermute_impl_type; + std::shared_ptr wrapper = + std::make_shared( + canonical, *world_, trange_, shape_, pmap_, perm_, + GeneralRepermuteOp{perm_, !this->implicit_permute_outer_}); + return dist_eval_type(wrapper); } /// Expression identification tag From 7c5d1c2b70fcff5b2221ac08a43b557a377ecbaa Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:03:53 -0400 Subject: [PATCH 3/4] tests: inner-node general products (THC, depth-2, non-canonical root target) --- tests/general_product.cpp | 74 +++++++++++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/tests/general_product.cpp b/tests/general_product.cpp index b0d622a31d..8ce13debe5 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -235,26 +235,70 @@ BOOST_AUTO_TEST_CASE(expression_general_product_dense_batched_outer) { BOOST_CHECK_SMALL(diff_norm(c, c_ref, "b,i,k"), 1e-10); } -BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_gated) { - // THC-style reconstruction: +BOOST_AUTO_TEST_CASE(expression_general_product_noncanonical_root_target) { + // a root-level general product with a NON-canonical target layout: the + // product evaluates canonically (r1,p,q) and is re-permuted to the target + // by the streaming unary eval + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + auto x = make_patterned_array(world, tr_x, 1.0); + + TA::TArrayD w, i1, w_ref; + BOOST_REQUIRE_NO_THROW(w("p,q,r1") = x("p,r1") * x("q,r1")); + i1("r1,p,q") = x("p,r1") * x("q,r1"); // canonical evaluation + w_ref("p,q,r1") = i1("r1,p,q"); // plain permute assignment + + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r1"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_depth2) { + // minimal inner-node case: a general product (fused r1) feeding a + // contraction over r1 -- the general child evaluates canonically + // (r1,p,q) and is re-permuted on the fly to the consumer's GEMM layout + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary + auto x = make_patterned_array(world, tr_x, 1.0); + auto z = make_patterned_array(world, tr_z, 2.0); + + TA::TArrayD w; + BOOST_REQUIRE_NO_THROW(w("p,q,r2") = (x("p,r1") * x("q,r1")) * z("r1,r2")); + + TA::TArrayD i1, w_ref; + i1("r1,p,q") = x("p,r1") * x("q,r1"); + w_ref("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r2"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_thc) { + // THC-style reconstruction in ONE expression: // g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * X("r,r2") * X("s,r2") - // r1 is fused in X("p,r1") * X("q,r1") but contracted downstream. The - // first product is an INNER node of the expression tree, where the role of - // r1 cannot be deduced bottom-up (the target reaches only the root); - // resolving this requires top-down index-set deduction (deferred). Until - // then: an informative error, not garbage (bottom-up, X*X would contract - // r1, orphaning the r1 of Z). + // r1 is fused in X("p,r1") * X("q,r1") but contracted downstream, so the + // first product is a general product at an INNER node of the (left-deep) + // expression tree. The top-down index-set deduction demands r1 of the X*X + // node (its consumer carries it) and contracts it where Z meets that + // subtree; higher up, r1 is dropped from the demand (consumed entirely + // within). Verified against the same chain staged through explicit + // intermediates (the pre-deduction recipe, itself differential-tested + // against einsum in expression_general_product_thc_intermediates). auto& world = TA::get_default_world(); TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary - TA::TArrayD x(world, tr_x); - TA::TArrayD z(world, tr_z); - x.fill(1.0); - z.fill(1.0); + auto x = make_patterned_array(world, tr_x, 1.0); + auto z = make_patterned_array(world, tr_z, 2.0); + TA::TArrayD g; - BOOST_CHECK_THROW( - g("p,q,r,s") = x("p,r1") * x("q,r1") * z("r1,r2") * x("r,r2") * x("s,r2"), - TiledArray::Exception); + BOOST_REQUIRE_NO_THROW(g("p,q,r,s") = x("p,r1") * x("q,r1") * z("r1,r2") * + x("r,r2") * x("s,r2")); + + TA::TArrayD i1, i2, i3, g_ref; + i1("r1,p,q") = x("p,r1") * x("q,r1"); + i2("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + i3("r2,p,q,r") = i2("p,q,r2") * x("r,r2"); + g_ref("p,q,r,s") = i3("r2,p,q,r") * x("s,r2"); + + BOOST_CHECK_SMALL(diff_norm(g, g_ref, "p,q,r,s"), 1e-10); } BOOST_AUTO_TEST_CASE(expression_general_product_thc_intermediates) { From f8f4090f54e1a5b98b01c867fc545b4c9f2f3807 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 10:00:28 -0400 Subject: [PATCH 4/4] expressions: address PR review (Phase E) - GeneralRepermuteOp: store/apply only the OUTER result-layout permutation. The streaming re-permute wrapper exists to reorder a general product's outer (result) layout; inner (within-cell) permutation of ToT results is handled separately (init_struct_general / implicit_permute_inner_). Using the full bipartite perm_ would also permute inner cells -- a no-op when the inner perm is identity, but a latent double-apply if an inner perm is ever deferred to a downstream op. Pass outer(perm_) into the op; the host UnaryEvalImpl still receives full perm_ for ordinal/trange remap. - MultEngine down-pass: materialize each child demand as a named lvalue before preferred_layout(), which returns a reference to its argument for leaf/binary engines -- avoids a needlessly fragile bind-to-temporary. --- src/TiledArray/expressions/cont_engine.h | 9 +++++++-- src/TiledArray/expressions/mult_engine.h | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index d79d3a27db..2db08ea23a 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -851,7 +851,12 @@ class ContEngine : public BinaryEngine { typedef value_type result_type; typedef value_type argument_type; static constexpr bool is_consumable = false; - BipartitePermutation perm; + /// Only the *outer* (result-layout) permutation is applied here; inner + /// (within-cell) permutation of tensor-of-tensor results is handled + /// separately (see init_struct_general / implicit_permute_inner_), so this + /// op stores a plain outer Permutation to avoid accidentally permuting + /// inner contents. + Permutation perm; /// false when the consumer fuses the permutation into its own operation /// (implicit permute, e.g. a transposed GEMM): then only the tile /// ordinals/trange are remapped (by the host UnaryEvalImpl) and the tile @@ -917,7 +922,7 @@ class ContEngine : public BinaryEngine { std::shared_ptr wrapper = std::make_shared( canonical, *world_, trange_, shape_, pmap_, perm_, - GeneralRepermuteOp{perm_, !this->implicit_permute_outer_}); + GeneralRepermuteOp{outer(perm_), !this->implicit_permute_outer_}); return dist_eval_type(wrapper); } diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 3f74d4c7cd..3e345b15e5 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -328,11 +328,18 @@ class MultEngine : public ContEngine> { // each child dictates the ORDER of its demand (preferred_layout): a // product child reorders to its canonical (h, eA, eB) layout -- a // general product cannot host a result permutation, and contraction - // consumers absorb any child layout via the GEMM transpose forms - BinaryEngine_::left_.init_indices(BinaryEngine_::left_.preferred_layout( - bipartite_demand(avail_l, avail_r, target_indices))); - BinaryEngine_::right_.init_indices(BinaryEngine_::right_.preferred_layout( - bipartite_demand(avail_r, avail_l, target_indices))); + // consumers absorb any child layout via the GEMM transpose forms. + // Materialize the demands as named lvalues: some preferred_layout() + // overloads (leaf/binary/unary) return a reference to their argument, so + // binding that to a temporary demand would be needlessly fragile. + const BipartiteIndexList left_demand = + bipartite_demand(avail_l, avail_r, target_indices); + const BipartiteIndexList right_demand = + bipartite_demand(avail_r, avail_l, target_indices); + BinaryEngine_::left_.init_indices( + BinaryEngine_::left_.preferred_layout(left_demand)); + BinaryEngine_::right_.init_indices( + BinaryEngine_::right_.preferred_layout(right_demand)); } this->product_type_ = compute_product_type(