Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/TiledArray/expressions/binary_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,52 @@ class BinaryEngine : public ExprEngine<Derived> {
return demand;
}

/// Deduce each child's index SET top-down from this node's target and
/// initialize the children with the deduced demands (the product-engine
/// down-pass). The index set an inner node must produce depends on what
/// its ancestors need: a shared index of the two children is fused iff
/// this node's target carries it, contracted here otherwise; an index of
/// one child that neither the sibling nor the target carries is consumed
/// entirely within that child's subtree and is not demanded of it (a
/// genuinely orphaned index -- an implicit trace, unsupported -- surfaces
/// as the leaf-level target-is-not-a-permutation error). Each child's
/// demand is ordered target-kept-first followed by the contracted-here
/// indices in leaf-availability order, then reordered by the child's own
/// preferred_layout() (a product child reorders to its canonical
/// (fused, left-free, right-free) layout -- a general product cannot host
/// a result permutation, and contraction consumers absorb any child
/// layout via the GEMM transpose forms).
void init_children_indices(const BipartiteIndexList& target_indices) {
auto const avail_l = left_.available_indices();
auto const avail_r = right_.available_indices();
auto demand = [](auto const& avail, auto const& sibling, auto const& tgt) {
container::svector<std::string> r;
for (auto&& idx : tgt)
if (avail.count(idx)) r.push_back(idx);
for (auto&& idx : avail) {
if (tgt.count(idx)) continue;
if (sibling.count(idx)) r.push_back(idx);
}
return r;
};
auto bipartite_demand = [&demand](auto const& avail, auto const& sib,
auto const& tgt) {
auto const out = demand(outer(avail), outer(sib), outer(tgt));
auto const in = demand(inner(avail), inner(sib), inner(tgt));
return BipartiteIndexList(IndexList(out.begin(), out.end()),
IndexList(in.begin(), in.end()));
};
// Materialize the demands as named lvalues: some preferred_layout()
// overloads (leaf/binary/unary) return a reference to their argument, so
// binding that to a temporary demand would be needlessly fragile.
const BipartiteIndexList left_demand =
bipartite_demand(avail_l, avail_r, target_indices);
const BipartiteIndexList right_demand =
bipartite_demand(avail_r, avail_l, target_indices);
left_.init_indices(left_.preferred_layout(left_demand));
right_.init_indices(right_.preferred_layout(right_demand));
}

/// Initialize result tensor structure

/// This function will initialize the permutation, tiled range, and shape
Expand Down
136 changes: 85 additions & 51 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -674,27 +674,46 @@ class ContEngine : public BinaryEngine<Derived> {
this->init_perm(target_indices);
general_repermute_ = (outer(target_indices) != outer(indices_));

// the tile op operates on the folded (fused-mode-free) shapes
const auto left_op = to_cblas_op(left_outer_permtype_);
// A product with NO external (free) outer indices (every outer index
// fused or contracted, e.g. C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b"))
// folds to a GEMM with no free modes, i.e. rank-0 tensors, which the
// tile kernels do not support. Evaluate it with a SYNTHETIC UNIT
// left-external mode instead: the folded product becomes
// (1,K) x (K) -> (1), the exact shape of the (supported) one-sided
// neB == 0 case. The unit mode lives only in the tile op's GemmHelper;
// tranges, shapes and tiles carry the true (external-free) ranks, and
// BatchedContractReduce / SparseShape::gemm_batched detect the
// synthetic mode from the one-rank mismatch and pad their folded views
// with a unit extent.
const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u;

// the tile op operates on the folded (fused-mode-free) shapes; the
// synthetic unit mode leads the folded left operand, so it is NoTrans
const auto left_op =
u ? math::blas::NoTranspose : to_cblas_op(left_outer_permtype_);
const auto right_op = to_cblas_op(right_outer_permtype_);
if constexpr (!TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh,
outer_size(left_indices_) - nh,
op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh + u,
outer_size(left_indices_) - nh + u,
outer_size(right_indices_) - nh);
} else {
// the batched tile op must be perm-free (BatchedContractReduce cannot
// host the folded-rank result permutation); the outer perm is empty by
// the interleaved-target gate above, so only an explicit inner result
// permutation can require one
if (!implicit_permute_inner_ && bool(inner(perm_)))
// host the folded-rank result permutation); the outer perm is handled
// by the streaming re-permute (general_repermute_), so only a genuine
// (non-identity) explicit inner result permutation requires one. N.B.
// perm_ may carry a non-null identity inner component when only the
// outer modes are permuted (the bipartite perm is constructed whole).
if (!implicit_permute_inner_ && bool(inner(perm_)) &&
!inner(perm_).is_identity())
TA_EXCEPTION(
"general products of tensors-of-tensors: a non-identity inner "
"result permutation is not yet supported; reorder the inner "
"annotation of the result");

// factor_ is absorbed into element_nonreturn_op_
op_ = op_type(left_op, right_op, scalar_type(1),
outer_size(indices_) - nh, outer_size(left_indices_) - nh,
outer_size(indices_) - nh + u,
outer_size(left_indices_) - nh + u,
outer_size(right_indices_) - nh, BipartitePermutation{},
this->element_nonreturn_op_, std::move(this->arena_plan_));
// ce+e, ce+ce_right and ce+ce_left are mutually exclusive; at most one
Expand Down Expand Up @@ -734,7 +753,11 @@ class ContEngine : public BinaryEngine<Derived> {
trange_type make_trange_general() const {
const unsigned int nh = n_fused_modes_;
const unsigned int nc = op_.gemm_helper().num_contract_ranks();
const unsigned int neA = op_.gemm_helper().left_rank() - nc;
// the no-external case carries a synthetic unit left-external mode in
// the GemmHelper only (see init_struct_general); the actual tranges do
// not have it
const unsigned int u = (outer_size(indices_) == n_fused_modes_) ? 1u : 0u;
const unsigned int neA = op_.gemm_helper().left_rank() - nc - u;
const unsigned int neB = op_.gemm_helper().right_rank() - nc;

typename trange_type::Ranges ranges(nh + neA + neB);
Expand Down Expand Up @@ -786,7 +809,11 @@ class ContEngine : public BinaryEngine<Derived> {
std::shared_ptr<const pmap_interface> pmap) {
const unsigned int nh = n_fused_modes_;
const unsigned int nc = op_.gemm_helper().num_contract_ranks();
const unsigned int neA = op_.gemm_helper().left_rank() - nc;
// the no-external case carries a synthetic unit left-external mode in
// the GemmHelper only (see init_struct_general); the actual tranges do
// not have it
const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u;
const unsigned int neA = op_.gemm_helper().left_rank() - nc - u;
const unsigned int neB = op_.gemm_helper().right_rank() - nc;

// Get pointers to the argument sizes
Expand Down Expand Up @@ -1707,7 +1734,11 @@ class ContEngine : public BinaryEngine<Derived> {
TiledArray::detail::is_contraction_arena_tot_v<
result_tile_type, left_tile_type, right_tile_type>;
if constexpr (arena_eligible_scale) {
if (this->outer_product_uses_summa()) {
// the fused arena scale ops are factor-free; a non-unit
// expression-level prefactor (ScalMult) takes the fallback op,
// which absorbs it
if (this->outer_product_uses_summa() &&
this->factor_ == scalar_type(1)) {
// The inner perm handed to the plan must match how the inner
// *result* permutation is applied for this result cell type --
// and the two cell types apply it in different places:
Expand Down Expand Up @@ -1741,45 +1772,48 @@ class ContEngine : public BinaryEngine<Derived> {
// cells. The Hadamard outer product is an assignment
// `result = (perm ^ tot) * scalar`, which needs value-returning
// `scale`; only owning inner cells support it.
auto fallback_op = [perm = !this->implicit_permute_inner_
? inner(this->perm_)
: Permutation{},
outer_uses_summa =
this->outer_product_uses_summa()](
result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
if (outer_uses_summa) {
using TiledArray::axpy_to;
if constexpr (tot_x_t) {
if (left.empty()) return; // absent cell: no contribution
if (perm)
axpy_to(result, left, right, perm);
else
axpy_to(result, left, right);
} else {
if (right.empty()) return; // absent cell: no contribution
if (perm)
axpy_to(result, right, left, perm);
else
axpy_to(result, right, left);
}
} else {
if constexpr (!TiledArray::is_tensor_view_v<
result_tile_element_type>) {
using TiledArray::scale;
if constexpr (tot_x_t)
result = perm ? scale(left, right, perm) : scale(left, right);
else
result = perm ? scale(right, left, perm) : scale(right, left);
} else {
TA_EXCEPTION(
"Tensor<View> scale-inner Hadamard-outer product: a "
"view result cell cannot be value-assigned a fresh "
"scaled tensor");
}
}
};
// N.B. the expression-level scalar prefactor (factor_, != 1 for
// ScalMult expressions) multiplies the plain operand's element
auto fallback_op =
[perm = !this->implicit_permute_inner_ ? inner(this->perm_)
: Permutation{},
outer_uses_summa = this->outer_product_uses_summa(),
factor = this->factor_](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
if (outer_uses_summa) {
using TiledArray::axpy_to;
if constexpr (tot_x_t) {
if (left.empty()) return; // absent cell: no contribution
if (perm)
axpy_to(result, left, right * factor, perm);
else
axpy_to(result, left, right * factor);
} else {
if (right.empty()) return; // absent cell: no contribution
if (perm)
axpy_to(result, right, left * factor, perm);
else
axpy_to(result, right, left * factor);
}
} else {
if constexpr (!TiledArray::is_tensor_view_v<
result_tile_element_type>) {
using TiledArray::scale;
if constexpr (tot_x_t)
result = perm ? scale(left, right * factor, perm)
: scale(left, right * factor);
else
result = perm ? scale(right, left * factor, perm)
: scale(right, left * factor);
} else {
TA_EXCEPTION(
"Tensor<View> scale-inner Hadamard-outer product: a "
"view result cell cannot be value-assigned a fresh "
"scaled tensor");
}
}
};
if constexpr (arena_eligible_scale) {
if (this->arena_plan_) {
if constexpr (tot_x_t)
Expand Down
94 changes: 33 additions & 61 deletions src/TiledArray/expressions/mult_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,48 +299,7 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
// take the no-target init_indices() overload below, which retains the
// bottom-up contraction convention -- general products under
// reductions remain unsupported.
{
auto const avail_l = BinaryEngine_::left_.available_indices();
auto const avail_r = BinaryEngine_::right_.available_indices();
auto demand = [](auto const& avail, auto const& sibling,
auto const& tgt) {
container::svector<std::string> r;
for (auto&& idx : tgt)
if (avail.count(idx)) r.push_back(idx);
for (auto&& idx : avail) {
if (tgt.count(idx)) continue;
if (sibling.count(idx)) r.push_back(idx);
// else: the index is consumed entirely within the child's own
// subtree (contracted deeper down) -- not demanded here. A
// genuinely orphaned index (single occurrence, demanded nowhere =
// implicit trace, unsupported) surfaces as the leaf-level
// target-is-not-a-permutation error.
}
return r;
};
auto bipartite_demand = [&demand](auto const& avail, auto const& sib,
auto const& tgt) {
auto const out = demand(outer(avail), outer(sib), outer(tgt));
auto const in = demand(inner(avail), inner(sib), inner(tgt));
return BipartiteIndexList(IndexList(out.begin(), out.end()),
IndexList(in.begin(), in.end()));
};
// each child dictates the ORDER of its demand (preferred_layout): a
// product child reorders to its canonical (h, eA, eB) layout -- a
// general product cannot host a result permutation, and contraction
// consumers absorb any child layout via the GEMM transpose forms.
// Materialize the demands as named lvalues: some preferred_layout()
// overloads (leaf/binary/unary) return a reference to their argument, so
// binding that to a temporary demand would be needlessly fragile.
const BipartiteIndexList left_demand =
bipartite_demand(avail_l, avail_r, target_indices);
const BipartiteIndexList right_demand =
bipartite_demand(avail_r, avail_l, target_indices);
BinaryEngine_::left_.init_indices(
BinaryEngine_::left_.preferred_layout(left_demand));
BinaryEngine_::right_.init_indices(
BinaryEngine_::right_.preferred_layout(right_demand));
}
BinaryEngine_::init_children_indices(target_indices);

this->product_type_ = compute_product_type(
outer(BinaryEngine_::left_.indices()),
Expand Down Expand Up @@ -713,29 +672,34 @@ class ScalMultEngine

/// \param target_indices The target index list for this expression
void init_indices(const BipartiteIndexList& target_indices) {
BinaryEngine_::left_.init_indices();
BinaryEngine_::right_.init_indices();
// deduce the children's index sets top-down (see
// BinaryEngine::init_children_indices), then classify and route exactly
// as MultEngine does (the scalar factor does not affect index roles)
BinaryEngine_::init_children_indices(target_indices);

this->product_type_ = compute_product_type(
outer(BinaryEngine_::left_.indices()),
outer(BinaryEngine_::right_.indices()), outer(target_indices));
this->inner_product_type_ = compute_product_type(
inner(BinaryEngine_::left_.indices()),
inner(BinaryEngine_::right_.indices()), inner(target_indices));

if (this->inner_product_type_ == TensorProduct::General)
TA_EXCEPTION(
"ScalMultEngine: general products (fused + contracted + free "
"indices) between the inner (nested) indices of tensors-of-tensors "
"are not supported");

if (this->product_type() == TensorProduct::Hadamard) {
// since already initialized left and right arg indices assign the target
// indices
BinaryEngine_::perm_indices(target_indices);
} else if (this->product_type() == TensorProduct::General) {
// layout via GeneralPermutationOptimizer (the target determines which
// shared indices are fused vs contracted), then propagate to children
if (!this->implicit_permute()) {
BinaryEngine_::template init_indices_<TensorProduct::General>(
target_indices);
if (BinaryEngine_::left_indices_ != BinaryEngine_::left_.indices())
BinaryEngine_::left_.perm_indices(BinaryEngine_::left_indices_);
if (BinaryEngine_::right_indices_ != BinaryEngine_::right_.indices())
BinaryEngine_::right_.perm_indices(BinaryEngine_::right_indices_);
}
this->perm_indices(target_indices);
} else {
ContEngine_::init_indices(target_indices);
auto children_initialized = true;
ContEngine_::init_indices(children_initialized);
ContEngine_::perm_indices(target_indices);
}
}

Expand Down Expand Up @@ -767,12 +731,14 @@ class ScalMultEngine
/// for the result tensor.
/// \param target_indices The target index list for the result tensor
void init_struct(const BipartiteIndexList& target_indices) {
// TODO Phase B (batched Summa): evaluate general products natively
if (this->product_type() == TensorProduct::General)
TA_EXCEPTION(
"ScalMultEngine: evaluation of general products (fused + contracted "
"+ free indices) via the expression layer is not yet implemented; "
"use TiledArray::einsum() instead");
if (this->product_type() == TensorProduct::General) {
// the inner tile op (for tensors-of-tensors) must be initialized
// first; init_struct_general consumes element_nonreturn_op_ and the
// arena plan it builds
this->init_inner_tile_op(inner(target_indices));
ContEngine_::init_struct_general(target_indices);
return;
}

this->init_perm(target_indices);

Expand All @@ -794,6 +760,8 @@ class ScalMultEngine
std::shared_ptr<const pmap_interface> pmap) {
if (this->product_type() == TensorProduct::Contraction)
ContEngine_::init_distribution(world, pmap);
else if (this->product_type() == TensorProduct::General)
ContEngine_::init_distribution_general(world, pmap);
else
BinaryEngine_::init_distribution(world, pmap);
}
Expand All @@ -804,6 +772,8 @@ class ScalMultEngine
dist_eval_type make_dist_eval() const {
if (this->product_type() == TensorProduct::Contraction)
return ContEngine_::make_dist_eval();
else if (this->product_type() == TensorProduct::General)
return ContEngine_::make_dist_eval_general();
else
return BinaryEngine_::make_dist_eval();
}
Expand All @@ -814,6 +784,8 @@ class ScalMultEngine
trange_type make_trange() const {
if (this->product_type() == TensorProduct::Contraction)
return ContEngine_::make_trange();
else if (this->product_type() == TensorProduct::General)
return ContEngine_::make_trange_general();
else
return BinaryEngine_::make_trange();
}
Expand Down
Loading
Loading