Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,18 @@ def guard(clusters):
# Separate out the indirect ConditionalDimensions, which only serve
# the purpose of protecting from OOB accesses
cds = [d for d in cds if not d.indirect]
modes = [cd.relation for cd in cds]
if modes.count('strict') > 1:
print(modes, cds, {m == 'strict' for m in modes})
raise CompilationError("Only one `strict` condition"
"can be used in an equation")
elif 'strict' in modes:
mode = 'strict'
else:
mode = sympy.And if sympy.And in modes else sympy.Or

# Chain together all `cds` conditions from all expressions in `c`
guards = {}
mode = sympy.Or
for cd in cds:
# `BOTTOM` parent implies a guard that lives outside of
# any iteration space, which corresponds to the placeholder None
Expand All @@ -279,7 +287,6 @@ def guard(clusters):

# Pull `cd` from any expr
condition = guards.setdefault(k, [])
mode = mode and cd.relation
for e in exprs:
try:
condition.append(e.conditionals[cd])
Expand All @@ -296,7 +303,10 @@ def guard(clusters):

# Combination `mode` is And by default.
# If all conditions are Or then Or combination `mode` is used.
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
if mode == 'strict':
guards = {d: v[0] for d, v in guards.items()}
else:
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}

# Construct a guarded Cluster
processed.append(c.rebuild(exprs=exprs, guards=guards))
Expand Down
32 changes: 26 additions & 6 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
)
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min
from devito.types import (
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift
)

__all__ = [
'ClusterizedEq',
Expand Down Expand Up @@ -222,7 +224,7 @@ def __new__(cls, *args, **kwargs):
relations=ordering.relations, mode='partial')
ispace = IterationSpace(intervals, iterators)

# Construct the conditionals and replace the ConditionalDimensions in `expr`
# Construct the conditionals
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should place this whole block of code, which constructs/lowers the conditionals, into its own separate functions, and a docstring with some examples

conditionals = {}
for d in ordering:
if not d.is_Conditional:
Expand All @@ -234,13 +236,31 @@ def __new__(cls, *args, **kwargs):
if d._factor is not None:
cond = d.relation(cond, GuardFactor(d))
conditionals[d] = cond

# Merge conditionals when possible. E.g if we have an implicit_dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this block imho deserves its own function

# and there is a dimension with the same parent, we ca merged
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension

"ca merged"

"their conditions"

you could also make the example a bit more practical

# its condition
for d in input_expr.implicit_dims:
if d not in conditionals:
continue
for cd in dict(conditionals):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list(...) is fine

if cd.parent == d.parent and cd != d:
cond = conditionals.pop(d)
if d.relation == 'strict':
conditionals[cd] = conditionals[d] = cond
else:
mode = cd.relation and d.relation
conditionals[cd] = mode(cond, conditionals[cd])
break

# Replace the ConditionalDimensions in `expr`
for d, cond in conditionals.items():
# Replace dimension with index
index = d.index
if d.condition is not None and d in expr.free_symbols:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})

conditionals = frozendict(conditionals)
index = index - relational_min(cond, d.parent)
shift = relational_shift(cond, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})

# Lower all Differentiable operations into SymPy operations
rhs = diff2sympy(expr.rhs)
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/support/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __lt__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -164,6 +165,7 @@ def __gt__(self, other):
return True
elif q_negative(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -203,6 +205,7 @@ def __le__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

# Note: unlike `__lt__`, if we end up here, then *it is* <=. For example,
Expand Down
9 changes: 5 additions & 4 deletions devito/passes/clusters/asynchrony.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict

from sympy import true
from sympy import Mod, true

from devito.ir import (
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
Expand Down Expand Up @@ -78,7 +78,8 @@ def callback(self, clusters, prefix):
d = self.key0(c0)
if d is not dim:
continue

if d in c0.guards and not c0.guards[d].has(Mod):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

searching for Mod is a bit meh, I'd rather add a special guard to ir/support/guards.py and look for that instead (there's quite a few already in there!)

continue
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
self._schedule_withlocks(c0, d, protected, locks, syncs)

Expand Down Expand Up @@ -193,7 +194,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
if c.properties.is_prefetchable(d._defines):
_actions_from_update_memcpy(c, d, clusters, actions, sregistry)
elif d.is_Custom and is_integer(c.ispace[d].size):
_actions_from_init(c, d, actions)
_actions_from_init(c, d, clusters, actions)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover, I guess


# Attach the computed Actions
processed = []
Expand All @@ -214,7 +215,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
return processed


def _actions_from_init(c, d, actions):
def _actions_from_init(c, d, clusters, actions):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover, I guess

e = c.exprs[0]
function = e.rhs.function
target = e.lhs.function
Expand Down
Loading
Loading