未验证 提交 45acb717 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Optimize composite rule by making scalar shape as 1 (#51960)

* optimize composite rule by making scalar shape as []1

* fix shape usage for 0D

* fix rules

* fix 0D error

* fix flatten 0D error

* fix bn eval mode

* fix bn test

* fix flatten
上级 864b50c3
......@@ -158,6 +158,15 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
'x4', shape=weight.shape, dtype=str(weight.dtype)
)
x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype))
if attrs.use_global_stats is None:
attrs.use_global_stats = not attrs.training
trainable_statistics = False
else:
trainable_statistics = not attrs.use_global_stats
use_run_stat = (
(not attrs.training) and (not trainable_statistics)
) or attrs.use_global_stats
y = fn(
x1,
x2,
......@@ -177,16 +186,27 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
blocks[0].ops[0].output_names, blocks[0].ops[0].output_arg_names
)
)
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
if not use_run_stat:
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
]
]
else:
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
]
]
]
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that batch_norm in original block
......@@ -202,21 +222,36 @@ def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
exe.run(startup_program)
# indeed SavedVariance is 1/sqrt(batch_var+eps)
Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
if not use_run_stat:
Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
else:
Y, MeanOut, VarianceOut = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
paddle.disable_static()
core._set_prim_all_enabled(False)
return Y, MeanOut, VarianceOut, SavedMean, SavedVariance
if not use_run_stat:
return Y, MeanOut, VarianceOut, SavedMean, SavedVariance
else:
return Y, MeanOut, VarianceOut
class TestCompositeBatchNorm(unittest.TestCase):
......
......@@ -134,8 +134,10 @@ def composite_batchnorm(
# reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition.
reserve_space = None
return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space
if not use_run_stat:
return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space
else:
return y, run_mean_, run_var_, None, None, reserve_space
@REGISTER_COMPOSITE('layer_norm')
......@@ -183,12 +185,13 @@ def gelu_composite(x, approximate):
0.70710678118654752440 # /* 1/sqrt(2) */ copy from gelu-kernel.cc
)
M_2_SQRTPI = 1.12837916709551257390 # /* 2/sqrt(pi) */
one = ones(x.shape, x.dtype)
half = full(x.shape, 0.5, x.dtype)
full_shape = x.shape if len(x.shape) == 0 else [1]
one = ones(full_shape, x.dtype)
half = full(full_shape, 0.5, x.dtype)
if approximate:
# gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
kAlpha = full(x.shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(x.shape, 0.044715, x.dtype)
kAlpha = full(full_shape, M_2_SQRTPI * M_SQRT1_2, x.dtype)
GELU_CONSTANT = full(full_shape, 0.044715, x.dtype)
tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x))
out = x * half * (one + tanh_out)
return out
......@@ -210,7 +213,7 @@ def mean_composite(x, axis, keepdim):
operator.mul, [x.shape[axis] for axis in axes]
)
norm = fill_constant(
shape=sum_x.shape,
shape=x.shape if len(x.shape) == 0 else [1],
value=value_to_fill,
dtype=sum_x.dtype,
)
......@@ -316,7 +319,9 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis):
start_dim = start_axis if len(shape_in) != 0 else 0
end_dim = stop_axis if len(shape_in) != 0 else 0
assert start_dim <= end_dim
if len(shape_in) == 0 or start_dim == end_dim:
if len(shape_in) == 0:
return reshape(x, shape=[1]), None
if start_dim == end_dim:
return reshape(x, shape=shape_in), None
slice_numel = 1
for i in range(start_dim, end_dim + 1):
......@@ -372,7 +377,7 @@ def bernoulli(shape, dtype, p, seed=0):
return cast(
greater_equal(
uniform(shape, new_dtype, min=0.0, max=1.0, seed=seed),
fill_constant(shape, new_dtype, p),
fill_constant(shape if len(shape) == 0 else [1], new_dtype, p),
),
dtype,
)
......@@ -389,16 +394,17 @@ def hard_swish_composite(x):
offset = 3.0
threshold = 6.0
scale = 6.0
full_shape = x.shape if len(x.shape) == 0 else [1]
res = (
minimum(
maximum(
x + full(x.shape, offset, dtype=x.dtype),
full(x.shape, 0.0, dtype=x.dtype),
x + full(full_shape, offset, dtype=x.dtype),
full(full_shape, 0.0, dtype=x.dtype),
),
full(x.shape, threshold, dtype=x.dtype),
full(full_shape, threshold, dtype=x.dtype),
)
* x
/ full(x.shape, scale, dtype=x.dtype)
/ full(full_shape, scale, dtype=x.dtype)
)
return res
......@@ -499,7 +505,7 @@ def sqrt_composite(x):
define composite rule of op sqrt
res = pow(x, 0.5)
"""
y = full(x.shape, 0.5, x.dtype)
y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype)
res = pow(x, y)
return res
......@@ -511,7 +517,7 @@ def pow_composite(x, y):
res = x^y
"""
if isinstance(y, (int, float)):
y = full([1], y, x.dtype)
y = full(x.shape if len(x.shape) == 0 else [1], y, x.dtype)
res = pow(x, y)
return res
......@@ -520,7 +526,10 @@ def pow_composite(x, y):
def relu_composite(x):
"""define composite rule of op relu."""
# relu(x) = max(x, 0)
return maximum(x, zeros_like(x))
if len(x.shape) == 0:
return maximum(x, full(x.shape, 0.0, x.dtype))
else:
return maximum(x, full([1], 0.0, x.dtype))
@REGISTER_COMPOSITE('unsqueeze2')
......@@ -547,5 +556,5 @@ def unsqueeze_composite(x, axis):
def rsqrt_composite(x):
"""define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5)
y = full(x.shape, -0.5, x.dtype)
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
return pow(x, y)
......@@ -633,14 +633,13 @@ def _lower_composite(
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out: {orig_out.name}.dtype={orig_out.dtype} and new_out: {new_out.name}.dtype={new_out.dtype}'
)
if orig_out.shape and new_out.shape:
assert (
-1 not in new_out.shape
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert orig_out.shape == new_out.shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}'
)
assert (
-1 not in new_out.shape
), f'when replace origin op {op_name} with composite rule, composite out shape has -1.'
assert orig_out.shape == new_out.shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out: {orig_out.name}.shape={orig_out.shape} and new_out: {new_out.name}.shape={new_out.shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册