未验证 提交 f47a5f7f 编写于 作者: C cyber-pioneer 提交者: GitHub

[prim] simplify batch_norm composite rule (#51827)

* simplify batch_norm composite rule

* polish code
上级 bef4e9f7
...@@ -1192,7 +1192,7 @@ void batch_norm_grad(const Tensor& x, ...@@ -1192,7 +1192,7 @@ void batch_norm_grad(const Tensor& x,
auto eps = auto eps =
full<T>(phi::vectorize(run_var.dims()), epsilon, run_var.dtype()); full<T>(phi::vectorize(run_var.dims()), epsilon, run_var.dtype());
mean_data = run_mean; mean_data = run_mean;
rsqrt_var = 1 / (run_var + eps).pow(0.5); rsqrt_var = (run_var + eps).pow(-0.5);
} else { } else {
mean_data = saved_mean; mean_data = saved_mean;
rsqrt_var = saved_variance; rsqrt_var = saved_variance;
......
...@@ -71,7 +71,11 @@ def composite_batchnorm( ...@@ -71,7 +71,11 @@ def composite_batchnorm(
use_global_stats, use_global_stats,
trainable_statistics, trainable_statistics,
): ):
"""define composite rule of op batch_norm""" """
define composite rule of op batch_norm
As the same with op kernel, the position of savedvariance indeed return inverse std.
"""
is_amp = False is_amp = False
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
...@@ -94,38 +98,42 @@ def composite_batchnorm( ...@@ -94,38 +98,42 @@ def composite_batchnorm(
1 if i in reduce_axes else s for i, s in enumerate(x.shape) 1 if i in reduce_axes else s for i, s in enumerate(x.shape)
) )
batch_mean = zeros(run_mean.shape, run_mean.dtype) half = -0.5
batch_var = zeros(run_var.shape, run_var.dtype)
if not use_run_stat:
batch_mean = mean(x, reduce_axes, keepdim=True) if not use_run_stat:
temp = mean(x * x, reduce_axes, keepdim=True) batch_mean = mean(x, reduce_axes)
temp = mean(x * x, reduce_axes)
batch_var = temp - batch_mean * batch_mean batch_var = temp - batch_mean * batch_mean
inv_std = pow((batch_var + epsilon), half)
if data_layout == "NHWC":
x_hat = (x - batch_mean) * inv_std
else:
x_hat = (x - reshape(batch_mean, stats_shape)) * reshape(
inv_std, stats_shape
)
x_hat = (x - reshape(batch_mean, stats_shape)) / sqrt( run_mean = momentum * run_mean + (1 - momentum) * batch_mean
reshape(batch_var, stats_shape) + epsilon run_var = momentum * run_var + (1 - momentum) * batch_var
)
run_mean = momentum * run_mean + (1 - momentum) * reshape(
batch_mean, run_mean.shape
)
run_var = momentum * run_var + (1 - momentum) * reshape(
batch_var, run_var.shape
)
else: else:
x_hat = (x - reshape(run_mean, stats_shape)) / sqrt( batch_mean = zeros(run_mean.shape, run_mean.dtype)
reshape(run_var, stats_shape) + epsilon batch_var = zeros(run_var.shape, run_var.dtype)
) inv_std = pow((batch_var + epsilon), half)
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) if data_layout == "NHWC":
x_hat = (x - run_mean) * pow((run_var + epsilon), half)
else:
x_hat = (x - reshape(run_mean, stats_shape)) * pow(
(reshape(run_var, stats_shape) + epsilon), half
)
if data_layout == "NHWC":
y = scale * x_hat + bias
else:
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape)
if is_amp: if is_amp:
y = cast(y, "float16") y = cast(y, "float16")
# As the same with op kernel, indeed return inverse std
inv_std = 1.0 / sqrt(batch_var + epsilon)
# add op assign to detach tensor in void unsafe change outside the rule. # add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) batch_mean_ = assign(batch_mean)
inv_std_ = assign(reshape(inv_std, run_var.shape)) inv_std_ = assign(inv_std)
run_mean_ = assign(run_mean) run_mean_ = assign(run_mean)
run_var_ = assign(run_var) run_var_ = assign(run_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册