diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index ddcf0a9b004e790b616c9a9bd8dfe6fb0e3c9cbf..d7eb54ba571974b7a127a674546bd347264ef2b0 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -80,8 +80,12 @@ def composite_batchnorm( reshape(batch_var, stats_shape) + epsilon ) - run_mean = momentum * run_mean + (1 - momentum) * batch_mean - run_var = momentum * run_var + (1 - momentum) * batch_var + run_mean = momentum * run_mean + (1 - momentum) * reshape( + batch_mean, run_mean.shape + ) + run_var = momentum * run_var + (1 - momentum) * reshape( + batch_var, run_var.shape + ) else: x_hat = (x - reshape(run_mean, stats_shape)) / sqrt( reshape(run_var, stats_shape) + epsilon @@ -89,8 +93,8 @@ def composite_batchnorm( y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) # add op assign to detach tensor in void unsafe change outside the rule. - batch_mean_ = assign(batch_mean) - batch_var_ = assign(batch_var) + batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) + batch_var_ = assign(reshape(batch_var, run_var.shape)) run_mean_ = assign(run_mean) run_var_ = assign(run_var) if trainable_statistics or not is_test: