未验证 提交 e389f2fc 编写于 作者: J Jiabin Yang 提交者: GitHub

fix bn composite error shape (#50338)

上级 f11c913e
......@@ -80,8 +80,12 @@ def composite_batchnorm(
reshape(batch_var, stats_shape) + epsilon
)
run_mean = momentum * run_mean + (1 - momentum) * batch_mean
run_var = momentum * run_var + (1 - momentum) * batch_var
run_mean = momentum * run_mean + (1 - momentum) * reshape(
batch_mean, run_mean.shape
)
run_var = momentum * run_var + (1 - momentum) * reshape(
batch_var, run_var.shape
)
else:
x_hat = (x - reshape(run_mean, stats_shape)) / sqrt(
reshape(run_var, stats_shape) + epsilon
......@@ -89,8 +93,8 @@ def composite_batchnorm(
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape)
# add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(batch_mean)
batch_var_ = assign(batch_var)
batch_mean_ = assign(reshape(batch_mean, run_mean.shape))
batch_var_ = assign(reshape(batch_var, run_var.shape))
run_mean_ = assign(run_mean)
run_var_ = assign(run_var)
if trainable_statistics or not is_test:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册