diff --git a/src/cuda/tile/_passes/rewrite_patterns.py b/src/cuda/tile/_passes/rewrite_patterns.py index 339ac14..a5ea360 100644 --- a/src/cuda/tile/_passes/rewrite_patterns.py +++ b/src/cuda/tile/_passes/rewrite_patterns.py @@ -95,8 +95,10 @@ def fuse_mul_addsub(op: RawBinaryArithmeticOperation, ctx: MatchContext): raise NoMatch("not an add/sub binop") if (mul_op := ctx.get_match(op.lhs, match_float_mul)) is not None: acc = op.rhs - elif op.fn == "add" and (mul_op := ctx.get_match(op.rhs, match_float_mul)) is not None: + rhs_is_mul = False + elif (mul_op := ctx.get_match(op.rhs, match_float_mul)) is not None: acc = op.lhs + rhs_is_mul = True else: raise NoMatch("no float mul operand") @@ -112,14 +114,19 @@ def fuse_mul_addsub(op: RawBinaryArithmeticOperation, ctx: MatchContext): # FIXME: fuse location new_ops = [] + mul_lhs = mul_op.lhs if op.fn == "sub": - negated_acc = ctx.make_temp_var(op.loc) - ctx.set_type(negated_acc, ctx.typeof(acc)) - new_ops.append(Unary(fn="neg", operand=acc, rounding_mode=None, flush_to_zero=False, - result_vars=(negated_acc,), loc=op.loc)) - acc = negated_acc + neg_target = mul_op.lhs if rhs_is_mul else acc + negated = ctx.make_temp_var(op.loc) + ctx.set_type(negated, ctx.typeof(neg_target)) + new_ops.append(Unary(fn="neg", operand=neg_target, rounding_mode=None, flush_to_zero=False, + result_vars=(negated,), loc=op.loc)) + if rhs_is_mul: + mul_lhs = negated + else: + acc = negated - new_ops.append(FusedMulAddOperation(lhs=mul_op.lhs, rhs=mul_op.rhs, acc=acc, + new_ops.append(FusedMulAddOperation(lhs=mul_lhs, rhs=mul_op.rhs, acc=acc, rounding_mode=rm, flush_to_zero=ftz, result_vars=(op.result_var,), loc=op.loc)) ctx.add_rewrite((mul_op, op), new_ops) diff --git a/test/test_fma.py b/test/test_fma.py index b5e5169..252fba0 100644 --- a/test/test_fma.py +++ b/test/test_fma.py @@ -56,6 +56,17 @@ def add_mul_kernel(x, y, z, output, ct.store(output, index=(bidx, 0), tile=output_tile) +def sub_mul_kernel(x, y, z, output, + TILE: ct.Constant[int], + DIM: ct.Constant[int]): + bidx = ct.bid(0) + tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM)) + ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM)) + tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM)) + output_tile = tz - tx * ty # c - a*b + ct.store(output, index=(bidx, 0), tile=output_tile) + + @ct.kernel def mul_add_same_operand_kernel(x, output, TILE: ct.Constant[int], @@ -86,6 +97,7 @@ def test_fma_skip_when_new_op_uses_deleted_var(): pytest.param(mul_add_kernel, lambda x, y, z: x * y + z), pytest.param(mul_sub_kernel, lambda x, y, z: x * y - z), pytest.param(add_mul_kernel, lambda x, y, z: z + x * y), + pytest.param(sub_mul_kernel, lambda x, y, z: z - x * y), ] ) def test_fma(kernel, kernel_ref):