From 7995189c72b1656763260dbf636069df2b1e1f7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E4=B8=9C=E6=97=AD?= Date: Tue, 16 Jun 2020 20:47:08 +0800 Subject: [PATCH] fix FakeQuantPerLayer/FakeQuantPerLayerGrad symmetric bug and remove BNTrainingReduceGrad/BNTrainingUpdateGrad --- mindspore/nn/layer/quant.py | 102 ++++++++---------- .../ops/_op_impl/_custom_op/batchnorm_fold.py | 8 +- .../_custom_op/fake_quant_perlayer.py | 17 ++- .../_custom_op/fake_quant_perlayer_grad.py | 18 ++-- mindspore/ops/operations/__init__.py | 2 - mindspore/ops/operations/_grad_ops.py | 27 ----- mindspore/ops/operations/_quant_ops.py | 31 +----- mindspore/ops/operations/nn_ops.py | 25 +++-- tests/st/ops/gpu/test_batchnorm_fold_op.py | 6 +- 9 files changed, 83 insertions(+), 153 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index af30af215..f573fc562 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -47,7 +47,7 @@ class BatchNormFoldCell(Cell): Batch normalization folded. Args: - momentum (float): Momentum value should be [0, 1]. Default: 0.1. + momentum (float): Momentum value should be [0, 1]. Default: 0.9. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in float32 else 1e-3. Default: 1e-5. freeze_bn (int): Delay in steps at which computation switches from regular batch @@ -69,12 +69,11 @@ class BatchNormFoldCell(Cell): """ - def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0, freeze_bn_ascend=True): + def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0): """init batch norm fold layer""" super(BatchNormFoldCell, self).__init__() self.epsilon = epsilon self.is_gpu = context.get_context('device_target') == "GPU" - self.freeze_bn_ascend = freeze_bn_ascend if self.is_gpu: self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn) self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn) @@ -89,7 +88,7 @@ class BatchNormFoldCell(Cell): else: batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step) else: - if self.training and not self.freeze_bn_ascend: + if self.training: x_sum, x_square_sum = self.bn_reduce(x) _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \ self.bn_update(x, x_sum, x_square_sum, mean, variance) @@ -226,17 +225,17 @@ class Conv2dBatchNormQuant(Cell): pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". padding: (int): Implicit paddings on both sides of the input. Default: 0. eps (int): Parameters for BatchNormal. Default: 1e-5. - momentum (int): Parameters for BatchNormal op. Default: 0.9. + momentum (int): Parameters for BatchNormal op. Default: 0.997. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - convolution kernel. Default: 'None'. + convolution kernel. Default: 'normal'. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - beta vector. Default: 'None'. + beta vector. Default: 'zeros'. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - gamma vector. Default: 'None'. + gamma vector. Default: 'ones'. mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - mean vector. Default: 'None'. + mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - variance vector. Default: 'None'. + variance vector. Default: 'ones'. quant_delay (int): Quantization delay parameters according by global step. Default: 0. freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. @@ -269,19 +268,18 @@ class Conv2dBatchNormQuant(Cell): group=1, eps=1e-5, momentum=0.997, - weight_init=None, - beta_init=None, - gamma_init=None, - mean_init=None, - var_init=None, + weight_init='normal', + beta_init='zeros', + gamma_init='ones', + mean_init='zeros', + var_init='ones', quant_delay=0, freeze_bn=100000, fake=True, num_bits=8, per_channel=False, symmetric=False, - narrow_range=False, - freeze_bn_ascend=True): + narrow_range=False): """init Conv2dBatchNormQuant layer""" super(Conv2dBatchNormQuant, self).__init__() self.in_channels = in_channels @@ -302,7 +300,6 @@ class Conv2dBatchNormQuant(Cell): self.symmetric = symmetric self.narrow_range = narrow_range self.is_gpu = context.get_context('device_target') == "GPU" - self.freeze_bn_ascend = freeze_bn_ascend # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: @@ -314,8 +311,7 @@ class Conv2dBatchNormQuant(Cell): pad=padding, stride=self.stride, dilation=self.dilation) - if weight_init is None: - weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) + weight_shape = [1, in_channels, *self.kernel_size] channel_axis = 1 else: self.conv = P.Conv2D(out_channel=out_channels, @@ -325,24 +321,16 @@ class Conv2dBatchNormQuant(Cell): stride=self.stride, dilation=self.dilation, group=group) - if weight_init is None: - weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) + weight_shape = [out_channels, in_channels // group, *self.kernel_size] channel_axis = 0 - self.weight = Parameter(weight_init, name='weight') + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') # initialize batchnorm Parameter - if gamma_init is None: - gamma_init = initializer('ones', [out_channels]) - self.gamma = Parameter(gamma_init, name='gamma') - if beta_init is None: - beta_init = initializer('zeros', [out_channels]) - self.beta = Parameter(beta_init, name='beta') - if mean_init is None: - mean_init = initializer('zeros', [out_channels]) - self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False) - if var_init is None: - var_init = initializer('ones', [out_channels]) - self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False) + self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma') + self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta') + self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False) + self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance', + requires_grad=False) # initialize fake ops self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, @@ -371,12 +359,10 @@ class Conv2dBatchNormQuant(Cell): def extend_repr(self): s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels, - self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, - self.group, - self.fake, self.freeze_bn, self.momentum, - self.quant_delay) + 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.fake, self.freeze_bn, self.momentum, self.quant_delay) return s def construct(self, x): @@ -401,7 +387,7 @@ class Conv2dBatchNormQuant(Cell): out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, self.step) else: - if self.training and not self.freeze_bn_ascend: + if self.training: out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) F.control_depend(out, self.assignadd(self.step, self.one)) else: @@ -427,8 +413,8 @@ class Conv2dQuant(Cell): divisible by the number of groups. Default: 1. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. - Default: None. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None. + Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. quant_delay (int): Quantization delay parameters according by global step. Default: 0. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. @@ -458,8 +444,8 @@ class Conv2dQuant(Cell): dilation=1, group=1, has_bias=False, - weight_init=None, - bias_init=None, + weight_init='normal', + bias_init='zeros', quant_delay=0, num_bits=8, per_channel=False, @@ -480,15 +466,14 @@ class Conv2dQuant(Cell): self.group = group self.quant_delay = quant_delay - if weight_init is None: - weight_init = initializer( - 'normal', [out_channels, in_channels // group, *self.kernel_size]) - self.weight = Parameter(weight_init, name='weight') - if bias_init is None: - bias_init = initializer('zeros', [out_channels]) - if has_bias: - self.bias = Parameter(bias_init, name='bias') - self.bias_add = P.BiasAdd() + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + + self.bias_add = P.BiasAdd() + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + self.bias = None self.conv = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, @@ -518,9 +503,10 @@ class Conv2dQuant(Cell): def extend_repr(self): s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.quant_delay) + 'has_bias={}, quant_delay={}'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.quant_delay) return s diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py index 39549ccfc..11434223d 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py @@ -65,7 +65,6 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", kernel_name="batchnorm_fold"): """batchnorm_fold TBE op""" - momentum = 1.0 - momentum util.check_kernel_name(kernel_name) data_format = data_format.upper() if data_format != "NCHW": @@ -120,13 +119,12 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) mean_square = te.lang.cce.vmul(batch_mean, batch_mean) batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) - + batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon)) if num == 1: batch_var_scaler = 0.0 else: batch_var_scaler = float(num) / (num - 1) - batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) - batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon)) + batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) factor = 1.0 - momentum factor_reverse = momentum @@ -134,7 +132,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) - var_mul = te.lang.cce.vmuls(batch_variance, factor) + var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor) var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py index 20b39dc25..3e75e9e0a 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py @@ -50,15 +50,16 @@ def _fake_quant_per_layer_tbe(): @fusion_manager.register("fake_quant_per_layer") -def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, +def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric, kernel_name="fake_quant_per_layer"): """FakeQuantPerLayer""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype) - min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) - max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) + if symmetric: + max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val) + min_val = te.lang.cce.vmuls(max_val, -1.) # CalNudge(NudgeMinMax) scale = te.lang.cce.vdiv(te.lang.cce.vsub( @@ -119,12 +120,8 @@ def fake_quant_per_layer(x, min_val, max_val, y, input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 + quant_min = 0 + quant_max = 2 ** num_bits - 1 if narrow_range: quant_min = quant_min + 1 @@ -132,7 +129,7 @@ def fake_quant_per_layer(x, min_val, max_val, y, min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, - quant_min, quant_max, kernel_name) + quant_min, quant_max, symmetric, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py index 9a5b8bc7d..a78effcc4 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py @@ -78,7 +78,7 @@ def _fake_quant_per_layer_grad_tbe(): @fusion_manager.register("fake_quant_per_layer_grad") -def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, +def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, symmetric, kernel_name="fake_quant_per_layer_grad"): """FakeQuantPerLayerGrad""" shape = te.lang.cce.util.shape_to_list(x.shape) @@ -88,6 +88,10 @@ def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quan quant_min = te.lang.cce.broadcast(quant_min, shape_min) quant_max = te.lang.cce.broadcast(quant_max, shape_min) + if symmetric: + max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val) + min_val = te.lang.cce.vmuls(max_val, -1.) + # CalNudge(NudgeMinMax) scale = te.lang.cce.vdiv(te.lang.cce.vsub( max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) @@ -142,12 +146,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 + quant_min = 0 + quant_max = 2 ** num_bits - 1 if narrow_range: quant_min = quant_min + 1 @@ -155,8 +155,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min, - quant_max, kernel_name) + res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, + quant_min, quant_max, symmetric, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 8b4729f59..0cf41c173 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -68,7 +68,6 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, ApplyRMSProp, ApplyCenteredRMSProp) from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop -from . import _quant_ops from ._quant_ops import * from .thor_ops import * @@ -265,5 +264,4 @@ __all__ = [ "SquareSumAll" ] -__all__.extend(_quant_ops.__all__) __all__.sort() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6e5d2f9d9..2ce50260c 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -365,33 +365,6 @@ class FusedBatchNormGrad(Primitive): def __call__(self, dy, x, scale, save_mean, save_inv_variance): raise NotImplementedError -class BNTrainingReduceGrad(PrimitiveWithInfer): - """Gradients of FusedBatchNorm operation.""" - - @prim_attr_register - def __init__(self, epsilon=0.0001): - _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance'] - self.init_prim_io_names(inputs=_inputs, outputs=['y']) - - def infer_shape(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance): - return grads - - def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance): - return grads - -class BNTrainingUpdateGrad(PrimitiveWithInfer): - """Gradients of FusedBatchNorm operation.""" - - @prim_attr_register - def __init__(self, epsilon=0.0001): - self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'], - outputs=['diff_scale', 'diff_offset']) - - def infer_shape(self, grads, x, batch_mean, batch_variance): - return (batch_mean, batch_variance) - - def infer_dtype(self, grads, x, batch_mean, batch_variance): - return (batch_mean, batch_variance) class GeluGrad(PrimitiveWithInfer): """Gradients of Gelu operation.""" diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index b228c51b1..a6abb45b7 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -35,7 +35,6 @@ __all__ = ["FakeQuantPerLayer", "BatchNormFold2Grad", "BatchNormFoldD", "BatchNormFoldGradD", - "BNTrainingReduce", "BatchNormFold2_D", "BatchNormFold2GradD", "BatchNormFold2GradReduce", @@ -333,7 +332,7 @@ class BatchNormFold(PrimitiveWithInfer): Batch normalization folded. Args: - momentum (float): Momentum value should be [0, 1]. Default: 0.1. + momentum (float): Momentum value should be [0, 1]. Default: 0.9. epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in float32 else 1e-3. Default: 1e-5. is_training (bool): In training mode set True, else set False. Default: True. @@ -365,7 +364,7 @@ class BatchNormFold(PrimitiveWithInfer): channel_axis = 1 @prim_attr_register - def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): + def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): """init batch norm fold layer""" self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) @@ -697,32 +696,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer): return x_type -class BNTrainingReduce(PrimitiveWithInfer): - """ - reduce sum at axis [0, 2, 3]. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C)`. - - Outputs: - - **x_sum** (Tensor) - Tensor has the same shape as x. - - **x_square_sum** (Tensor) - Tensor has the same shape as x. - - """ - - @prim_attr_register - def __init__(self): - """init _BNTrainingReduce layer""" - self.init_prim_io_names(inputs=['x'], - outputs=['x_sum', 'x_square_sum']) - - def infer_shape(self, x_shape): - return [x_shape[1]], [x_shape[1]] - - def infer_dtype(self, x_type): - return x_type, x_type - - class BatchNormFold2_D(PrimitiveWithInfer): """ Scale the bias with a correction factor to the long term statistics diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 93c265479..da9000471 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -586,22 +586,27 @@ class FusedBatchNorm(Primitive): class BNTrainingReduce(PrimitiveWithInfer): """ - primitive operator of bn_training_reduce's register and info descriptor + reduce sum at axis [0, 2, 3]. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C)`. + + Outputs: + - **sum** (Tensor) - Tensor of shape :math:`(C,)`. + - **square_sum** (Tensor) - Tensor of shape :math:`(C,)`. + """ @prim_attr_register - def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): + def __init__(self): self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) - self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) - self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) - def infer_shape(self, x): - input_shp = _infer_shape_reduce(x, (0, 2, 3), False, self.name) - return (input_shp, input_shp) + def infer_shape(self, x_shape): + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + return ([x_shape[1]], [x_shape[1]]) - def infer_dtype(self, x): - return (x, x) + def infer_dtype(self, x_type): + return (x_type, x_type) class BNTrainingUpdate(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_batchnorm_fold_op.py b/tests/st/ops/gpu/test_batchnorm_fold_op.py index f7e1a2deb..09a8dcf28 100644 --- a/tests/st/ops/gpu/test_batchnorm_fold_op.py +++ b/tests/st/ops/gpu/test_batchnorm_fold_op.py @@ -28,7 +28,7 @@ context.set_context(device_target='GPU') class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.op = P.BatchNormFold(freeze_bn=10) + self.op = P.BatchNormFold(momentum=0.9, freeze_bn=10) @ms_function def construct(self, x, mean, variance, current_step): @@ -40,8 +40,8 @@ def np_result(x, mean, var, momentum, epsilon): np_mean = x.mean(axis=(0, 2, 3)) np_var = x.var(axis=(0, 2, 3)) n = x.shape[0] * x.shape[2] * x.shape[3] - mean_update = momentum * np_mean + (1 - momentum) * mean - var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var + mean_update = (1 - momentum) * np_mean + momentum * mean + var_update = (1 - momentum) * np_var * n / (n - 1) + momentum * var np_var = np.sqrt(np_var + epsilon) delay_mean = mean.copy() delay_std = np.sqrt(var + epsilon) -- GitLab