From d063d5774f538a49b308997b46f3fe556bbbb4ed Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 11 Aug 2021 20:15:02 +0800 Subject: [PATCH] perf(functional): use fma to reduce elemwise but disable subgraph compilation GitOrigin-RevId: c75a6e1a09b8e727a48e3b5eaabc6926aa046a46 --- .../python/megengine/core/tensor/utils.py | 20 +++++++++-- .../python/megengine/functional/math.py | 4 +-- imperative/python/megengine/functional/nn.py | 36 +++++++++---------- 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index b6a1f0c26..738427091 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): "-": lambda: builtin.Elemwise(mode="negate"), } + ternary_ops = { + "fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"), + } + + quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")} + def decorator(func): builder = _SubgraphBuilder(name) - def apply_expr(op, *args): + def apply_expr(op, *args, nr_out=None): if isinstance(op, str): if len(args) == 2: op = binary_ops[op]() elif len(args) == 1: op = unary_ops[op]() - return builder.apply(op, args, 1)[0] + elif len(args) == 3: + op = ternary_ops[op]() + elif len(args) == 4: + op = quaternary_ops[op]() + results = builder.apply(op, args, 1 if nr_out is None else nr_out) + if nr_out is None: + assert len(results) == 1 + return results[0] + else: + assert len(results) == nr_out + return results def apply_const(value, dtype=dtype, device=device): return builder.apply_const(value, dtype, device) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index e8117ab13..12614e4f6 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -784,7 +784,7 @@ class _Hashable: def _get_extentedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, ): - @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3) + @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) def extentedMatrixMulOp(inputs, f, c): assert len(inputs) == 2 inp1, inp2 = inputs @@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp( def _get_extentedBatchedMatrixMulOp( device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, ): - @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3) + @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) def extentedBatchedMatrixMulOp(inputs, f, c): assert len(inputs) == 2 inp1, inp2 = inputs diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 2f27f858c..47d474fe7 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) - @subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3) + @subgraph("SyncBnStage1", dtype, device, 7) def syncbn_stage1(inputs, f, c): input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] weight, bias = inputs[5:7] @@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): inv_var_wt = f("*", invsqrt_channel_var, weight) neg_channel_mean = f("-", channel_mean) outvar =\ - f("+", f("*", input, inv_var_wt), + f("fma3", input, inv_var_wt, f("+", f("*", neg_channel_mean, inv_var_wt), bias)) return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) - @subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3) + @subgraph("SyncBnStage1Inference", dtype, device, 6) def syncbn_stage1_inference(inputs, f, c): input, channel_mean, channel_var, eps = inputs[0:4] weight, bias = inputs[4:6] @@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): bias)) return (outvar,), (True,) - @subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3) + @subgraph("SyncBnStage2", dtype, device, 7) def syncbn_stage2(inputs, f, c): running_mean, running_var, momentum = inputs[0:3] reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] - running_mean = f("*", running_mean, momentum) - running_mean =\ - f("+", running_mean, - f("*", f("-", c(1), momentum), - channel_mean)) + c1_minus_momentum = f("-", c(1), momentum) + reduce_size_minus_c1 = f("-", reduce_size, c(1)) + running_mean = f("fma4", + running_mean, momentum, + c1_minus_momentum, channel_mean, + ) channel_variance_unbiased =\ f("+", f("/", f("**", channel_x1s, c(2)), f("*", f("-", reduce_size), - f("-", reduce_size, c(1)))), + reduce_size_minus_c1)), f("/", channel_x2s, - f("-", reduce_size, c(1)))) - running_var = f("*", running_var, momentum) - running_var =\ - f("+", running_var, - f("*", f("-", c(1), momentum), - channel_variance_unbiased)) + reduce_size_minus_c1)) + running_var = f("fma4", + running_var, momentum, + c1_minus_momentum, channel_variance_unbiased + ) return (running_mean, running_var), (True, True) - @subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3) + @subgraph("SyncBnConcatStats", dtype, device, 3) def syncbn_concat_stats(inputs, f, c): reduce_size, channel_x1s, channel_x2s = inputs[0:3] reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) return (stats,), (True,) - @subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3) + @subgraph("SyncBnSplitStats", dtype, device, 1) def syncbn_split_stats(inputs, f, c): stats = inputs[0] c_1 = c(1, dtype="int32") -- GitLab