From 6c9a54afa12ecc722bd29d3a728a3923205f0c03 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 29 Apr 2020 03:34:58 -0400 Subject: [PATCH] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6e9236955..2a1ca28ed 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -119,7 +119,7 @@ class _BatchNorm(Cell): def _shape_infer(self, x): """global batch normalization shape and axes infer""" if len(self.shape(x)) == 4: - axes = (0,2,3) + axes = (0, 2, 3) re_shape = (1, self.num_features, 1, 1) else: axes = (0,) -- GitLab