diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index c3d2e69da5..856170803f 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -258,6 +258,52 @@ class BinaryEngine : public ExprEngine { return demand; } + /// Deduce each child's index SET top-down from this node's target and + /// initialize the children with the deduced demands (the product-engine + /// down-pass). The index set an inner node must produce depends on what + /// its ancestors need: a shared index of the two children is fused iff + /// this node's target carries it, contracted here otherwise; an index of + /// one child that neither the sibling nor the target carries is consumed + /// entirely within that child's subtree and is not demanded of it (a + /// genuinely orphaned index -- an implicit trace, unsupported -- surfaces + /// as the leaf-level target-is-not-a-permutation error). Each child's + /// demand is ordered target-kept-first followed by the contracted-here + /// indices in leaf-availability order, then reordered by the child's own + /// preferred_layout() (a product child reorders to its canonical + /// (fused, left-free, right-free) layout -- a general product cannot host + /// a result permutation, and contraction consumers absorb any child + /// layout via the GEMM transpose forms). + void init_children_indices(const BipartiteIndexList& target_indices) { + auto const avail_l = left_.available_indices(); + auto const avail_r = right_.available_indices(); + auto demand = [](auto const& avail, auto const& sibling, auto const& tgt) { + container::svector r; + for (auto&& idx : tgt) + if (avail.count(idx)) r.push_back(idx); + for (auto&& idx : avail) { + if (tgt.count(idx)) continue; + if (sibling.count(idx)) r.push_back(idx); + } + return r; + }; + auto bipartite_demand = [&demand](auto const& avail, auto const& sib, + auto const& tgt) { + auto const out = demand(outer(avail), outer(sib), outer(tgt)); + auto const in = demand(inner(avail), inner(sib), inner(tgt)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + }; + // Materialize the demands as named lvalues: some preferred_layout() + // overloads (leaf/binary/unary) return a reference to their argument, so + // binding that to a temporary demand would be needlessly fragile. + const BipartiteIndexList left_demand = + bipartite_demand(avail_l, avail_r, target_indices); + const BipartiteIndexList right_demand = + bipartite_demand(avail_r, avail_l, target_indices); + left_.init_indices(left_.preferred_layout(left_demand)); + right_.init_indices(right_.preferred_layout(right_demand)); + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 2db08ea23a..fdf5b0da82 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -674,19 +674,37 @@ class ContEngine : public BinaryEngine { this->init_perm(target_indices); general_repermute_ = (outer(target_indices) != outer(indices_)); - // the tile op operates on the folded (fused-mode-free) shapes - const auto left_op = to_cblas_op(left_outer_permtype_); + // A product with NO external (free) outer indices (every outer index + // fused or contracted, e.g. C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) + // folds to a GEMM with no free modes, i.e. rank-0 tensors, which the + // tile kernels do not support. Evaluate it with a SYNTHETIC UNIT + // left-external mode instead: the folded product becomes + // (1,K) x (K) -> (1), the exact shape of the (supported) one-sided + // neB == 0 case. The unit mode lives only in the tile op's GemmHelper; + // tranges, shapes and tiles carry the true (external-free) ranks, and + // BatchedContractReduce / SparseShape::gemm_batched detect the + // synthetic mode from the one-rank mismatch and pad their folded views + // with a unit extent. + const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u; + + // the tile op operates on the folded (fused-mode-free) shapes; the + // synthetic unit mode leads the folded left operand, so it is NoTrans + const auto left_op = + u ? math::blas::NoTranspose : to_cblas_op(left_outer_permtype_); const auto right_op = to_cblas_op(right_outer_permtype_); if constexpr (!TiledArray::detail::is_tensor_of_tensor_v) { - op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh, - outer_size(left_indices_) - nh, + op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh + u, + outer_size(left_indices_) - nh + u, outer_size(right_indices_) - nh); } else { // the batched tile op must be perm-free (BatchedContractReduce cannot - // host the folded-rank result permutation); the outer perm is empty by - // the interleaved-target gate above, so only an explicit inner result - // permutation can require one - if (!implicit_permute_inner_ && bool(inner(perm_))) + // host the folded-rank result permutation); the outer perm is handled + // by the streaming re-permute (general_repermute_), so only a genuine + // (non-identity) explicit inner result permutation requires one. N.B. + // perm_ may carry a non-null identity inner component when only the + // outer modes are permuted (the bipartite perm is constructed whole). + if (!implicit_permute_inner_ && bool(inner(perm_)) && + !inner(perm_).is_identity()) TA_EXCEPTION( "general products of tensors-of-tensors: a non-identity inner " "result permutation is not yet supported; reorder the inner " @@ -694,7 +712,8 @@ class ContEngine : public BinaryEngine { // factor_ is absorbed into element_nonreturn_op_ op_ = op_type(left_op, right_op, scalar_type(1), - outer_size(indices_) - nh, outer_size(left_indices_) - nh, + outer_size(indices_) - nh + u, + outer_size(left_indices_) - nh + u, outer_size(right_indices_) - nh, BipartitePermutation{}, this->element_nonreturn_op_, std::move(this->arena_plan_)); // ce+e, ce+ce_right and ce+ce_left are mutually exclusive; at most one @@ -734,7 +753,11 @@ class ContEngine : public BinaryEngine { trange_type make_trange_general() const { const unsigned int nh = n_fused_modes_; const unsigned int nc = op_.gemm_helper().num_contract_ranks(); - const unsigned int neA = op_.gemm_helper().left_rank() - nc; + // the no-external case carries a synthetic unit left-external mode in + // the GemmHelper only (see init_struct_general); the actual tranges do + // not have it + const unsigned int u = (outer_size(indices_) == n_fused_modes_) ? 1u : 0u; + const unsigned int neA = op_.gemm_helper().left_rank() - nc - u; const unsigned int neB = op_.gemm_helper().right_rank() - nc; typename trange_type::Ranges ranges(nh + neA + neB); @@ -786,7 +809,11 @@ class ContEngine : public BinaryEngine { std::shared_ptr pmap) { const unsigned int nh = n_fused_modes_; const unsigned int nc = op_.gemm_helper().num_contract_ranks(); - const unsigned int neA = op_.gemm_helper().left_rank() - nc; + // the no-external case carries a synthetic unit left-external mode in + // the GemmHelper only (see init_struct_general); the actual tranges do + // not have it + const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u; + const unsigned int neA = op_.gemm_helper().left_rank() - nc - u; const unsigned int neB = op_.gemm_helper().right_rank() - nc; // Get pointers to the argument sizes @@ -1707,7 +1734,11 @@ class ContEngine : public BinaryEngine { TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_scale) { - if (this->outer_product_uses_summa()) { + // the fused arena scale ops are factor-free; a non-unit + // expression-level prefactor (ScalMult) takes the fallback op, + // which absorbs it + if (this->outer_product_uses_summa() && + this->factor_ == scalar_type(1)) { // The inner perm handed to the plan must match how the inner // *result* permutation is applied for this result cell type -- // and the two cell types apply it in different places: @@ -1741,45 +1772,48 @@ class ContEngine : public BinaryEngine { // cells. The Hadamard outer product is an assignment // `result = (perm ^ tot) * scalar`, which needs value-returning // `scale`; only owning inner cells support it. - auto fallback_op = [perm = !this->implicit_permute_inner_ - ? inner(this->perm_) - : Permutation{}, - outer_uses_summa = - this->outer_product_uses_summa()]( - result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - if (outer_uses_summa) { - using TiledArray::axpy_to; - if constexpr (tot_x_t) { - if (left.empty()) return; // absent cell: no contribution - if (perm) - axpy_to(result, left, right, perm); - else - axpy_to(result, left, right); - } else { - if (right.empty()) return; // absent cell: no contribution - if (perm) - axpy_to(result, right, left, perm); - else - axpy_to(result, right, left); - } - } else { - if constexpr (!TiledArray::is_tensor_view_v< - result_tile_element_type>) { - using TiledArray::scale; - if constexpr (tot_x_t) - result = perm ? scale(left, right, perm) : scale(left, right); - else - result = perm ? scale(right, left, perm) : scale(right, left); - } else { - TA_EXCEPTION( - "Tensor scale-inner Hadamard-outer product: a " - "view result cell cannot be value-assigned a fresh " - "scaled tensor"); - } - } - }; + // N.B. the expression-level scalar prefactor (factor_, != 1 for + // ScalMult expressions) multiplies the plain operand's element + auto fallback_op = + [perm = !this->implicit_permute_inner_ ? inner(this->perm_) + : Permutation{}, + outer_uses_summa = this->outer_product_uses_summa(), + factor = this->factor_](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (outer_uses_summa) { + using TiledArray::axpy_to; + if constexpr (tot_x_t) { + if (left.empty()) return; // absent cell: no contribution + if (perm) + axpy_to(result, left, right * factor, perm); + else + axpy_to(result, left, right * factor); + } else { + if (right.empty()) return; // absent cell: no contribution + if (perm) + axpy_to(result, right, left * factor, perm); + else + axpy_to(result, right, left * factor); + } + } else { + if constexpr (!TiledArray::is_tensor_view_v< + result_tile_element_type>) { + using TiledArray::scale; + if constexpr (tot_x_t) + result = perm ? scale(left, right * factor, perm) + : scale(left, right * factor); + else + result = perm ? scale(right, left * factor, perm) + : scale(right, left * factor); + } else { + TA_EXCEPTION( + "Tensor scale-inner Hadamard-outer product: a " + "view result cell cannot be value-assigned a fresh " + "scaled tensor"); + } + } + }; if constexpr (arena_eligible_scale) { if (this->arena_plan_) { if constexpr (tot_x_t) diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 3e345b15e5..1ad68119c5 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -299,48 +299,7 @@ class MultEngine : public ContEngine> { // take the no-target init_indices() overload below, which retains the // bottom-up contraction convention -- general products under // reductions remain unsupported. - { - auto const avail_l = BinaryEngine_::left_.available_indices(); - auto const avail_r = BinaryEngine_::right_.available_indices(); - auto demand = [](auto const& avail, auto const& sibling, - auto const& tgt) { - container::svector r; - for (auto&& idx : tgt) - if (avail.count(idx)) r.push_back(idx); - for (auto&& idx : avail) { - if (tgt.count(idx)) continue; - if (sibling.count(idx)) r.push_back(idx); - // else: the index is consumed entirely within the child's own - // subtree (contracted deeper down) -- not demanded here. A - // genuinely orphaned index (single occurrence, demanded nowhere = - // implicit trace, unsupported) surfaces as the leaf-level - // target-is-not-a-permutation error. - } - return r; - }; - auto bipartite_demand = [&demand](auto const& avail, auto const& sib, - auto const& tgt) { - auto const out = demand(outer(avail), outer(sib), outer(tgt)); - auto const in = demand(inner(avail), inner(sib), inner(tgt)); - return BipartiteIndexList(IndexList(out.begin(), out.end()), - IndexList(in.begin(), in.end())); - }; - // each child dictates the ORDER of its demand (preferred_layout): a - // product child reorders to its canonical (h, eA, eB) layout -- a - // general product cannot host a result permutation, and contraction - // consumers absorb any child layout via the GEMM transpose forms. - // Materialize the demands as named lvalues: some preferred_layout() - // overloads (leaf/binary/unary) return a reference to their argument, so - // binding that to a temporary demand would be needlessly fragile. - const BipartiteIndexList left_demand = - bipartite_demand(avail_l, avail_r, target_indices); - const BipartiteIndexList right_demand = - bipartite_demand(avail_r, avail_l, target_indices); - BinaryEngine_::left_.init_indices( - BinaryEngine_::left_.preferred_layout(left_demand)); - BinaryEngine_::right_.init_indices( - BinaryEngine_::right_.preferred_layout(right_demand)); - } + BinaryEngine_::init_children_indices(target_indices); this->product_type_ = compute_product_type( outer(BinaryEngine_::left_.indices()), @@ -713,29 +672,34 @@ class ScalMultEngine /// \param target_indices The target index list for this expression void init_indices(const BipartiteIndexList& target_indices) { - BinaryEngine_::left_.init_indices(); - BinaryEngine_::right_.init_indices(); + // deduce the children's index sets top-down (see + // BinaryEngine::init_children_indices), then classify and route exactly + // as MultEngine does (the scalar factor does not affect index roles) + BinaryEngine_::init_children_indices(target_indices); + this->product_type_ = compute_product_type( outer(BinaryEngine_::left_.indices()), outer(BinaryEngine_::right_.indices()), outer(target_indices)); + this->inner_product_type_ = compute_product_type( + inner(BinaryEngine_::left_.indices()), + inner(BinaryEngine_::right_.indices()), inner(target_indices)); + + if (this->inner_product_type_ == TensorProduct::General) + TA_EXCEPTION( + "ScalMultEngine: general products (fused + contracted + free " + "indices) between the inner (nested) indices of tensors-of-tensors " + "are not supported"); if (this->product_type() == TensorProduct::Hadamard) { // since already initialized left and right arg indices assign the target // indices BinaryEngine_::perm_indices(target_indices); } else if (this->product_type() == TensorProduct::General) { - // layout via GeneralPermutationOptimizer (the target determines which - // shared indices are fused vs contracted), then propagate to children - if (!this->implicit_permute()) { - BinaryEngine_::template init_indices_( - target_indices); - if (BinaryEngine_::left_indices_ != BinaryEngine_::left_.indices()) - BinaryEngine_::left_.perm_indices(BinaryEngine_::left_indices_); - if (BinaryEngine_::right_indices_ != BinaryEngine_::right_.indices()) - BinaryEngine_::right_.perm_indices(BinaryEngine_::right_indices_); - } + this->perm_indices(target_indices); } else { - ContEngine_::init_indices(target_indices); + auto children_initialized = true; + ContEngine_::init_indices(children_initialized); + ContEngine_::perm_indices(target_indices); } } @@ -767,12 +731,14 @@ class ScalMultEngine /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { - // TODO Phase B (batched Summa): evaluate general products natively - if (this->product_type() == TensorProduct::General) - TA_EXCEPTION( - "ScalMultEngine: evaluation of general products (fused + contracted " - "+ free indices) via the expression layer is not yet implemented; " - "use TiledArray::einsum() instead"); + if (this->product_type() == TensorProduct::General) { + // the inner tile op (for tensors-of-tensors) must be initialized + // first; init_struct_general consumes element_nonreturn_op_ and the + // arena plan it builds + this->init_inner_tile_op(inner(target_indices)); + ContEngine_::init_struct_general(target_indices); + return; + } this->init_perm(target_indices); @@ -794,6 +760,8 @@ class ScalMultEngine std::shared_ptr pmap) { if (this->product_type() == TensorProduct::Contraction) ContEngine_::init_distribution(world, pmap); + else if (this->product_type() == TensorProduct::General) + ContEngine_::init_distribution_general(world, pmap); else BinaryEngine_::init_distribution(world, pmap); } @@ -804,6 +772,8 @@ class ScalMultEngine dist_eval_type make_dist_eval() const { if (this->product_type() == TensorProduct::Contraction) return ContEngine_::make_dist_eval(); + else if (this->product_type() == TensorProduct::General) + return ContEngine_::make_dist_eval_general(); else return BinaryEngine_::make_dist_eval(); } @@ -814,6 +784,8 @@ class ScalMultEngine trange_type make_trange() const { if (this->product_type() == TensorProduct::Contraction) return ContEngine_::make_trange(); + else if (this->product_type() == TensorProduct::General) + return ContEngine_::make_trange_general(); else return BinaryEngine_::make_trange(); } diff --git a/src/TiledArray/sparse_shape.h b/src/TiledArray/sparse_shape.h index 7af41176c7..5ab12a3dee 100644 --- a/src/TiledArray/sparse_shape.h +++ b/src/TiledArray/sparse_shape.h @@ -1720,10 +1720,19 @@ class SparseShape { const auto* left_extent = tile_norms_.range().extent_data(); const auto* right_extent = other.tile_norms_.range().extent_data(); + // a no-external product carries a SYNTHETIC unit left-external mode in + // the GemmHelper only (see ContEngine::init_struct_general); detect it + // from the one-rank mismatch with the actual norm tensor and pad the + // folded left/result views with a unit extent + const bool unit_external = + (tile_norms_.range().rank() + 1u == nfused + gemm_helper.left_rank()); + const unsigned int u = unit_external ? 1u : 0u; + // check that the ranks match the folded gemm ranks plus the fused modes, // and that the fused and contracted mode extents of the two shapes are // congruent - TA_ASSERT(tile_norms_.range().rank() == nfused + gemm_helper.left_rank()); + TA_ASSERT(tile_norms_.range().rank() + u == + nfused + gemm_helper.left_rank()); TA_ASSERT(other.tile_norms_.range().rank() == nfused + gemm_helper.right_rank()); for (unsigned int d = 0u; d < nfused; ++d) @@ -1731,31 +1740,34 @@ class SparseShape { for (unsigned int i = gemm_helper.left_inner_begin(), j = gemm_helper.right_inner_begin(); i < gemm_helper.left_inner_end(); ++i, ++j) - TA_ASSERT(left_extent[nfused + i] == right_extent[nfused + j]); + TA_ASSERT(left_extent[nfused + i - u] == right_extent[nfused + j]); integer H = 1, M = 1, N = 1, K = 1; for (unsigned int d = 0u; d < nfused; ++d) H *= left_extent[d]; - for (unsigned int i = gemm_helper.left_outer_begin(); - i < gemm_helper.left_outer_end(); ++i) - M *= left_extent[nfused + i]; + if (!unit_external) + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i) + M *= left_extent[nfused + i]; for (unsigned int i = gemm_helper.left_inner_begin(); i < gemm_helper.left_inner_end(); ++i) - K *= left_extent[nfused + i]; + K *= left_extent[nfused + i - u]; for (unsigned int i = gemm_helper.right_outer_begin(); i < gemm_helper.right_outer_end(); ++i) N *= right_extent[nfused + i]; // result size vectors: fused modes (from this), then the left and right - // outer modes - const unsigned int result_rank = nfused + gemm_helper.result_rank(); + // outer modes (the synthetic unit left-external mode is absent from the + // actual result) + const unsigned int result_rank = nfused + gemm_helper.result_rank() - u; std::shared_ptr result_size_vectors( new vector_type[result_rank], std::default_delete()); unsigned int x = 0ul; for (unsigned int i = 0u; i < nfused; ++i, ++x) result_size_vectors.get()[x] = size_vectors_.get()[i]; - for (unsigned int i = gemm_helper.left_outer_begin(); - i < gemm_helper.left_outer_end(); ++i, ++x) - result_size_vectors.get()[x] = size_vectors_.get()[nfused + i]; + if (!unit_external) + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i, ++x) + result_size_vectors.get()[x] = size_vectors_.get()[nfused + i]; for (unsigned int i = gemm_helper.right_outer_begin(); i < gemm_helper.right_outer_end(); ++i, ++x) result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i]; @@ -1770,11 +1782,12 @@ class SparseShape { lobounds.push_back(tile_norms_.range().lobound_data()[d]); upbounds.push_back(tile_norms_.range().upbound_data()[d]); } - for (unsigned int i = gemm_helper.left_outer_begin(); - i < gemm_helper.left_outer_end(); ++i) { - lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]); - upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]); - } + if (!unit_external) + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i) { + lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]); + upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]); + } for (unsigned int i = gemm_helper.right_outer_begin(); i < gemm_helper.right_outer_end(); ++i) { lobounds.push_back(other.tile_norms_.range().lobound_data()[nfused + i]); @@ -1784,10 +1797,13 @@ class SparseShape { // the range spanned by modes [nfused, rank) of \p r, rebased to zero // lobounds (scratch view for the slab-batched norm GEMM) - auto fold_range = [nfused](const range_type& r) { + auto fold_range = [nfused](const range_type& r, + const bool prepend_unit = false) { const auto* extent = r.extent_data(); - container::svector extents(extent + nfused, - extent + r.rank()); + container::svector extents; + extents.reserve(r.rank() - nfused + (prepend_unit ? 1u : 0u)); + if (prepend_unit) extents.push_back(1); + extents.insert(extents.end(), extent + nfused, extent + r.rank()); return range_type(extents); }; @@ -1797,10 +1813,11 @@ class SparseShape { if (k_rank > 0u) { // the contracted-mode size vector; identical for every slab since the - // contracted modes follow the fused modes + // contracted modes follow the fused modes (helper coordinates carry + // the synthetic unit mode, the actual size vectors do not) const vector_type k_sizes = recursive_outer_product( - size_vectors_.get() + nfused + gemm_helper.left_inner_begin(), k_rank, - [](const vector_type& size_vector) -> const vector_type& { + size_vectors_.get() + nfused + gemm_helper.left_inner_begin() - u, + k_rank, [](const vector_type& size_vector) -> const vector_type& { return size_vector; }); @@ -1828,10 +1845,11 @@ class SparseShape { // slab-batched norm GEMM: fold the fused modes into the tensor batch // dimension by zero-copy reshape; result_folded shares result_norms' // buffer, so the accumulation lands in place - auto left_folded = left.reshape(fold_range(left.range()), H); + auto left_folded = + left.reshape(fold_range(left.range(), unit_external), H); auto right_folded = right.reshape(fold_range(right.range()), H); - auto result_folded = - result_norms.reshape(fold_range(result_norms.range()), H); + auto result_folded = result_norms.reshape( + fold_range(result_norms.range(), unit_external), H); result_folded.gemm(left_folded, right_folded, abs_factor, gemm_helper); // Hard zero tiles that are below the zero threshold. diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h index 570a7398c8..48b028e267 100644 --- a/src/TiledArray/tile_op/batched_contract_reduce.h +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -62,12 +62,16 @@ class BatchedContractReduce { /// \return the range spanned by modes [nfused_, rank) of \p r, rebased to /// zero lobounds (the folded view is a GEMM scratch view; only extents - /// matter) + /// matter); \p prepend_unit prepends a unit extent (the synthetic + /// left-external mode of a no-external product, see + /// ContEngine::init_struct_general) template - Range_ fold_range(const Range_& r) const { + Range_ fold_range(const Range_& r, const bool prepend_unit = false) const { const auto* extent = r.extent_data(); - container::svector extents(extent + nfused_, - extent + r.rank()); + container::svector extents; + extents.reserve(r.rank() - nfused_ + (prepend_unit ? 1u : 0u)); + if (prepend_unit) extents.push_back(1); + extents.insert(extents.end(), extent + nfused_, extent + r.rank()); return Range_(extents); } @@ -150,7 +154,13 @@ class BatchedContractReduce { const auto& gh = op_.gemm_helper(); const unsigned int nc = gh.num_contract_ranks(); - const unsigned int neA = gh.left_rank() - nc; + // a no-external product carries a SYNTHETIC unit left-external mode in + // the GemmHelper only (see ContEngine::init_struct_general); detect it + // from the one-rank mismatch with the actual left tile and pad the + // folded left/result views with a unit extent + const bool unit_external = + (left.range().rank() + 1u == nfused_ + gh.left_rank()); + const unsigned int neA = gh.left_rank() - nc - (unit_external ? 1u : 0u); const unsigned int neB = gh.right_rank() - nc; // both args must carry the fused modes as their leading modes, with @@ -163,7 +173,8 @@ class BatchedContractReduce { TA_ASSERT(batch == fused_volume(right.range())); // folded, zero-copy argument views - auto left_folded = left.reshape(fold_range(left.range()), batch); + auto left_folded = + left.reshape(fold_range(left.range(), unit_external), batch); auto right_folded = right.reshape(fold_range(right.range()), batch); if (empty(result)) { @@ -195,7 +206,8 @@ class BatchedContractReduce { } else { // accumulate through a folded, zero-copy view of the result const auto full_range = result.range(); - auto result_folded = result.reshape(fold_range(full_range), batch); + auto result_folded = + result.reshape(fold_range(full_range, unit_external), batch); op_(result_folded, left_folded, right_folded); // the wrapped op may REBIND the result instead of writing in place: // the arena grow-to-cover path (a later K-panel touching inner cells diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 8ce13debe5..58976f87dc 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -528,6 +528,269 @@ BOOST_AUTO_TEST_CASE(expression_general_product_t_times_tot) { BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_depth2_chains) { + // mixed T/ToT products at INNER nodes of the expression tree: + // left-deep: w("i,k;x") = (s("i,j") * t("j,m")) * c("m,k;x") + // right-deep: w("i,k;x") = s("i,j") * (t("j,m") * c("m,k;x")) + // (plain contraction at every node, inner Scale where a plain factor + // meets the nested one) + auto& world = TA::get_default_world(); + TA::TiledRange tr_s{{0, 2, 4}, {0, 2, 5}}; // i, j + TA::TiledRange tr_t{{0, 2, 5}, {0, 3, 4}}; // j, m + TA::TiledRange tr_c{{0, 3, 4}, {0, 2, 3}}; // m, k + auto s = make_patterned_array(world, tr_s, 1.0); + auto t = make_patterned_array(world, tr_t, 2.0); + auto c = make_patterned_tot_array(world, tr_c, {3}, 3.0); + + // staged reference + TArrayToT i1, w_ref; + i1("j,k;x") = t("j,m") * c("m,k;x"); + w_ref("i,k;x") = s("i,j") * i1("j,k;x"); + + TArrayToT w_l, w_r; + BOOST_REQUIRE_NO_THROW(w_l("i,k;x") = (s("i,j") * t("j,m")) * c("m,k;x")); + BOOST_CHECK_SMALL(tot_max_abs_diff(w_l, w_ref), 1e-10); + BOOST_REQUIRE_NO_THROW(w_r("i,k;x") = s("i,j") * (t("j,m") * c("m,k;x"))); + BOOST_CHECK_SMALL(tot_max_abs_diff(w_r, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_inner_general) { + // a mixed T x ToT GENERAL product at an INNER node: + // w("i,j;x") = (g("b,i") * c("b,j;x")) * h("b") + // b is fused where g meets c (demanded by the h factor above) and + // contracted at the root + auto& world = TA::get_default_world(); + TA::TiledRange tr_g{{0, 3, 5}, {0, 2, 4}}; // b, i + TA::TiledRange tr_c{{0, 3, 5}, {0, 2, 3}}; // b, j + TA::TiledRange tr_h{{0, 3, 5}}; // b + auto g = make_patterned_array(world, tr_g, 1.0); + auto c = make_patterned_tot_array(world, tr_c, {3}, 2.0); + TA::TArrayD h(world, tr_h); + h.fill(1.5); + + TArrayToT i1, w_ref; + i1("b,i,j;x") = g("b,i") * c("b,j;x"); // depth-1 mixed general + w_ref("i,j;x") = i1("b,i,j;x") * h("b"); + + TArrayToT w; + try { + w("i,j;x") = (g("b,i") * c("b,j;x")) * h("b"); + } catch (std::exception& e) { + BOOST_FAIL(std::string("threw: ") + e.what()); + } + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_scaled) { + // scaled mixed T x ToT general product (ScalMultEngine path): + // w("b,i,k;x") = 2 * (a("b,i,j") * c("b,j,k;x")) + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_c{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto c = make_patterned_tot_array(world, tr_c, {3}, 2.0); + + TArrayToT w_ref0, w_ref, w; + w_ref0("b,i,k;x") = a("b,i,j") * c("b,j,k;x"); + w_ref("b,i,k;x") = 2.0 * w_ref0("b,i,k;x"); + BOOST_REQUIRE_NO_THROW(w("b,i,k;x") = 2.0 * (a("b,i,j") * c("b,j,k;x"))); + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_block_operands) { + // general product of BLOCK views: C("b,i,k") = A.block * B.block with + // b fused, j contracted, i/k free; the blocks restrict b and i/k + auto& world = TA::get_default_world(); + TA::TiledRange tr{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; // (b,i,j) / (b,j,k) + auto a = make_patterned_array(world, tr, 1.0); + auto t = make_patterned_array(world, tr, 2.0); + + // materialized blocks as the reference operands + TA::TArrayD ab, tb, w_ref, w; + ab("b,i,j") = a("b,i,j").block({0, 0, 0}, {1, 2, 2}); + tb("b,j,k") = t("b,j,k").block({0, 0, 0}, {1, 2, 1}); + w_ref("b,i,k") = ab("b,i,j") * tb("b,j,k"); + + BOOST_REQUIRE_NO_THROW(w("b,i,k") = a("b,i,j").block({0, 0, 0}, {1, 2, 2}) * + t("b,j,k").block({0, 0, 0}, {1, 2, 1})); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_into_block) { + // general product assigned INTO a block view of the result: + // W.block = A * B (b fused, j contracted) + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2}, {0, 2}, {0, 2, 4}}; // b, i, j + TA::TiledRange tr_b{{0, 2}, {0, 2, 4}, {0, 2}}; // b, j, k + TA::TiledRange tr_w{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; + auto a = make_patterned_array(world, tr_a, 1.0); + auto t = make_patterned_array(world, tr_b, 2.0); + + TA::TArrayD prod; + prod("b,i,k") = a("b,i,j") * t("b,j,k"); + + TA::TArrayD w(world, tr_w), w_ref(world, tr_w); + w.fill(0.0); + w_ref.fill(0.0); + w_ref("b,i,k").block({0, 0, 0}, {1, 1, 1}) = prod("b,i,k"); + + BOOST_REQUIRE_NO_THROW(w("b,i,k").block({0, 0, 0}, {1, 1, 1}) = + a("b,i,j") * t("b,j,k")); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_block_in_tree) { + // a BLOCK leaf under an inner general node of a deeper tree: + // w("p,q,r2") = (x.block("p,r1") * y("q,r1")) * z("r1,r2") + // r1 is fused at the inner (general) node, contracted at the root + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // p x r1 + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // r1 x r2 + auto x = make_patterned_array(world, tr_x, 1.0); + auto y = make_patterned_array(world, tr_x, 1.5); + auto z = make_patterned_array(world, tr_z, 2.0); + + TA::TArrayD xb, i1, w_ref, w; + xb("p,r1") = x("p,r1").block({0, 0}, {1, 2}); + i1("r1,p,q") = xb("p,r1") * y("q,r1"); + w_ref("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + + BOOST_REQUIRE_NO_THROW( + w("p,q,r2") = (x("p,r1").block({0, 0}, {1, 2}) * y("q,r1")) * z("r1,r2")); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r2"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_repermute_into_block) { + // a general product with a NON-canonical target layout (streaming + // re-permute) assigned INTO a block view of the result + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2}, {0, 3, 5}}; // p x r1 + TA::TiledRange tr_w{{0, 2, 4}, {0, 2, 4}, {0, 3, 5}}; // p, q, r1 + auto x = make_patterned_array(world, tr_x, 1.0); + auto y = make_patterned_array(world, tr_x, 1.5); + + TA::TArrayD i1, w(world, tr_w), w_ref(world, tr_w); + w.fill(0.0); + w_ref.fill(0.0); + i1("r1,p,q") = x("p,r1") * y("q,r1"); // canonical evaluation + w_ref("p,q,r1").block({0, 0, 0}, {1, 1, 2}) = i1("r1,p,q"); + + BOOST_REQUIRE_NO_THROW(w("p,q,r1").block({0, 0, 0}, {1, 1, 2}) = + x("p,r1") * y("q,r1")); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r1"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_sum_under_product) { + // a SUM nested under a product, with a general product as one summand: + // F("i,j") = A("x,i") * (B("x,k") * C("x,k,j") + D("x,j")) + // x is fused inside B*C (demanded by the sum's consumer), k is contracted + // within the summand (never escapes), and x is contracted at the root. + // The down-pass prunes k from the sum's demand automatically (it appears + // in neither the target nor A) and hands (j,x) to BOTH summands. + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 3, 5}, {0, 2, 4}}; // x, i + TA::TiledRange tr_b{{0, 3, 5}, {0, 2, 3}}; // x, k + TA::TiledRange tr_c{{0, 3, 5}, {0, 2, 3}, {0, 2, 4}}; // x, k, j + TA::TiledRange tr_d{{0, 3, 5}, {0, 2, 4}}; // x, j + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + auto c = make_patterned_array(world, tr_c, 3.0); + auto d = make_patterned_array(world, tr_d, 4.0); + + TA::TArrayD i1, s, f_ref, f; + i1("x,j") = b("x,k") * c("x,k,j"); // general: x fused, k contracted + s("x,j") = i1("x,j") + d("x,j"); + f_ref("i,j") = a("x,i") * s("x,j"); + + BOOST_REQUIRE_NO_THROW(f("i,j") = + a("x,i") * (b("x,k") * c("x,k,j") + d("x,j"))); + BOOST_CHECK_SMALL(diff_norm(f, f_ref, "i,j"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_kitchen_sink) { + // one expression combining a THC-like batching index (x: fused at the + // inner node, contracted at the root), a mixed T x ToT general product, + // a ToT x ToT general product with an inner outer-product, and a scalar + // prefactor (ScalMult at the root): + // W("i,j,m;a,b") = 2 * ((g("x,i") * cv("x,j;a")) * dv("x,i,m;b")) + auto& world = TA::get_default_world(); + TA::TiledRange tr_g{{0, 3, 5}, {0, 2, 4}}; // x, i + TA::TiledRange tr_cv{{0, 3, 5}, {0, 2, 3}}; // x, j + TA::TiledRange tr_dv{{0, 3, 5}, {0, 2, 4}, {0, 2, 3}}; // x, i, m + auto g = make_patterned_array(world, tr_g, 1.0); + auto cv = make_patterned_tot_array(world, tr_cv, {2}, 2.0); + auto dv = make_patterned_tot_array(world, tr_dv, {3}, 3.0); + + TArrayToT i1, i2, w_ref, w; + i1("x,i,j;a") = g("x,i") * cv("x,j;a"); // T x ToT general + i2("i,j,m;a,b") = i1("x,i,j;a") * dv("x,i,m;b"); // ToT x ToT, inner outer + w_ref("i,j,m;a,b") = 2.0 * i2("i,j,m;a,b"); + + BOOST_REQUIRE_NO_THROW(w("i,j,m;a,b") = + 2.0 * ((g("x,i") * cv("x,j;a")) * dv("x,i,m;b"))); + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_no_externals) { + // a ToT x ToT general product with NO external (free) outer indices -- + // every outer index fused or contracted. The folded product has no free + // modes, so it is evaluated with a synthetic unit left-external mode + // carried by the tile op's GemmHelper only (this shape used to segfault). + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr{{0, 3, 5}, {0, 2, 4}, {0, 2, 3}}; // x, i, j + auto a = make_patterned_tot_array(world, tr, {2}, 1.0); + auto b = make_patterned_tot_array(world, tr, {3}, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("i,j;a,b") = a("x,i,j;a") * b("x,i,j;b")); + auto c_ref = TA::einsum(a("x,i,j;a"), b("x,i,j;b"), "i,j;a,b"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); + + // the same no-external root with the left operand produced by a general + // T x ToT product at an INNER node (the original motivating expression) + TA::TiledRange tr_g{{0, 3, 5}, {0, 2, 4}}; // x, i + TA::TiledRange tr_cv{{0, 3, 5}, {0, 2, 3}}; // x, j + auto g = make_patterned_array(world, tr_g, 1.0); + auto cv = make_patterned_tot_array(world, tr_cv, {2}, 2.0); + TArrayToT i1, w, w_ref; + i1("x,i,j;a") = g("x,i") * cv("x,j;a"); + w_ref("i,j;a,b") = i1("x,i,j;a") * b("x,i,j;b"); + BOOST_REQUIRE_NO_THROW(w("i,j;a,b") = + (g("x,i") * cv("x,j;a")) * b("x,i,j;b")); + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_no_externals) { + // the plain-tensor analogue of the no-external general product (the + // Hadamard-reduction shape): C("i") = A("i,j") * B("i,j") + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr{{0, 2, 4}, {0, 3, 5}}; // i, j + auto a = make_patterned_array(world, tr, 1.0); + auto b = make_patterned_array(world, tr, 2.0); + + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("i") = a("i,j") * b("i,j")); + auto c_ref = TA::einsum(a("i,j"), b("i,j"), "i"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "i"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_no_externals) { + // block-sparse no-external general product: exercises the synthetic + // unit-mode handling in SparseShape::gemm_batched + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j + auto a = make_patterned_sparse_array(world, tr, 1.0, 3); + auto b = make_patterned_sparse_array(world, tr, 2.0, 4); + + TA::TSpArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i") = a("b,i,j") * b("b,i,j")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,i,j"), "b,i"); + BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i"), 1e-10); +} + BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_outer_product) { // the PNO-CC PPL building-block shape: ToT x ToT with an EMPTY right // outer-external set and an inner OUTER-product: