提交 7995189c 编写于 作者: 王东旭

fix FakeQuantPerLayer/FakeQuantPerLayerGrad symmetric bug and remove...

fix FakeQuantPerLayer/FakeQuantPerLayerGrad symmetric bug and remove BNTrainingReduceGrad/BNTrainingUpdateGrad
上级 1e90e7be
...@@ -47,7 +47,7 @@ class BatchNormFoldCell(Cell): ...@@ -47,7 +47,7 @@ class BatchNormFoldCell(Cell):
Batch normalization folded. Batch normalization folded.
Args: Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1. momentum (float): Momentum value should be [0, 1]. Default: 0.9.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5. float32 else 1e-3. Default: 1e-5.
freeze_bn (int): Delay in steps at which computation switches from regular batch freeze_bn (int): Delay in steps at which computation switches from regular batch
...@@ -69,12 +69,11 @@ class BatchNormFoldCell(Cell): ...@@ -69,12 +69,11 @@ class BatchNormFoldCell(Cell):
""" """
def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0, freeze_bn_ascend=True): def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
"""init batch norm fold layer""" """init batch norm fold layer"""
super(BatchNormFoldCell, self).__init__() super(BatchNormFoldCell, self).__init__()
self.epsilon = epsilon self.epsilon = epsilon
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"
self.freeze_bn_ascend = freeze_bn_ascend
if self.is_gpu: if self.is_gpu:
self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn) self.bn_train = P.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn) self.bn_infer = P.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
...@@ -89,7 +88,7 @@ class BatchNormFoldCell(Cell): ...@@ -89,7 +88,7 @@ class BatchNormFoldCell(Cell):
else: else:
batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step) batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
else: else:
if self.training and not self.freeze_bn_ascend: if self.training:
x_sum, x_square_sum = self.bn_reduce(x) x_sum, x_square_sum = self.bn_reduce(x)
_, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \ _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
self.bn_update(x, x_sum, x_square_sum, mean, variance) self.bn_update(x, x_sum, x_square_sum, mean, variance)
...@@ -226,17 +225,17 @@ class Conv2dBatchNormQuant(Cell): ...@@ -226,17 +225,17 @@ class Conv2dBatchNormQuant(Cell):
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0. padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5. eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.9. momentum (int): Parameters for BatchNormal op. Default: 0.997.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'None'. convolution kernel. Default: 'normal'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
beta vector. Default: 'None'. beta vector. Default: 'zeros'.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
gamma vector. Default: 'None'. gamma vector. Default: 'ones'.
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
mean vector. Default: 'None'. mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'None'. variance vector. Default: 'ones'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
...@@ -269,19 +268,18 @@ class Conv2dBatchNormQuant(Cell): ...@@ -269,19 +268,18 @@ class Conv2dBatchNormQuant(Cell):
group=1, group=1,
eps=1e-5, eps=1e-5,
momentum=0.997, momentum=0.997,
weight_init=None, weight_init='normal',
beta_init=None, beta_init='zeros',
gamma_init=None, gamma_init='ones',
mean_init=None, mean_init='zeros',
var_init=None, var_init='ones',
quant_delay=0, quant_delay=0,
freeze_bn=100000, freeze_bn=100000,
fake=True, fake=True,
num_bits=8, num_bits=8,
per_channel=False, per_channel=False,
symmetric=False, symmetric=False,
narrow_range=False, narrow_range=False):
freeze_bn_ascend=True):
"""init Conv2dBatchNormQuant layer""" """init Conv2dBatchNormQuant layer"""
super(Conv2dBatchNormQuant, self).__init__() super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -302,7 +300,6 @@ class Conv2dBatchNormQuant(Cell): ...@@ -302,7 +300,6 @@ class Conv2dBatchNormQuant(Cell):
self.symmetric = symmetric self.symmetric = symmetric
self.narrow_range = narrow_range self.narrow_range = narrow_range
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"
self.freeze_bn_ascend = freeze_bn_ascend
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
...@@ -314,8 +311,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -314,8 +311,7 @@ class Conv2dBatchNormQuant(Cell):
pad=padding, pad=padding,
stride=self.stride, stride=self.stride,
dilation=self.dilation) dilation=self.dilation)
if weight_init is None: weight_shape = [1, in_channels, *self.kernel_size]
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
channel_axis = 1 channel_axis = 1
else: else:
self.conv = P.Conv2D(out_channel=out_channels, self.conv = P.Conv2D(out_channel=out_channels,
...@@ -325,24 +321,16 @@ class Conv2dBatchNormQuant(Cell): ...@@ -325,24 +321,16 @@ class Conv2dBatchNormQuant(Cell):
stride=self.stride, stride=self.stride,
dilation=self.dilation, dilation=self.dilation,
group=group) group=group)
if weight_init is None: weight_shape = [out_channels, in_channels // group, *self.kernel_size]
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
channel_axis = 0 channel_axis = 0
self.weight = Parameter(weight_init, name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
# initialize batchnorm Parameter # initialize batchnorm Parameter
if gamma_init is None: self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
gamma_init = initializer('ones', [out_channels]) self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
self.gamma = Parameter(gamma_init, name='gamma') self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
if beta_init is None: self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
beta_init = initializer('zeros', [out_channels]) requires_grad=False)
self.beta = Parameter(beta_init, name='beta')
if mean_init is None:
mean_init = initializer('zeros', [out_channels])
self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False)
if var_init is None:
var_init = initializer('ones', [out_channels])
self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False)
# initialize fake ops # initialize fake ops
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
...@@ -371,12 +359,10 @@ class Conv2dBatchNormQuant(Cell): ...@@ -371,12 +359,10 @@ class Conv2dBatchNormQuant(Cell):
def extend_repr(self): def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels, 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
self.kernel_size, self.stride, self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.pad_mode, self.padding, self.dilation, self.group,
self.group, self.fake, self.freeze_bn, self.momentum, self.quant_delay)
self.fake, self.freeze_bn, self.momentum,
self.quant_delay)
return s return s
def construct(self, x): def construct(self, x):
...@@ -401,7 +387,7 @@ class Conv2dBatchNormQuant(Cell): ...@@ -401,7 +387,7 @@ class Conv2dBatchNormQuant(Cell):
out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
batch_std, batch_mean, running_std, running_mean, self.step) batch_std, batch_mean, running_std, running_mean, self.step)
else: else:
if self.training and not self.freeze_bn_ascend: if self.training:
out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
F.control_depend(out, self.assignadd(self.step, self.one)) F.control_depend(out, self.assignadd(self.step, self.one))
else: else:
...@@ -427,8 +413,8 @@ class Conv2dQuant(Cell): ...@@ -427,8 +413,8 @@ class Conv2dQuant(Cell):
divisible by the number of groups. Default: 1. divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: None. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
...@@ -458,8 +444,8 @@ class Conv2dQuant(Cell): ...@@ -458,8 +444,8 @@ class Conv2dQuant(Cell):
dilation=1, dilation=1,
group=1, group=1,
has_bias=False, has_bias=False,
weight_init=None, weight_init='normal',
bias_init=None, bias_init='zeros',
quant_delay=0, quant_delay=0,
num_bits=8, num_bits=8,
per_channel=False, per_channel=False,
...@@ -480,15 +466,14 @@ class Conv2dQuant(Cell): ...@@ -480,15 +466,14 @@ class Conv2dQuant(Cell):
self.group = group self.group = group
self.quant_delay = quant_delay self.quant_delay = quant_delay
if weight_init is None: weight_shape = [out_channels, in_channels // group, *self.kernel_size]
weight_init = initializer( self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight') self.bias_add = P.BiasAdd()
if bias_init is None: if check_bool(has_bias):
bias_init = initializer('zeros', [out_channels]) self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
if has_bias: else:
self.bias = Parameter(bias_init, name='bias') self.bias = None
self.bias_add = P.BiasAdd()
self.conv = P.Conv2D(out_channel=self.out_channels, self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
...@@ -518,9 +503,10 @@ class Conv2dQuant(Cell): ...@@ -518,9 +503,10 @@ class Conv2dQuant(Cell):
def extend_repr(self): def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, 'has_bias={}, quant_delay={}'.format(
self.pad_mode, self.padding, self.dilation, self.group, self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.has_bias, self.quant_delay) self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
return s return s
......
...@@ -65,7 +65,6 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, ...@@ -65,7 +65,6 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW",
kernel_name="batchnorm_fold"): kernel_name="batchnorm_fold"):
"""batchnorm_fold TBE op""" """batchnorm_fold TBE op"""
momentum = 1.0 - momentum
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
data_format = data_format.upper() data_format = data_format.upper()
if data_format != "NCHW": if data_format != "NCHW":
...@@ -120,13 +119,12 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, ...@@ -120,13 +119,12 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean) mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
if num == 1: if num == 1:
batch_var_scaler = 0.0 batch_var_scaler = 0.0
else: else:
batch_var_scaler = float(num) / (num - 1) batch_var_scaler = float(num) / (num - 1)
batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon))
factor = 1.0 - momentum factor = 1.0 - momentum
factor_reverse = momentum factor_reverse = momentum
...@@ -134,7 +132,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance, ...@@ -134,7 +132,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_variance, factor) var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
......
...@@ -50,15 +50,16 @@ def _fake_quant_per_layer_tbe(): ...@@ -50,15 +50,16 @@ def _fake_quant_per_layer_tbe():
@fusion_manager.register("fake_quant_per_layer") @fusion_manager.register("fake_quant_per_layer")
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric,
kernel_name="fake_quant_per_layer"): kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer""" """FakeQuantPerLayer"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype) quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) if symmetric:
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
min_val = te.lang.cce.vmuls(max_val, -1.)
# CalNudge(NudgeMinMax) # CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub( scale = te.lang.cce.vdiv(te.lang.cce.vsub(
...@@ -119,12 +120,8 @@ def fake_quant_per_layer(x, min_val, max_val, y, ...@@ -119,12 +120,8 @@ def fake_quant_per_layer(x, min_val, max_val, y,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape) shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric: quant_min = 0
quant_min = 0 - 2 ** (num_bits - 1) quant_max = 2 ** num_bits - 1
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range: if narrow_range:
quant_min = quant_min + 1 quant_min = quant_min + 1
...@@ -132,7 +129,7 @@ def fake_quant_per_layer(x, min_val, max_val, y, ...@@ -132,7 +129,7 @@ def fake_quant_per_layer(x, min_val, max_val, y,
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name) quant_min, quant_max, symmetric, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res) sch = generic.auto_schedule(res)
......
...@@ -78,7 +78,7 @@ def _fake_quant_per_layer_grad_tbe(): ...@@ -78,7 +78,7 @@ def _fake_quant_per_layer_grad_tbe():
@fusion_manager.register("fake_quant_per_layer_grad") @fusion_manager.register("fake_quant_per_layer_grad")
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, symmetric,
kernel_name="fake_quant_per_layer_grad"): kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad""" """FakeQuantPerLayerGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
...@@ -88,6 +88,10 @@ def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quan ...@@ -88,6 +88,10 @@ def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quan
quant_min = te.lang.cce.broadcast(quant_min, shape_min) quant_min = te.lang.cce.broadcast(quant_min, shape_min)
quant_max = te.lang.cce.broadcast(quant_max, shape_min) quant_max = te.lang.cce.broadcast(quant_max, shape_min)
if symmetric:
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
min_val = te.lang.cce.vmuls(max_val, -1.)
# CalNudge(NudgeMinMax) # CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub( scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
...@@ -142,12 +146,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, ...@@ -142,12 +146,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape) shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric: quant_min = 0
quant_min = 0 - 2 ** (num_bits - 1) quant_max = 2 ** num_bits - 1
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range: if narrow_range:
quant_min = quant_min + 1 quant_min = quant_min + 1
...@@ -155,8 +155,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, ...@@ -155,8 +155,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min, res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data,
quant_max, kernel_name) quant_min, quant_max, symmetric, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res) sch = generic.auto_schedule(res)
......
...@@ -68,7 +68,6 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, ...@@ -68,7 +68,6 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
ApplyRMSProp, ApplyCenteredRMSProp) ApplyRMSProp, ApplyCenteredRMSProp)
from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop
from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
from .thor_ops import * from .thor_ops import *
...@@ -265,5 +264,4 @@ __all__ = [ ...@@ -265,5 +264,4 @@ __all__ = [
"SquareSumAll" "SquareSumAll"
] ]
__all__.extend(_quant_ops.__all__)
__all__.sort() __all__.sort()
...@@ -365,33 +365,6 @@ class FusedBatchNormGrad(Primitive): ...@@ -365,33 +365,6 @@ class FusedBatchNormGrad(Primitive):
def __call__(self, dy, x, scale, save_mean, save_inv_variance): def __call__(self, dy, x, scale, save_mean, save_inv_variance):
raise NotImplementedError raise NotImplementedError
class BNTrainingReduceGrad(PrimitiveWithInfer):
"""Gradients of FusedBatchNorm operation."""
@prim_attr_register
def __init__(self, epsilon=0.0001):
_inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
self.init_prim_io_names(inputs=_inputs, outputs=['y'])
def infer_shape(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
return grads
def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
return grads
class BNTrainingUpdateGrad(PrimitiveWithInfer):
"""Gradients of FusedBatchNorm operation."""
@prim_attr_register
def __init__(self, epsilon=0.0001):
self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
outputs=['diff_scale', 'diff_offset'])
def infer_shape(self, grads, x, batch_mean, batch_variance):
return (batch_mean, batch_variance)
def infer_dtype(self, grads, x, batch_mean, batch_variance):
return (batch_mean, batch_variance)
class GeluGrad(PrimitiveWithInfer): class GeluGrad(PrimitiveWithInfer):
"""Gradients of Gelu operation.""" """Gradients of Gelu operation."""
......
...@@ -35,7 +35,6 @@ __all__ = ["FakeQuantPerLayer", ...@@ -35,7 +35,6 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFold2Grad", "BatchNormFold2Grad",
"BatchNormFoldD", "BatchNormFoldD",
"BatchNormFoldGradD", "BatchNormFoldGradD",
"BNTrainingReduce",
"BatchNormFold2_D", "BatchNormFold2_D",
"BatchNormFold2GradD", "BatchNormFold2GradD",
"BatchNormFold2GradReduce", "BatchNormFold2GradReduce",
...@@ -333,7 +332,7 @@ class BatchNormFold(PrimitiveWithInfer): ...@@ -333,7 +332,7 @@ class BatchNormFold(PrimitiveWithInfer):
Batch normalization folded. Batch normalization folded.
Args: Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1. momentum (float): Momentum value should be [0, 1]. Default: 0.9.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5. float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True. is_training (bool): In training mode set True, else set False. Default: True.
...@@ -365,7 +364,7 @@ class BatchNormFold(PrimitiveWithInfer): ...@@ -365,7 +364,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel_axis = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init batch norm fold layer""" """init batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
...@@ -697,32 +696,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer): ...@@ -697,32 +696,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
return x_type return x_type
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **x_sum** (Tensor) - Tensor has the same shape as x.
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
"""
@prim_attr_register
def __init__(self):
"""init _BNTrainingReduce layer"""
self.init_prim_io_names(inputs=['x'],
outputs=['x_sum', 'x_square_sum'])
def infer_shape(self, x_shape):
return [x_shape[1]], [x_shape[1]]
def infer_dtype(self, x_type):
return x_type, x_type
class BatchNormFold2_D(PrimitiveWithInfer): class BatchNormFold2_D(PrimitiveWithInfer):
""" """
Scale the bias with a correction factor to the long term statistics Scale the bias with a correction factor to the long term statistics
......
...@@ -586,22 +586,27 @@ class FusedBatchNorm(Primitive): ...@@ -586,22 +586,27 @@ class FusedBatchNorm(Primitive):
class BNTrainingReduce(PrimitiveWithInfer): class BNTrainingReduce(PrimitiveWithInfer):
""" """
primitive operator of bn_training_reduce's register and info descriptor reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **sum** (Tensor) - Tensor of shape :math:`(C,)`.
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`.
""" """
@prim_attr_register @prim_attr_register
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
def infer_shape(self, x): def infer_shape(self, x_shape):
input_shp = _infer_shape_reduce(x, (0, 2, 3), False, self.name) validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
return (input_shp, input_shp) return ([x_shape[1]], [x_shape[1]])
def infer_dtype(self, x): def infer_dtype(self, x_type):
return (x, x) return (x_type, x_type)
class BNTrainingUpdate(PrimitiveWithInfer): class BNTrainingUpdate(PrimitiveWithInfer):
......
...@@ -28,7 +28,7 @@ context.set_context(device_target='GPU') ...@@ -28,7 +28,7 @@ context.set_context(device_target='GPU')
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.op = P.BatchNormFold(freeze_bn=10) self.op = P.BatchNormFold(momentum=0.9, freeze_bn=10)
@ms_function @ms_function
def construct(self, x, mean, variance, current_step): def construct(self, x, mean, variance, current_step):
...@@ -40,8 +40,8 @@ def np_result(x, mean, var, momentum, epsilon): ...@@ -40,8 +40,8 @@ def np_result(x, mean, var, momentum, epsilon):
np_mean = x.mean(axis=(0, 2, 3)) np_mean = x.mean(axis=(0, 2, 3))
np_var = x.var(axis=(0, 2, 3)) np_var = x.var(axis=(0, 2, 3))
n = x.shape[0] * x.shape[2] * x.shape[3] n = x.shape[0] * x.shape[2] * x.shape[3]
mean_update = momentum * np_mean + (1 - momentum) * mean mean_update = (1 - momentum) * np_mean + momentum * mean
var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var var_update = (1 - momentum) * np_var * n / (n - 1) + momentum * var
np_var = np.sqrt(np_var + epsilon) np_var = np.sqrt(np_var + epsilon)
delay_mean = mean.copy() delay_mean = mean.copy()
delay_std = np.sqrt(var + epsilon) delay_std = np.sqrt(var + epsilon)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册