提交 d063d577 编写于 作者: M Megvii Engine Team

perf(functional): use fma to reduce elemwise but disable subgraph compilation

GitOrigin-RevId: c75a6e1a09b8e727a48e3b5eaabc6926aa046a46
上级 2a063f8e
...@@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): ...@@ -242,16 +242,32 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
"-": lambda: builtin.Elemwise(mode="negate"), "-": lambda: builtin.Elemwise(mode="negate"),
} }
ternary_ops = {
"fma3": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD3"),
}
quaternary_ops = {"fma4": lambda: builtin.Elemwise(mode="FUSE_MUL_ADD4")}
def decorator(func): def decorator(func):
builder = _SubgraphBuilder(name) builder = _SubgraphBuilder(name)
def apply_expr(op, *args): def apply_expr(op, *args, nr_out=None):
if isinstance(op, str): if isinstance(op, str):
if len(args) == 2: if len(args) == 2:
op = binary_ops[op]() op = binary_ops[op]()
elif len(args) == 1: elif len(args) == 1:
op = unary_ops[op]() op = unary_ops[op]()
return builder.apply(op, args, 1)[0] elif len(args) == 3:
op = ternary_ops[op]()
elif len(args) == 4:
op = quaternary_ops[op]()
results = builder.apply(op, args, 1 if nr_out is None else nr_out)
if nr_out is None:
assert len(results) == 1
return results[0]
else:
assert len(results) == nr_out
return results
def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device) return builder.apply_const(value, dtype, device)
......
...@@ -784,7 +784,7 @@ class _Hashable: ...@@ -784,7 +784,7 @@ class _Hashable:
def _get_extentedMatrixMulOp( def _get_extentedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
): ):
@subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=3) @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedMatrixMulOp(inputs, f, c): def extentedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2 assert len(inputs) == 2
inp1, inp2 = inputs inp1, inp2 = inputs
...@@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp( ...@@ -884,7 +884,7 @@ def _get_extentedMatrixMulOp(
def _get_extentedBatchedMatrixMulOp( def _get_extentedBatchedMatrixMulOp(
device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
): ):
@subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=3) @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
def extentedBatchedMatrixMulOp(inputs, f, c): def extentedBatchedMatrixMulOp(inputs, f, c):
assert len(inputs) == 2 assert len(inputs) == 2
inp1, inp2 = inputs inp1, inp2 = inputs
......
...@@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): ...@@ -1174,7 +1174,7 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True) return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
@subgraph("SyncBnStage1", dtype, device, 7, gopt_level=3) @subgraph("SyncBnStage1", dtype, device, 7)
def syncbn_stage1(inputs, f, c): def syncbn_stage1(inputs, f, c):
input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5] input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
weight, bias = inputs[5:7] weight, bias = inputs[5:7]
...@@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): ...@@ -1187,12 +1187,12 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
inv_var_wt = f("*", invsqrt_channel_var, weight) inv_var_wt = f("*", invsqrt_channel_var, weight)
neg_channel_mean = f("-", channel_mean) neg_channel_mean = f("-", channel_mean)
outvar =\ outvar =\
f("+", f("*", input, inv_var_wt), f("fma3", input, inv_var_wt,
f("+", f("*", neg_channel_mean, inv_var_wt), f("+", f("*", neg_channel_mean, inv_var_wt),
bias)) bias))
return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False) return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)
@subgraph("SyncBnStage1Inference", dtype, device, 6, gopt_level=3) @subgraph("SyncBnStage1Inference", dtype, device, 6)
def syncbn_stage1_inference(inputs, f, c): def syncbn_stage1_inference(inputs, f, c):
input, channel_mean, channel_var, eps = inputs[0:4] input, channel_mean, channel_var, eps = inputs[0:4]
weight, bias = inputs[4:6] weight, bias = inputs[4:6]
...@@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels): ...@@ -1205,36 +1205,36 @@ def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
bias)) bias))
return (outvar,), (True,) return (outvar,), (True,)
@subgraph("SyncBnStage2", dtype, device, 7, gopt_level=3) @subgraph("SyncBnStage2", dtype, device, 7)
def syncbn_stage2(inputs, f, c): def syncbn_stage2(inputs, f, c):
running_mean, running_var, momentum = inputs[0:3] running_mean, running_var, momentum = inputs[0:3]
reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7] reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
running_mean = f("*", running_mean, momentum) c1_minus_momentum = f("-", c(1), momentum)
running_mean =\ reduce_size_minus_c1 = f("-", reduce_size, c(1))
f("+", running_mean, running_mean = f("fma4",
f("*", f("-", c(1), momentum), running_mean, momentum,
channel_mean)) c1_minus_momentum, channel_mean,
)
channel_variance_unbiased =\ channel_variance_unbiased =\
f("+", f("/", f("**", channel_x1s, c(2)), f("+", f("/", f("**", channel_x1s, c(2)),
f("*", f("-", reduce_size), f("*", f("-", reduce_size),
f("-", reduce_size, c(1)))), reduce_size_minus_c1)),
f("/", channel_x2s, f("/", channel_x2s,
f("-", reduce_size, c(1)))) reduce_size_minus_c1))
running_var = f("*", running_var, momentum) running_var = f("fma4",
running_var =\ running_var, momentum,
f("+", running_var, c1_minus_momentum, channel_variance_unbiased
f("*", f("-", c(1), momentum), )
channel_variance_unbiased))
return (running_mean, running_var), (True, True) return (running_mean, running_var), (True, True)
@subgraph("SyncBnConcatStats", dtype, device, 3, gopt_level=3) @subgraph("SyncBnConcatStats", dtype, device, 3)
def syncbn_concat_stats(inputs, f, c): def syncbn_concat_stats(inputs, f, c):
reduce_size, channel_x1s, channel_x2s = inputs[0:3] reduce_size, channel_x1s, channel_x2s = inputs[0:3]
reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32")) reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s) stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
return (stats,), (True,) return (stats,), (True,)
@subgraph("SyncBnSplitStats", dtype, device, 1, gopt_level=3) @subgraph("SyncBnSplitStats", dtype, device, 1)
def syncbn_split_stats(inputs, f, c): def syncbn_split_stats(inputs, f, c):
stats = inputs[0] stats = inputs[0]
c_1 = c(1, dtype="int32") c_1 = c(1, dtype="int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册