未验证 提交 45acb717 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Optimize composite rule by making scalar shape as 1 (#51960)

* optimize composite rule by making scalar shape as []1

* fix shape usage for 0D

* fix rules

* fix 0D error

* fix flatten 0D error

* fix bn eval mode

* fix bn test

* fix flatten
上级 864b50c3
...@@ -158,6 +158,15 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): ...@@ -158,6 +158,15 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
'x4', shape=weight.shape, dtype=str(weight.dtype) 'x4', shape=weight.shape, dtype=str(weight.dtype)
) )
x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype)) x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype))
if attrs.use_global_stats is None:
attrs.use_global_stats = not attrs.training
trainable_statistics = False
else:
trainable_statistics = not attrs.use_global_stats
use_run_stat = (
(not attrs.training) and (not trainable_statistics)
) or attrs.use_global_stats
y = fn( y = fn(
x1, x1,
x2, x2,
...@@ -177,16 +186,27 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): ...@@ -177,16 +186,27 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
blocks[0].ops[0].output_names, blocks[0].ops[0].output_arg_names blocks[0].ops[0].output_names, blocks[0].ops[0].output_arg_names
) )
) )
vars_list = [
names[key] if not use_run_stat:
for key in [ vars_list = [
"Y", names[key]
"MeanOut", for key in [
"VarianceOut", "Y",
"SavedMean", "MeanOut",
"SavedVariance", "VarianceOut",
"SavedMean",
"SavedVariance",
]
]
else:
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
]
] ]
]
fwd_ops = [op.type for op in blocks[0].ops] fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that batch_norm in original block # Ensure that batch_norm in original block
...@@ -202,21 +222,36 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None): ...@@ -202,21 +222,36 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
exe.run(startup_program) exe.run(startup_program)
# indeed SavedVariance is 1/sqrt(batch_var+eps) # indeed SavedVariance is 1/sqrt(batch_var+eps)
Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run( if not use_run_stat:
main_program, Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run(
feed={ main_program,
'x1': inputs, feed={
'x2': running_mean, 'x1': inputs,
'x3': running_variance, 'x2': running_mean,
'x4': weight, 'x3': running_variance,
'x5': bias, 'x4': weight,
}, 'x5': bias,
fetch_list=vars_list, },
) fetch_list=vars_list,
)
else:
Y, MeanOut, VarianceOut = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
paddle.disable_static() paddle.disable_static()
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
if not use_run_stat:
return Y, MeanOut, VarianceOut, SavedMean, SavedVariance return Y, MeanOut, VarianceOut, SavedMean, SavedVariance
else:
return Y, MeanOut, VarianceOut
class TestCompositeBatchNorm(unittest.TestCase): class TestCompositeBatchNorm(unittest.TestCase):
......
...@@ -134,8 +134,10 @@ def composite_batchnorm( ...@@ -134,8 +134,10 @@ def composite_batchnorm(
# reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition. # reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition.
reserve_space = None reserve_space = None
if not use_run_stat:
return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space
else:
return y, run_mean_, run_var_, None, None, reserve_space
@REGISTER_COMPOSITE('layer_norm') @REGISTER_COMPOSITE('layer_norm')
...@@ -183,12 +185,13 @@ def gelu_composite(x, approximate): ...@@ -183,12 +185,13 @@ def gelu_composite(x, approximate):
0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc 0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc
) )
M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */ M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */
one = ones(x.shape, x.dtype) full_shape = x.shape if len(x.shape) == 0 else [1]
half = full(x.shape, 0.5, x.dtype) one = ones(full_shape, x.dtype)
half = full(full_shape, 0.5, x.dtype)
if approximate: if approximate:
# gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3}))) # gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
kAlpha = full(x.shape, M_2_SQRTPI * M_SQRT1_2, x.dtype) kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(x.shape, 0.044715, x.dtype) GELU_CONSTANT = full(full_shape, 0.044715, x.dtype)
tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x))
out = x * half * (one + tanh_out) out = x * half * (one + tanh_out)
return out return out
...@@ -210,7 +213,7 @@ def mean_composite(x, axis, keepdim): ...@@ -210,7 +213,7 @@ def mean_composite(x, axis, keepdim):
operator.mul, [x.shape[axis] for axis in axes] operator.mul, [x.shape[axis] for axis in axes]
) )
norm = fill_constant( norm = fill_constant(
shape=sum_x.shape, shape=x.shape if len(x.shape) == 0 else [1],
value=value_to_fill, value=value_to_fill,
dtype=sum_x.dtype, dtype=sum_x.dtype,
) )
...@@ -316,7 +319,9 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis): ...@@ -316,7 +319,9 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis):
start_dim = start_axis if len(shape_in) != 0 else 0 start_dim = start_axis if len(shape_in) != 0 else 0
end_dim = stop_axis if len(shape_in) != 0 else 0 end_dim = stop_axis if len(shape_in) != 0 else 0
assert start_dim <= end_dim assert start_dim <= end_dim
if len(shape_in) == 0 or start_dim == end_dim: if len(shape_in) == 0:
return reshape(x, shape=[1]), None
if start_dim == end_dim:
return reshape(x, shape=shape_in), None return reshape(x, shape=shape_in), None
slice_numel = 1 slice_numel = 1
for i in range(start_dim, end_dim + 1): for i in range(start_dim, end_dim + 1):
...@@ -372,7 +377,7 @@ def bernoulli(shape, dtype, p, seed=0): ...@@ -372,7 +377,7 @@ def bernoulli(shape, dtype, p, seed=0):
return cast( return cast(
greater_equal( greater_equal(
uniform(shape, new_dtype, min=0.0, max=1.0, seed=seed), uniform(shape, new_dtype, min=0.0, max=1.0, seed=seed),
fill_constant(shape, new_dtype, p), fill_constant(shape if len(shape) == 0 else [1], new_dtype, p),
), ),
dtype, dtype,
) )
...@@ -389,16 +394,17 @@ def hard_swish_composite(x): ...@@ -389,16 +394,17 @@ def hard_swish_composite(x):
offset = 3.0 offset = 3.0
threshold = 6.0 threshold = 6.0
scale = 6.0 scale = 6.0
full_shape = x.shape if len(x.shape) == 0 else [1]
res = ( res = (
minimum( minimum(
maximum( maximum(
x + full(x.shape, offset, dtype=x.dtype), x + full(full_shape, offset, dtype=x.dtype),
full(x.shape, 0.0, dtype=x.dtype), full(full_shape, 0.0, dtype=x.dtype),
), ),
full(x.shape, threshold, dtype=x.dtype), full(full_shape, threshold, dtype=x.dtype),
) )
* x * x
/ full(x.shape, scale, dtype=x.dtype) / full(full_shape, scale, dtype=x.dtype)
) )
return res return res
...@@ -499,7 +505,7 @@ def sqrt_composite(x): ...@@ -499,7 +505,7 @@ def sqrt_composite(x):
define composite rule of op sqrt define composite rule of op sqrt
res = pow(x, 0.5) res = pow(x, 0.5)
""" """
y = full(x.shape, 0.5, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype)
res = pow(x, y) res = pow(x, y)
return res return res
...@@ -511,7 +517,7 @@ def pow_composite(x, y): ...@@ -511,7 +517,7 @@ def pow_composite(x, y):
res = x^y res = x^y
""" """
if isinstance(y, (int, float)): if isinstance(y, (int, float)):
y = full([1], y, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype)
res = pow(x, y) res = pow(x, y)
return res return res
...@@ -520,7 +526,10 @@ def pow_composite(x, y): ...@@ -520,7 +526,10 @@ def pow_composite(x, y):
def relu_composite(x): def relu_composite(x):
"""define composite rule of op relu.""" """define composite rule of op relu."""
# relu(x) = max(x, 0) # relu(x) = max(x, 0)
return maximum(x, zeros_like(x)) if len(x.shape) == 0:
return maximum(x, full(x.shape, 0.0, x.dtype))
else:
return maximum(x, full([1], 0.0, x.dtype))
@REGISTER_COMPOSITE('unsqueeze2') @REGISTER_COMPOSITE('unsqueeze2')
...@@ -547,5 +556,5 @@ def unsqueeze_composite(x, axis): ...@@ -547,5 +556,5 @@ def unsqueeze_composite(x, axis):
def rsqrt_composite(x): def rsqrt_composite(x):
"""define composite rule of op rsqrt.""" """define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5) # rsqrt(x) = x^(-0.5)
y = full(x.shape, -0.5, x.dtype) y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
return pow(x, y) return pow(x, y)
...@@ -633,14 +633,13 @@ def _lower_composite( ...@@ -633,14 +633,13 @@ def _lower_composite(
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, ' f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out: {orig_out.name}.dtype={orig_out.dtype} and new_out: {new_out.name}.dtype={new_out.dtype}' f'but orig_out: {orig_out.name}.dtype={orig_out.dtype} and new_out: {new_out.name}.dtype={new_out.dtype}'
) )
if orig_out.shape and new_out.shape: assert (
assert ( -1 not in new_out.shape
-1 not in new_out.shape ), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.' assert orig_out.shape == new_out.shape, (
assert orig_out.shape == new_out.shape, ( f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}'
f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}' )
)
assert not (orig_out is None) ^ ( assert not (orig_out is None) ^ (
new_out is None new_out is None
), "orig_out and new_out should match." ), "orig_out and new_out should match."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册