未验证 提交 f47a5f7f 编写于 作者: C cyber-pioneer 提交者: GitHub

[prim] simplify batch_norm composite rule (#51827)

* simplify batch_norm composite rule

* polish code
上级 bef4e9f7
......@@ -1192,7 +1192,7 @@ void batch_norm_grad(const Tensor& x,
auto eps =
full<T>(phi::vectorize(run_var.dims()), epsilon, run_var.dtype());
mean_data = run_mean;
rsqrt_var = 1 / (run_var + eps).pow(0.5);
rsqrt_var = (run_var + eps).pow(-0.5);
} else {
mean_data = saved_mean;
rsqrt_var = saved_variance;
......
......@@ -71,7 +71,11 @@ def composite_batchnorm(
use_global_stats,
trainable_statistics,
):
"""define composite rule of op batch_norm"""
"""
define composite rule of op batch_norm
As the same with op kernel, the position of savedvariance indeed return inverse std.
"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
......@@ -94,38 +98,42 @@ def composite_batchnorm(
1 if i in reduce_axes else s for i, s in enumerate(x.shape)
)
batch_mean = zeros(run_mean.shape, run_mean.dtype)
batch_var = zeros(run_var.shape, run_var.dtype)
if not use_run_stat:
half = -0.5
batch_mean = mean(x, reduce_axes, keepdim=True)
temp = mean(x * x, reduce_axes, keepdim=True)
if not use_run_stat:
batch_mean = mean(x, reduce_axes)
temp = mean(x * x, reduce_axes)
batch_var = temp - batch_mean * batch_mean
x_hat = (x - reshape(batch_mean, stats_shape)) / sqrt(
reshape(batch_var, stats_shape) + epsilon
inv_std = pow((batch_var + epsilon), half)
if data_layout == "NHWC":
x_hat = (x - batch_mean) * inv_std
else:
x_hat = (x - reshape(batch_mean, stats_shape)) * reshape(
inv_std, stats_shape
)
run_mean = momentum * run_mean + (1 - momentum) * reshape(
batch_mean, run_mean.shape
)
run_var = momentum * run_var + (1 - momentum) * reshape(
batch_var, run_var.shape
)
run_mean = momentum * run_mean + (1 - momentum) * batch_mean
run_var = momentum * run_var + (1 - momentum) * batch_var
else:
batch_mean = zeros(run_mean.shape, run_mean.dtype)
batch_var = zeros(run_var.shape, run_var.dtype)
inv_std = pow((batch_var + epsilon), half)
if data_layout == "NHWC":
x_hat = (x - run_mean) * pow((run_var + epsilon), half)
else:
x_hat = (x - reshape(run_mean, stats_shape)) / sqrt(
reshape(run_var, stats_shape) + epsilon
x_hat = (x - reshape(run_mean, stats_shape)) * pow(
(reshape(run_var, stats_shape) + epsilon), half
)
if data_layout == "NHWC":
y = scale * x_hat + bias
else:
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape)
if is_amp:
y = cast(y, "float16")
# As the same with op kernel, indeed return inverse std
inv_std = 1.0 / sqrt(batch_var + epsilon)
# add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(reshape(batch_mean, run_mean.shape))
inv_std_ = assign(reshape(inv_std, run_var.shape))
batch_mean_ = assign(batch_mean)
inv_std_ = assign(inv_std)
run_mean_ = assign(run_mean)
run_var_ = assign(run_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册