提交 6c9a54af 编写于 作者: Z zhaojichen

fix globalbatchnorm bug

上级 8261cfd0
......@@ -119,7 +119,7 @@ class _BatchNorm(Cell):
def _shape_infer(self, x):
"""global batch normalization shape and axes infer"""
if len(self.shape(x)) == 4:
axes = (0,2,3)
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
else:
axes = (0,)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册