Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/TiledArray/expressions/binary_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,35 @@ class BinaryEngine : public ExprEngine<Derived> {
right_inner_permtype_ == PermutationType::general);
}

/// \return the indices this subtree can supply (Phase E up-pass): the
/// first-occurrence-ordered union of the children's available indices,
/// outer and inner lists separately. Valid before init: it depends only on
/// the leaf annotations, never on resolved (post-init) index sets.
BipartiteIndexList available_indices() const {
auto union_ = [](auto const& a, auto const& b) {
container::svector<std::string> r(a.begin(), a.end());
for (auto&& idx : b)
if (!a.count(idx)) r.push_back(idx);
return r;
};
auto const l = left_.available_indices();
auto const r = right_.available_indices();
auto const out = union_(outer(l), outer(r));
auto const in = union_(inner(l), inner(r));
return BipartiteIndexList(IndexList(out.begin(), out.end()),
IndexList(in.begin(), in.end()));
}

/// \return the layout this subtree prefers for producing the index set of
/// \p demand: element-wise binary ops (Add/Subt) impose no layout of their
/// own (both children must produce the same set and are aligned by
/// permutation), so the demand is returned unchanged. MultEngine overrides
/// this with its canonical product layout.
const BipartiteIndexList& preferred_layout(
const BipartiteIndexList& demand) const {
return demand;
}

/// Initialize result tensor structure

/// This function will initialize the permutation, tiled range, and shape
Expand Down
100 changes: 86 additions & 14 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#define TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED

#include <TiledArray/dist_eval/contraction_eval.h>
#include <TiledArray/dist_eval/unary_eval.h>
#include <TiledArray/expressions/binary_engine.h>
#include <TiledArray/expressions/permopt.h>
#include <TiledArray/pmap/slabbed_pmap.h>
Expand All @@ -36,6 +37,7 @@
#include <TiledArray/tile_op/batched_contract_reduce.h>
#include <TiledArray/tile_op/contract_reduce.h>
#include <TiledArray/tile_op/mult.h>
#include <TiledArray/tile_op/noop.h>

namespace TiledArray {
namespace expressions {
Expand Down Expand Up @@ -179,6 +181,10 @@ class ContEngine : public BinaryEngine<Derived> {
// General (fused + contracted + free indices) products only:
unsigned int n_fused_modes_ = 0; ///< # of leading fused (outer) modes
size_type n_slabs_ = 1; ///< # of fused-index tile slabs
bool general_repermute_ = false; ///< whether the target layout differs
///< from the canonical result layout, so
///< the evaluated result is re-permuted
///< by a streaming unary eval

static unsigned int find(const BipartiteIndexList& indices,
const std::string& index_label, unsigned int i,
Expand Down Expand Up @@ -660,16 +666,13 @@ class ContEngine : public BinaryEngine<Derived> {
TA_ASSERT(nh > 0u); // else this is a pure contraction
n_fused_modes_ = nh;

// initialize perm_; an interleaved target (a result permutation that
// mixes fused and free modes) is not yet supported -- the canonical
// result layout must equal the target
// initialize perm_; a target that differs from the canonical (fused...,
// left-free..., right-free...) result layout cannot be folded into the
// batched tile op (BatchedContractReduce must be perm-free), so the
// product is evaluated in its canonical layout and re-permuted to the
// target by a streaming unary eval (see make_dist_eval_general)
this->init_perm(target_indices);
if (outer(target_indices) != outer(indices_))
TA_EXCEPTION(
"general products (fused + contracted + free indices): targets "
"that interleave fused and free indices are not yet supported; "
"reorder the result annotation to (fused..., left-free..., "
"right-free...)");
general_repermute_ = (outer(target_indices) != outer(indices_));

// the tile op operates on the folded (fused-mode-free) shapes
const auto left_op = to_cblas_op(left_outer_permtype_);
Expand Down Expand Up @@ -712,6 +715,12 @@ class ContEngine : public BinaryEngine<Derived> {

trange_ = make_trange_general();
shape_ = make_shape_general();
if (general_repermute_) {
// consumers see the target layout; the canonical structures are
// recomputed in make_dist_eval_general for the inner Summa
trange_ = outer(perm_) * trange_;
shape_ = shape_.perm(outer(perm_));
}

if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) {
shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape);
Expand Down Expand Up @@ -834,9 +843,37 @@ class ContEngine : public BinaryEngine<Derived> {
}
}

/// Streaming tile re-permute op for general products whose target layout
/// differs from the canonical (fused..., free...) result layout: the
/// batched tile op must stay perm-free, so the consumer-side unary eval
/// applies the result permutation per tile instead
struct GeneralRepermuteOp {
typedef value_type result_type;
typedef value_type argument_type;
static constexpr bool is_consumable = false;
/// Only the *outer* (result-layout) permutation is applied here; inner
/// (within-cell) permutation of tensor-of-tensor results is handled
/// separately (see init_struct_general / implicit_permute_inner_), so this
/// op stores a plain outer Permutation to avoid accidentally permuting
/// inner contents.
Permutation perm;
/// false when the consumer fuses the permutation into its own operation
/// (implicit permute, e.g. a transposed GEMM): then only the tile
/// ordinals/trange are remapped (by the host UnaryEvalImpl) and the tile
/// contents are delivered in the canonical layout
bool permute_contents = true;
result_type operator()(const argument_type& tile) const {
if (!permute_contents) return tile;
TiledArray::detail::Noop<value_type, value_type, false> noop;
return noop(tile, perm);
}
};

/// Construct the distributed evaluator of a general product

/// \return The batched-Summa distributed evaluator for this expression
/// \return The batched-Summa distributed evaluator for this expression,
/// wrapped in a streaming re-permute when the target layout differs from
/// the canonical result layout
dist_eval_type make_dist_eval_general() const {
typedef TiledArray::detail::BatchedContractReduce<op_type> batched_op_type;
typedef TiledArray::detail::Summa<typename left_type::dist_eval_type,
Expand All @@ -847,11 +884,46 @@ class ContEngine : public BinaryEngine<Derived> {
typename left_type::dist_eval_type left = left_.make_dist_eval();
typename right_type::dist_eval_type right = right_.make_dist_eval();

std::shared_ptr<impl_type> pimpl = std::make_shared<impl_type>(
left, right, *world_, trange_, shape_, pmap_, perm_,
batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_);
if (!general_repermute_) {
std::shared_ptr<impl_type> pimpl = std::make_shared<impl_type>(
left, right, *world_, trange_, shape_, pmap_, perm_,
batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_);
return dist_eval_type(pimpl);
}

return dist_eval_type(pimpl);
// evaluate in the canonical layout (Summa with perm-free op), then
// re-permute tiles to the target layout with a streaming unary eval;
// trange_/shape_ hold the target-layout structures (see
// init_struct_general), the canonical ones are recomputed here
auto const canonical_trange = make_trange_general();
auto const canonical_shape = [this]() {
auto s = make_shape_general();
if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) {
// the consumer-supplied mask is expressed in the target layout
auto const inv_perm = outer(perm_).inv();
s = s.mask(ExprEngine_::override_ptr_->shape->perm(inv_perm));
}
return s;
}();
// the inner Summa's result placement must be slab-replicated (the owner
// of a tile independent of its slab index), regardless of the
// (target-layout) pmap the consumer supplied for this node
auto canonical_pmap = std::make_shared<TiledArray::detail::SlabbedPmap>(
*world_, proc_grid_.make_pmap(), n_slabs_);
std::shared_ptr<impl_type> pimpl = std::make_shared<impl_type>(
left, right, *world_, canonical_trange, canonical_shape, canonical_pmap,
BipartitePermutation{}, batched_op_type(op_, n_fused_modes_), K_,
proc_grid_, n_slabs_);
dist_eval_type canonical(pimpl);

typedef TiledArray::detail::UnaryEvalImpl<
dist_eval_type, GeneralRepermuteOp, typename Derived::policy>
repermute_impl_type;
std::shared_ptr<repermute_impl_type> wrapper =
std::make_shared<repermute_impl_type>(
canonical, *world_, trange_, shape_, pmap_, perm_,
GeneralRepermuteOp{outer(perm_), !this->implicit_permute_outer_});
return dist_eval_type(wrapper);
}

/// Expression identification tag
Expand Down
12 changes: 12 additions & 0 deletions src/TiledArray/expressions/leaf_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ class LeafEngine : public ExprEngine<Derived> {
/// This function is a noop since the index list is fixed.
void init_indices() {}

/// \return the indices this subtree can supply (Phase E up-pass): for a
/// leaf, simply its annotation (set at construction, valid before init)
const BipartiteIndexList& available_indices() const { return indices_; }

/// \return the layout this subtree prefers for producing the index set of
/// \p demand: a leaf accepts any ordering (the consumer aligns against the
/// fixed annotation by permutation), so the demand is returned unchanged
const BipartiteIndexList& preferred_layout(
const BipartiteIndexList& demand) const {
return demand;
}

void init_distribution(World* world,
const std::shared_ptr<const pmap_interface>& pmap) {
ExprEngine_::init_distribution(world, (pmap ? pmap : array_.pmap()));
Expand Down
127 changes: 91 additions & 36 deletions src/TiledArray/expressions/mult_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,44 +281,65 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {

/// \param target_indices The target index list for this expression
void init_indices(const BipartiteIndexList& target_indices) {
// to decide what type of product this is must initialize indices down
// the tree.
// N.B. since this may be a contraction we do not know the target indices
// for the left and right, hence do target-neutral initialization
BinaryEngine_::left_.init_indices();
BinaryEngine_::right_.init_indices();

// Validate that the (bottom-up resolved) child indices are consistent
// with the target: every outer index of each child must appear in the
// other child or in the target (as a free, fused, or contracted index).
// A violation usually means a *general* product (fused + contracted +
// free indices) appears at an INNER node of the expression tree, where
// the role of a shared index cannot be deduced bottom-up; e.g. in the
// THC-like g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * ... the
// index r1 is fused in X*X but contracted downstream, while the
// bottom-up convention contracts it in X*X, orphaning the r1 of Z.
// Resolving this requires pushing the needed-index sets down the
// expression tree; until then materialize such inner products into
// explicit intermediates, so that every general product appears as the
// root of its own assignment (where the target determines the index
// roles).
// Deduce each child's index SET top-down (Phase E). The index set an
// inner node must produce depends on what its ancestors need: a shared
// index of the two children is fused iff this node's target carries it,
// contracted here otherwise; an index of one child that neither the
// sibling nor the target carries is consumed entirely within that
// child's subtree and is not demanded of it.
// Each child's demand is ordered target-kept-first (the indices this
// node's result keeps, in target order) followed by the
// contracted-here indices in the child's leaf-availability order --
// fused indices thus lead, matching the canonical layouts of
// GeneralPermutationOptimizer (h leading) so the nbatch fold stays
// zero-copy. E.g. in the THC-like g("p,q,r,s") = X("p,r1") * X("q,r1")
// * Z("r1,r2") * ... the index r1 is demanded of X*X by its consumer
// (fused there), and contracted where Z meets the X*X subtree.
// N.B. expressions consumed WITHOUT a target (reductions, e.g. dot)
// take the no-target init_indices() overload below, which retains the
// bottom-up contraction convention -- general products under
// reductions remain unsupported.
{
auto const& left_outer = outer(BinaryEngine_::left_.indices());
auto const& right_outer = outer(BinaryEngine_::right_.indices());
auto const& target_outer = outer(target_indices);
auto validate = [&](const IndexList& a, const IndexList& b) {
for (auto&& idx : a)
if (!b.count(idx) && !target_outer.count(idx))
TA_EXCEPTION(
"MultEngine: an argument index appears in neither the other "
"argument nor the target. If a general product (fused + "
"contracted + free indices) appears at an inner node of the "
"expression tree, its index roles cannot be deduced "
"bottom-up; materialize it into an explicit intermediate so "
"that it appears as the root of its own assignment");
auto const avail_l = BinaryEngine_::left_.available_indices();
auto const avail_r = BinaryEngine_::right_.available_indices();
auto demand = [](auto const& avail, auto const& sibling,
auto const& tgt) {
container::svector<std::string> r;
for (auto&& idx : tgt)
if (avail.count(idx)) r.push_back(idx);
for (auto&& idx : avail) {
if (tgt.count(idx)) continue;
if (sibling.count(idx)) r.push_back(idx);
// else: the index is consumed entirely within the child's own
// subtree (contracted deeper down) -- not demanded here. A
// genuinely orphaned index (single occurrence, demanded nowhere =
// implicit trace, unsupported) surfaces as the leaf-level
// target-is-not-a-permutation error.
}
return r;
};
validate(left_outer, right_outer);
validate(right_outer, left_outer);
auto bipartite_demand = [&demand](auto const& avail, auto const& sib,
auto const& tgt) {
auto const out = demand(outer(avail), outer(sib), outer(tgt));
auto const in = demand(inner(avail), inner(sib), inner(tgt));
return BipartiteIndexList(IndexList(out.begin(), out.end()),
IndexList(in.begin(), in.end()));
};
// each child dictates the ORDER of its demand (preferred_layout): a
// product child reorders to its canonical (h, eA, eB) layout -- a
// general product cannot host a result permutation, and contraction
// consumers absorb any child layout via the GEMM transpose forms.
// Materialize the demands as named lvalues: some preferred_layout()
// overloads (leaf/binary/unary) return a reference to their argument, so
// binding that to a temporary demand would be needlessly fragile.
const BipartiteIndexList left_demand =
bipartite_demand(avail_l, avail_r, target_indices);
const BipartiteIndexList right_demand =
bipartite_demand(avail_r, avail_l, target_indices);
BinaryEngine_::left_.init_indices(
BinaryEngine_::left_.preferred_layout(left_demand));
BinaryEngine_::right_.init_indices(
BinaryEngine_::right_.preferred_layout(right_demand));
}

this->product_type_ = compute_product_type(
Expand Down Expand Up @@ -377,6 +398,40 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
}
}

/// \return the layout this product prefers for producing the index set of
/// \p demand: the canonical general-product result layout (h, eA, eB) --
/// indices supplied by both children (fused here) lead, then indices
/// supplied by the left child only, then by the right child only, each
/// group in demand order. h-leading is required when this node resolves to
/// a general product (general results cannot host a result permutation and
/// the nbatch fold needs the fused modes leading); for a pure contraction
/// (empty h) the (eA, eB) order matches the GEMM result, and for a pure
/// Hadamard the demand order is preserved.
BipartiteIndexList preferred_layout(const BipartiteIndexList& demand) const {
auto const avail_l = BinaryEngine_::left_.available_indices();
auto const avail_r = BinaryEngine_::right_.available_indices();
auto canonical = [](auto const& d, auto const& al, auto const& ar) {
container::svector<std::string> h, ea, eb;
for (auto&& idx : d) {
bool const inl = al.count(idx);
bool const inr = ar.count(idx);
if (inl && inr)
h.push_back(idx);
else if (inl)
ea.push_back(idx);
else
eb.push_back(idx);
}
h.insert(h.end(), ea.begin(), ea.end());
h.insert(h.end(), eb.begin(), eb.end());
return h;
};
auto const out = canonical(outer(demand), outer(avail_l), outer(avail_r));
auto const in = canonical(inner(demand), inner(avail_l), inner(avail_r));
return BipartiteIndexList(IndexList(out.begin(), out.end()),
IndexList(in.begin(), in.end()));
}

/// Initialize result tensor structure

/// This function will initialize the permutation, tiled range, and shape
Expand Down
10 changes: 10 additions & 0 deletions src/TiledArray/expressions/unary_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ class UnaryEngine : ExprEngine<Derived> {
indices_ = arg_.indices();
}

/// \return the indices this subtree can supply (Phase E up-pass): a unary
/// op supplies exactly what its argument supplies (valid before init)
decltype(auto) available_indices() const { return arg_.available_indices(); }

/// \return the layout this subtree prefers for producing the index set of
/// \p demand: a unary op defers to its argument's preference
decltype(auto) preferred_layout(const BipartiteIndexList& demand) const {
return arg_.preferred_layout(demand);
}

/// Initialize result tensor structure

/// This function will initialize the permutation, tiled range, and shape
Expand Down
Loading
Loading