diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index b6a1f0c266b46acae992e5d3c85a506bc2a85768..73842709177ef488c9cbd2bc2b89ac85e8ab1777 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): "-": lambda: builtin.Elemwise(mode="negate"), } + ternary_ops = { + "fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"), + } + + quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")} + def decorator(func): builder = _SubgraphBuilder(name) - def apply_expr(op, *args): + def apply_expr(op, *args, nr_out=None): if isinstance(op, str): if len(args) == 2: op = binary_ops[op]() elif len(args) == 1: op = unary_ops[op]() - return builder.apply(op, args, 1)[0] + elif len(args) == 3: + op = ternary_ops[op]() + elif len(args) == 4: + op = quaternary_ops[op]() + results = builder.apply(op, args, 1 if nr_out is None else nr_out) + if nr_out is None: + assert len(results) == 1 + return results[0] + else: + assert len(results) == nr_out + return results def apply_const(value, dtype=dtype, device=device): return builder.apply_const(value, dtype, device) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index e8117ab131ff57c0fc70e375a1b9475dc69ba928..12614e4f62b6a8a310bea8613bbef45a56fa2dbd 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -784,7 +784,7 @@ class _Hashable: def _get_extentedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, ): - @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3) + @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) def extentedMatrixMulOp(inputs, f, c): assert len(inputs) == 2 inp1, inp2 = inputs @@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp( def _get_extentedBatchedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, ): - @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3) + @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) def extentedBatchedMatrixMulOp(inputs, f, c): assert len(inputs) == 2 inp1, inp2 = inputs diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 2f27f858c7ef63c3f153d3e292b9f8fc842aa800..47d474fe76c49701bd0ed04578294d2bc85d63a5 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) - @subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3) + @subgraph("SyncBnStage1", dtype, device, 7) def syncbn_stage1(inputs, f, c): input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] weight, bias = inputs[5:7] @@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): inv_var_wt = f("*", invsqrt_channel_var, weight) neg_channel_mean = f("-", channel_mean) outvar =\ - f("+", f("*", input, inv_var_wt), + f("fma3", input, inv_var_wt, f("+", f("*", neg_channel_mean, inv_var_wt), bias)) return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) - @subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3) + @subgraph("SyncBnStage1Inference", dtype, device, 6) def syncbn_stage1_inference(inputs, f, c): input, channel_mean, channel_var, eps = inputs[0:4] weight, bias = inputs[4:6] @@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): bias)) return (outvar,), (True,) - @subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3) + @subgraph("SyncBnStage2", dtype, device, 7) def syncbn_stage2(inputs, f, c): running_mean, running_var, momentum = inputs[0:3] reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] - running_mean = f("*", running_mean, momentum) - running_mean =\ - f("+", running_mean, - f("*", f("-", c(1), momentum), - channel_mean)) + c1_minus_momentum = f("-", c(1), momentum) + reduce_size_minus_c1 = f("-", reduce_size, c(1)) + running_mean = f("fma4", + running_mean, momentum, + c1_minus_momentum, channel_mean, + ) channel_variance_unbiased =\ f("+", f("/", f("**", channel_x1s, c(2)), f("*", f("-", reduce_size), - f("-", reduce_size, c(1)))), + reduce_size_minus_c1)), f("/", channel_x2s, - f("-", reduce_size, c(1)))) - running_var = f("*", running_var, momentum) - running_var =\ - f("+", running_var, - f("*", f("-", c(1), momentum), - channel_variance_unbiased)) + reduce_size_minus_c1)) + running_var = f("fma4", + running_var, momentum, + c1_minus_momentum, channel_variance_unbiased + ) return (running_mean, running_var), (True, True) - @subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3) + @subgraph("SyncBnConcatStats", dtype, device, 3) def syncbn_concat_stats(inputs, f, c): reduce_size, channel_x1s, channel_x2s = inputs[0:3] reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) return (stats,), (True,) - @subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3) + @subgraph("SyncBnSplitStats", dtype, device, 1) def syncbn_split_stats(inputs, f, c): stats = inputs[0] c_1 = c(1, dtype="int32")