未验证 提交 b0ec8efb 编写于 作者: X Xiaoxu Chen 提交者: GitHub

add batch_norm orig2prim transform rule (#46446)

* Support both use_calc_stream and sync_op in send recv APIs (#46023)

* add batch_norm prim2orig rule
Co-authored-by: NWen Sun <35923278+HermitSun@users.noreply.github.com>
上级 23c50648
...@@ -900,5 +900,81 @@ class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim): ...@@ -900,5 +900,81 @@ class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']} self.out_map = {0: self.output['Out']}
class TestBatchnormOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'batch_norm'
x = paddle.static.data(name='X', shape=[5, 8], dtype='float')
m = paddle.static.data(name='Mean', shape=[8], dtype='float')
v = paddle.static.data(name='Variance', shape=[8], dtype='float')
w = paddle.static.data(name='Scale', shape=[8], dtype='float')
b = paddle.static.data(name='Bias', shape=[8], dtype='float')
self.input = {
"X": [x],
"Scale": [w],
"Bias": [b],
"Mean": [m],
"Variance": [v]
}
saved_variance = self.layer_help.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
batch_norm_out = self.layer_help.create_variable_for_type_inference(
x.dtype)
saved_mean = self.layer_help.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
self.output = {
"Y": [batch_norm_out],
"MeanOut": [m],
"VarianceOut": [v],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}
self.attrs = {
"momentum": 0.9,
"epsilon": 1e-5,
"is_test": False,
"data_layout": 'NCHW',
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": False,
"trainable_statistics": False,
}
self.orig2prim_args = (b, m, None, w, v, x)
self.all_ops = [
'add_p', 'add_p', 'add_p', 'add_p', 'batch_norm', 'broadcast_p',
'broadcast_p', 'broadcast_p', 'broadcast_p', 'broadcast_p', 'div_p',
'div_p', 'div_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'fill_constant_p', 'fill_constant_p',
'fill_constant_p', 'mul_p', 'mul_p', 'mul_p', 'mul_p', 'mul_p',
'pow_p', 'reduce_sum_p', 'reduce_sum_p', 'reshape_p', 'reshape_p',
'reshape_p', 'reshape_p', 'sqrt_p', 'sub_p', 'sub_p', 'sub_p',
'sub_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {}
class TestFillConstantOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'fill_constant'
self.attrs = {'value': 1., 'shape': (2, 3), 'dtype': paddle.float32}
self.input = {}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.float32)
}
self.orig2prim_args = (None, None, None)
self.all_ops = ['fill_constant', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -318,7 +318,39 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) ...@@ -318,7 +318,39 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
lambda x: paddle.var(x, axis=1, unbiased=False), lambda x: paddle.var(x, axis=1, unbiased=False),
(np.random.rand(10, 20, 30), ), None, 'float32'), (np.random.rand(10, 20, 30), ), None, 'float32'),
('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True), ('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True),
(np.random.rand(10, 20, 30), ), None, 'float32'))) (np.random.rand(10, 20, 30), ), None, 'float32'),
('bn', lambda x, w, b: paddle.nn.functional.batch_norm(
x, paddle.ones((10, )), paddle.ones(
(10, )), w, b), (np.random.rand(10, 10), np.random.rand(10),
np.random.rand(10)), None, 'float32'),
('bn_train', lambda x, w, b: paddle.nn.functional.batch_norm(
x, paddle.ones((10, )), paddle.ones((10, )), w, b, training=True),
(np.random.rand(
10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'),
('bn_nhwc', lambda x, w, b: paddle.nn.functional.batch_norm(
x,
paddle.ones((10, )) + 1,
paddle.ones((10, )),
w,
b,
training=True,
data_format='NHWC',
), (np.random.rand(
10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'),
('bn_global_stat',
lambda x, w, b: paddle.nn.functional.batch_norm(x,
paddle.ones(
(10, )) + 3.2,
paddle.ones(
(10, )) + 6.7,
w,
b,
training=True,
data_format='NHWC',
use_global_stats=True),
(np.random.rand(
10, 10), np.random.rand(10), np.random.rand(10)), None, 'float32'),
))
class TestGrad(unittest.TestCase): class TestGrad(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import operator
import paddle import paddle
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
...@@ -92,6 +95,97 @@ def set_value(x, y, axis, starts, ends, strides, out): ...@@ -92,6 +95,97 @@ def set_value(x, y, axis, starts, ends, strides, out):
return out return out
def mean(x, axis=None, keepdim=False):
axes = axis or tuple(range(0, len(x.shape)))
sum = reduce_sum(x, axis=axes, keepdim=keepdim)
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)
def ones(shape, dtype):
return fill_const(1, shape, dtype)
def zeros(shape, dtype):
return fill_const(0, shape, dtype)
def batch_norm(x,
axis,
gamma,
beta,
run_mean,
run_var,
eps=1e-5,
momentum=0.9,
use_run_stat=False,
reserve_space=None):
"""batch normalizer.
Args:
x (Tensor): A tensor to be normalized.
axis (int): The features axis.
gamma (Tensor): The scale factor.
beta (float): The shift factor.
run_mean (Tensor): Running mean.
run_var (Tensor): Running variance.
eps (float, optional): A value added to the denominator for numerical
stability. Defaults to 1e-5.
momentum (float, optional): The value used for the running_mean and
running_var computation. Can be set to None for cumulative moving
average (i.e. simple average). Defaults to 0.9.
use_run_stat (bool, optional): Whether or not using runing statistics.
Defaults to False.
"""
reduce_axes = tuple(i for i in range(len(x.shape)) if i != axis)
stats_shape = tuple(1 if i in reduce_axes else s
for i, s in enumerate(x.shape))
batch_mean = zeros(run_mean.shape, run_mean.dtype)
batch_var = zeros(run_var.shape, run_var.dtype)
if not use_run_stat:
batch_mean = mean(x, reduce_axes, keepdim=True)
batch_var = mean(square(sub(x, broadcast(batch_mean, x.shape))),
reduce_axes,
keepdim=True)
x_hat = div(
sub(x, broadcast(batch_mean, x.shape)),
sqrt(
add(broadcast(batch_var, x.shape),
fill_const(eps, x.shape, batch_var.dtype))))
momentum = fill_const(momentum, run_mean.shape, run_mean.dtype)
run_mean = add(
mul(momentum, run_mean),
mul(sub(ones(run_mean.shape, run_mean.dtype), momentum),
reshape(batch_mean, run_mean.shape)))
run_var = add(
mul(momentum, run_var),
mul(sub(ones(run_var.shape, run_var.dtype), momentum),
reshape(batch_var, run_var.shape)))
else:
x_hat = div(
sub(x, broadcast(reshape(run_mean, stats_shape), x.shape)),
sqrt(
add(broadcast(reshape(run_var, stats_shape), x.shape),
fill_const(eps, x.shape, x.dtype))))
y = add(mul(broadcast(reshape(gamma, stats_shape), x_hat.shape), x_hat),
broadcast(reshape(beta, stats_shape), x_hat.shape))
if reserve_space:
return run_mean, reserve_space, batch_mean, batch_var, run_var, y
else:
return run_mean, batch_mean, batch_var, run_var, y
def square(x):
return pow(x, fill_const(2., x.shape, x.dtype))
@REGISTER_FN('add_p', 'X', 'Y', 'Z') @REGISTER_FN('add_p', 'X', 'Y', 'Z')
def add(x, y, out=None): def add(x, y, out=None):
return _simple_binop(LayerHelper('add_p', **locals())) return _simple_binop(LayerHelper('add_p', **locals()))
......
...@@ -215,6 +215,20 @@ def fill_any_like_orig2prim(op, x): ...@@ -215,6 +215,20 @@ def fill_any_like_orig2prim(op, x):
convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')]))) convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')])))
@REGISTER_ORIG2PRIM('fill_constant')
def fill_const_orig2prim(op,
shape_tensor=None,
shape_tensor_list=None,
value_tensor=None):
if shape_tensor or shape_tensor_list or value_tensor:
raise TypeError(
'fill_const_orig2prim currently not support Tensor input of shape and value.'
)
return fill_const(value=op.attr('value'),
shape=op.attr('shape'),
dtype=paddle.dtype(op.attr('dtype')))
@REGISTER_ORIG2PRIM('sum') @REGISTER_ORIG2PRIM('sum')
def sum_orig2prim(op, xs): def sum_orig2prim(op, xs):
x0 = xs[0] x0 = xs[0]
...@@ -391,7 +405,7 @@ def pow_orig2prim(op, x, y): ...@@ -391,7 +405,7 @@ def pow_orig2prim(op, x, y):
@REGISTER_ORIG2PRIM('square') @REGISTER_ORIG2PRIM('square')
def square_orig2prim(op, x): def square_orig2prim(op, x):
return primops.pow(x, fill_const(2., x.shape, x.dtype)) return primops.square(x)
@REGISTER_ORIG2PRIM('elementwise_max') @REGISTER_ORIG2PRIM('elementwise_max')
...@@ -436,12 +450,35 @@ def reduce_sum_orig2prim(op, x): ...@@ -436,12 +450,35 @@ def reduce_sum_orig2prim(op, x):
def reduce_mean_orig2prim(op, x): def reduce_mean_orig2prim(op, x):
axes = tuple(range(0, len( axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim') x.shape))) if op.attr('reduce_all') else op.attr('dim')
sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim')) return primops.mean(x, axes, op.attr('keep_dim'))
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]), @REGISTER_ORIG2PRIM('batch_norm')
dtype=sum.dtype) def batch_norm_orig2prim(op, bias, run_mean, momentum_tensor, scale, run_var,
return div(sum, norm) x):
momentum = op.attr('momentum')
eps = op.attr('epsilon')
is_test = op.attr('is_test')
data_layout = op.attr('data_layout')
use_global_stats = op.attr('use_global_stats')
trainable_statistics = op.attr('trainable_statistics')
reserve_space = None if len(
op.output_names) == 5 else get_output_var_list(op)[1]
feature_axis = 1 if data_layout in ('NC', 'NCL', 'NCHW',
'NCHWD') else len(x.shape) - 1
use_run_stat = (is_test and (not trainable_statistics)) or use_global_stats
return primops.batch_norm(x,
feature_axis,
scale,
bias,
run_mean,
run_var,
eps=eps,
momentum=momentum,
use_run_stat=use_run_stat,
reserve_space=reserve_space)
@REGISTER_ORIG2PRIM('size') @REGISTER_ORIG2PRIM('size')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册