Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/cuda/tile/_passes/rewrite_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def fuse_mul_addsub(op: RawBinaryArithmeticOperation, ctx: MatchContext):
raise NoMatch("not an add/sub binop")
if (mul_op := ctx.get_match(op.lhs, match_float_mul)) is not None:
acc = op.rhs
elif op.fn == "add" and (mul_op := ctx.get_match(op.rhs, match_float_mul)) is not None:
rhs_is_mul = False
elif (mul_op := ctx.get_match(op.rhs, match_float_mul)) is not None:
acc = op.lhs
rhs_is_mul = True
else:
raise NoMatch("no float mul operand")

Expand All @@ -112,14 +114,19 @@ def fuse_mul_addsub(op: RawBinaryArithmeticOperation, ctx: MatchContext):

# FIXME: fuse location
new_ops = []
mul_lhs = mul_op.lhs
if op.fn == "sub":
negated_acc = ctx.make_temp_var(op.loc)
ctx.set_type(negated_acc, ctx.typeof(acc))
new_ops.append(Unary(fn="neg", operand=acc, rounding_mode=None, flush_to_zero=False,
result_vars=(negated_acc,), loc=op.loc))
acc = negated_acc
neg_target = mul_op.lhs if rhs_is_mul else acc
negated = ctx.make_temp_var(op.loc)
ctx.set_type(negated, ctx.typeof(neg_target))
new_ops.append(Unary(fn="neg", operand=neg_target, rounding_mode=None, flush_to_zero=False,
result_vars=(negated,), loc=op.loc))
if rhs_is_mul:
mul_lhs = negated
else:
acc = negated

new_ops.append(FusedMulAddOperation(lhs=mul_op.lhs, rhs=mul_op.rhs, acc=acc,
new_ops.append(FusedMulAddOperation(lhs=mul_lhs, rhs=mul_op.rhs, acc=acc,
rounding_mode=rm, flush_to_zero=ftz,
result_vars=(op.result_var,), loc=op.loc))
ctx.add_rewrite((mul_op, op), new_ops)
Expand Down
12 changes: 12 additions & 0 deletions test/test_fma.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def add_mul_kernel(x, y, z, output,
ct.store(output, index=(bidx, 0), tile=output_tile)


def sub_mul_kernel(x, y, z, output,
TILE: ct.Constant[int],
DIM: ct.Constant[int]):
bidx = ct.bid(0)
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
output_tile = tz - tx * ty # c - a*b
ct.store(output, index=(bidx, 0), tile=output_tile)


@ct.kernel
def mul_add_same_operand_kernel(x, output,
TILE: ct.Constant[int],
Expand Down Expand Up @@ -86,6 +97,7 @@ def test_fma_skip_when_new_op_uses_deleted_var():
pytest.param(mul_add_kernel, lambda x, y, z: x * y + z),
pytest.param(mul_sub_kernel, lambda x, y, z: x * y - z),
pytest.param(add_mul_kernel, lambda x, y, z: z + x * y),
pytest.param(sub_mul_kernel, lambda x, y, z: z - x * y),
]
)
def test_fma(kernel, kernel_ref):
Expand Down