From e389f2fcad738ecb58db7841188d0cd6ea535d22 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Thu, 9 Feb 2023 11:09:32 +0800 Subject: [PATCH] fix bn composite error shape (#50338) --- python/paddle/incubate/autograd/composite_rules.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index ddcf0a9b00..d7eb54ba57 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -80,8 +80,12 @@ def composite_batchnorm( reshape(batch_var, stats_shape) + epsilon ) - run_mean = momentum * run_mean + (1 - momentum) * batch_mean - run_var = momentum * run_var + (1 - momentum) * batch_var + run_mean = momentum * run_mean + (1 - momentum) * reshape( + batch_mean, run_mean.shape + ) + run_var = momentum * run_var + (1 - momentum) * reshape( + batch_var, run_var.shape + ) else: x_hat = (x - reshape(run_mean, stats_shape)) / sqrt( reshape(run_var, stats_shape) + epsilon @@ -89,8 +93,8 @@ def composite_batchnorm( y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) # add op assign to detach tensor in void unsafe change outside the rule. - batch_mean_ = assign(batch_mean) - batch_var_ = assign(batch_var) + batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) + batch_var_ = assign(reshape(batch_var, run_var.shape)) run_mean_ = assign(run_mean) run_var_ = assign(run_var) if trainable_statistics or not is_test: -- GitLab