diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index c046c2e1bf4aadbfae9660d6032d2271f94f90b7..d2c25030634d055d1ac27c6445c2802b390fac1c 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -81,6 +81,7 @@ class Cell: self.enable_hook = False self._bprop_debug = False self._is_run = False + self.cell_type = None @property def is_run(self): @@ -140,6 +141,14 @@ class Cell: for cell_name, cell in cells_name: cell._param_prefix = cell_name + def update_cell_type(self, cell_type): + """ + Update current cell type mainly identify if quantization aware training network. + + After invoked, can set the cell type to 'cell_type'. + """ + self.cell_type = cell_type + @cell_init_args.setter def cell_init_args(self, value): if not isinstance(value, str): diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index b2a0de9cbe2d4eb99798b9a2312f612579ddeaca..fb77160cca6999f23d9333eefa4d15f47ca70c87 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -17,6 +17,7 @@ from mindspore import log as logger from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer +from mindspore._checkparam import ParamValidator as validator, Rel from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative from mindspore._extends import cell_attr_register from ..cell import Cell @@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv): self.weight, self.bias) return s + + +class DepthwiseConv2d(Cell): + r""" + 2D depthwise convolution layer. + + Applies a 2D depthwise convolution over an input tensor which is typically of shape: + math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number. + For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as: + + .. math:: + + out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, + + where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges + from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th + filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice + of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and + :math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape + :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number + to split the input in the channel dimension. + + If the 'pad_mode' is set to be "valid", the output height and width will be + :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - + (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and + :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - + (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. + + The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition + `_. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height + and width of the 2D convolution window. Single int means the value if for both height and width of + the kernel. A tuple of 2 ints means the first value is for the height and the other is for the + width of the kernel. + stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are + "same", "valid", "pad". Default: "same". + + - same: Adopts the way of completion. Output height and width will be the same as the input. + Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the + last extra padding will be done from the bottom and the right side. If this mode is set, `padding` + must be 0. + + - valid: Adopts the way of discarding. The possibly largest height and width of output will be return + without padding. Extra pixels will be discarded. If this mode is set, `padding` + must be 0. + + - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input + Tensor borders. `padding` should be greater than or equal to 0. + + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate + to use for dilated convolution. If set to be :math:`k > 1`, there will + be :math:`k - 1` pixels skipped for each sampling location. Its value should + be greater or equal to 1 and bounded by the height and width of the + input. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal') + >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) + >>> net(input).shape + (1, 240, 1024, 640) + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros'): + super(DepthwiseConv2d, self).__init__() + self.kernel_size = twice(kernel_size) + self.stride = twice(stride) + self.dilation = twice(dilation) + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + validator.check_integer('group', group, in_channels, Rel.EQ) + validator.check_integer('group', group, out_channels, Rel.EQ) + validator.check_integer('group', group, 1, Rel.GE) + self.pad_mode = pad_mode + self.padding = padding + self.dilation = dilation + self.group = group + self.has_bias = has_bias + self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=self.kernel_size, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation) + self.bias_add = P.BiasAdd() + weight_shape = [1, in_channels, *self.kernel_size] + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + if bias_init != 'zeros': + logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.") + self.bias = None + + def construct(self, x): + out = self.conv(x, self.weight) + if self.has_bias: + out = self.bias_add(out, self.bias) + return out + + def extend_repr(self): + s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={},' \ + 'has_bias={}, weight_init={}, bias_init={}'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.weight_init, self.bias_init) + + if self.has_bias: + s += ', bias={}'.format(self.bias) + return s diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 6c01aa54044d6dcc8a156dd6e6c47905078ee6f6..f34b5520a9a18d168e2852d69230d46c40ae7904 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Aware quantization.""" +"""Quantization aware.""" from functools import partial import numpy as np + import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore._checkparam import check_int_positive, check_bool, twice -from mindspore._checkparam import Validator as validator, Rel -from mindspore.nn.cell import Cell -from mindspore.nn.layer.activation import get_activation +from mindspore._checkparam import Rel import mindspore.context as context + from .normalization import BatchNorm2d from .activation import get_activation from ..cell import Cell @@ -82,7 +82,7 @@ class Conv2dBnAct(Cell): bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible Initializer and string are the same as 'weight_init'. Refer to the values of Initializer for more details. Default: 'zeros'. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. + has_bn (bool): Specifies to used batchnorm or not. Default: False. activation (string): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. @@ -94,7 +94,7 @@ class Conv2dBnAct(Cell): Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') + >>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU') >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> net(input).shape (1, 240, 1024, 640) @@ -112,28 +112,39 @@ class Conv2dBnAct(Cell): has_bias=False, weight_init='normal', bias_init='zeros', - batchnorm=None, + has_bn=False, activation=None): super(Conv2dBnAct, self).__init__() - self.conv = conv.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - pad_mode, - padding, - dilation, - group, - has_bias, - weight_init, - bias_init) - self.has_bn = batchnorm is not None + + if context.get_context('device_target') == "Ascend" and group > 1: + self.conv = conv.DepthwiseConv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_init=weight_init, + bias_init=bias_init) + else: + self.conv = conv.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_init=weight_init, + bias_init=bias_init) + + self.has_bn = has_bn self.has_act = activation is not None - self.batchnorm = batchnorm - if batchnorm is True: + if has_bn: self.batchnorm = BatchNorm2d(out_channels) - elif batchnorm is not None: - validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) self.activation = get_activation(activation) def construct(self, x): @@ -160,7 +171,7 @@ class DenseBnAct(Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. - batchnorm (bool): Specifies to used batchnorm or not. Default: None. + has_bn (bool): Specifies to used batchnorm or not. Default: False. activation (string): Specifies activation type. The optional values are as following: 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. @@ -172,7 +183,7 @@ class DenseBnAct(Cell): Tensor of shape :math:`(N, out\_channels)`. Examples: - >>> net = nn.Dense(3, 4) + >>> net = nn.DenseBnAct(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> net(input) """ @@ -183,7 +194,7 @@ class DenseBnAct(Cell): weight_init='normal', bias_init='zeros', has_bias=True, - batchnorm=None, + has_bn=False, activation=None): super(DenseBnAct, self).__init__() self.dense = basic.Dense( @@ -192,12 +203,10 @@ class DenseBnAct(Cell): weight_init, bias_init, has_bias) - self.has_bn = batchnorm is not None + self.has_bn = has_bn self.has_act = activation is not None - if batchnorm is True: + if has_bn: self.batchnorm = BatchNorm2d(out_channels) - elif batchnorm is not None: - validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,)) self.activation = get_activation(activation) def construct(self, x): @@ -271,20 +280,20 @@ class BatchNormFoldCell(Cell): class FakeQuantWithMinMax(Cell): r""" - Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. + Quantization aware op. This OP provide Fake quantization observer function on data with min and max. Args: min_init (int, float): The dimension of channel or 1(layer). Default: -6. max_init (int, float): The dimension of channel or 1(layer). Default: 6. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. channel_axis (int): Quantization by channel axis. Default: 1. - out_channels (int): declarate the min and max channel size, Default: 1. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. + num_channels (int): declarate the min and max channel size, Default: 1. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of FakeQuantWithMinMax. @@ -301,24 +310,27 @@ class FakeQuantWithMinMax(Cell): def __init__(self, min_init=-6, max_init=6, - num_bits=8, ema=False, ema_decay=0.999, per_channel=False, channel_axis=1, - out_channels=1, - quant_delay=0, + num_channels=1, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() + validator.check_type("min_init", min_init, [int, float]) + validator.check_type("max_init", max_init, [int, float]) + validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) self.min_init = min_init self.max_init = max_init self.num_bits = num_bits self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel - self.out_channels = out_channels + self.num_channels = num_channels self.channel_axis = channel_axis self.quant_delay = quant_delay self.symmetric = symmetric @@ -327,54 +339,54 @@ class FakeQuantWithMinMax(Cell): # init tensor min and max for fake quant op if self.per_channel: - min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) - max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) + min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) + max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) else: - min_array = np.array([self.min_init]).reshape(1).astype(np.float32) - max_array = np.array([self.max_init]).reshape(1).astype(np.float32) + min_array = np.array([self.min_init]).astype(np.float32) + max_array = np.array([self.max_init]).astype(np.float32) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) # init fake quant relative op if per_channel: quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) - ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) + ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) else: quant_fun = Q.FakeQuantPerLayer - ema_fun = Q.FakeQuantMinMaxPerLayerUpdate + ema_fun = Q.MinMaxUpdatePerLayer + self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) if self.is_ascend: - self.fake_quant = quant_fun(num_bits=self.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + self.fake_quant_train = quant_fun(num_bits=self.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range) + self.fake_quant_infer = self.fake_quant_train else: - self.fake_quant = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range) - self.ema_update = ema_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + quant_fun = partial(quant_fun, + ema=self.ema, + ema_decay=ema_decay, + num_bits=self.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=quant_delay) + self.fake_quant_train = quant_fun(training=True) + self.fake_quant_infer = quant_fun(training=False) def extend_repr(self): s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ 'quant_delay={}, min_init={}, max_init={}'.format( self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, - self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) + self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) return s def construct(self, x): - if self.is_ascend and self.training: + if self.training: min_up, max_up = self.ema_update(x, self.minq, self.maxq) - out = self.fake_quant(x, min_up, max_up) P.Assign()(self.minq, min_up) P.Assign()(self.maxq, max_up) + out = self.fake_quant_train(x, self.minq, self.maxq) else: - out = self.fake_quant(x, self.minq, self.maxq) + out = self.fake_quant_infer(x, self.minq, self.maxq) return out @@ -391,8 +403,8 @@ class Conv2dBatchNormQuant(Cell): stride (int): Specifies stride for all spatial dimensions with the same value. pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". padding: (int): Implicit paddings on both sides of the input. Default: 0. - eps (int): Parameters for BatchNormal. Default: 1e-5. - momentum (int): Parameters for BatchNormal op. Default: 0.997. + eps (float): Parameters for BatchNormal. Default: 1e-5. + momentum (float): Parameters for BatchNormal op. Default: 0.997. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the @@ -403,13 +415,13 @@ class Conv2dBatchNormQuant(Cell): mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the variance vector. Default: 'ones'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -440,13 +452,13 @@ class Conv2dBatchNormQuant(Cell): gamma_init='ones', mean_init='zeros', var_init='ones', - quant_delay=0, - freeze_bn=100000, fake=True, - num_bits=8, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0, + freeze_bn=100000): """init Conv2dBatchNormQuant layer""" super(Conv2dBatchNormQuant, self).__init__() self.in_channels = in_channels @@ -503,12 +515,13 @@ class Conv2dBatchNormQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=channel_axis, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.correct_mul = Q.CorrectionMul(channel_axis) if context.get_context('device_target') == "Ascend": @@ -582,11 +595,11 @@ class Conv2dQuant(Cell): weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -613,11 +626,11 @@ class Conv2dQuant(Cell): has_bias=False, weight_init='normal', bias_init='zeros', - quant_delay=0, - num_bits=8, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(Conv2dQuant, self).__init__() if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size) @@ -653,12 +666,13 @@ class Conv2dQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) def construct(self, x): weight = self.fake_quant_weight(self.weight) @@ -692,11 +706,11 @@ class DenseQuant(Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -718,19 +732,19 @@ class DenseQuant(Cell): bias_init='zeros', has_bias=True, activation=None, - num_bits=8, - quant_delay=0, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(DenseQuant, self).__init__() self.in_channels = check_int_positive(in_channels) self.out_channels = check_int_positive(out_channels) self.has_bias = check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ - weight_init.shape[1] != in_channels: + if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ + weight_init.shape()[1] != in_channels: raise ValueError("weight_init shape error") self.weight = Parameter(initializer( @@ -738,7 +752,7 @@ class DenseQuant(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( @@ -752,12 +766,13 @@ class DenseQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) def construct(self, x): """Use operators to construct to Dense layer.""" @@ -780,13 +795,16 @@ class DenseQuant(Cell): return str_info + class _QuantActivation(Cell): r""" Base class for Quant activation function. Add Fake Quant OP after activation OP. """ + def get_origin(self): raise NotImplementedError + class ReLUQuant(_QuantActivation): r""" ReLUQuant activation function. Add Fake Quant OP after Relu OP. @@ -794,12 +812,12 @@ class ReLUQuant(_QuantActivation): For a more Detailed overview of ReLU op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of ReLUQuant. @@ -814,22 +832,22 @@ class ReLUQuant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(ReLUQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.relu = P.ReLU() def construct(self, x): @@ -850,12 +868,12 @@ class ReLU6Quant(_QuantActivation): For a more Detailed overview of ReLU6 op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of ReLU6Quant. @@ -870,22 +888,22 @@ class ReLU6Quant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(ReLU6Quant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.relu6 = P.ReLU6() def construct(self, x): @@ -896,6 +914,7 @@ class ReLU6Quant(_QuantActivation): def get_origin(self): return self.relu6 + class HSwishQuant(_QuantActivation): r""" HSwishQuant activation function. Add Fake Quant OP after HSwish OP. @@ -903,12 +922,12 @@ class HSwishQuant(_QuantActivation): For a more Detailed overview of HSwish op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of HSwishQuant. @@ -923,31 +942,31 @@ class HSwishQuant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(HSwishQuant, self).__init__() self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.act = P.HSwish() def construct(self, x): @@ -959,6 +978,7 @@ class HSwishQuant(_QuantActivation): def get_origin(self): return self.act + class HSigmoidQuant(_QuantActivation): r""" HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. @@ -966,12 +986,12 @@ class HSigmoidQuant(_QuantActivation): For a more Detailed overview of HSigmoid op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of HSigmoidQuant. @@ -986,30 +1006,31 @@ class HSigmoidQuant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(HSigmoidQuant, self).__init__() self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.act = P.HSigmoid() def construct(self, x): @@ -1021,6 +1042,7 @@ class HSigmoidQuant(_QuantActivation): def get_origin(self): return self.act + class TensorAddQuant(Cell): r""" Add Fake Quant OP after TensorAdd OP. @@ -1028,12 +1050,12 @@ class TensorAddQuant(Cell): For a more Detailed overview of TensorAdd op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of TensorAddQuant. @@ -1049,22 +1071,22 @@ class TensorAddQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(TensorAddQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.add = P.TensorAdd() def construct(self, x1, x2): @@ -1080,12 +1102,12 @@ class MulQuant(Cell): For a more Detailed overview of Mul op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of MulQuant. @@ -1096,22 +1118,22 @@ class MulQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(MulQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.mul = P.Mul() def construct(self, x1, x2): @@ -1173,12 +1195,13 @@ class QuantBlock(Cell): self.has_bias = bias is None self.activation = activation self.has_act = activation is None + self.bias_add = P.BiasAdd() def construct(self, x): x = self.quant(x) x = self.core_op(x, self.weight) if self.has_bias: - output = self.bias_add(output, self.bias) + x = self.bias_add(x, self.bias) if self.has_act: x = self.activation(x) x = self.dequant(x, self.dequant_scale) diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index da19662e979901b65c2b67573067cda30654f34e..a2b0ba8d97e0d5f28018e18570fe40d3bb94f660 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""Generate bprop for aware quantization ops""" +"""Generate bprop for quantization aware ops""" from .. import operations as P from ..operations import _quant_ops as Q @@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self): return bprop -@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate) +@bprop_getters.register(Q.MinMaxUpdatePerLayer) def get_bprop_fakequant_with_minmax_per_layer_update(self): - """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" + """Generate bprop for MinMaxUpdatePerLayer for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) @@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self): return bprop -@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate) +@bprop_getters.register(Q.MinMaxUpdatePerChannel) def get_bprop_fakequant_with_minmax_per_channel_update(self): - """Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" + """Generate bprop for MinMaxUpdatePerChannel for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py index 7e98517057d862f9c139ce68c0b38d321c4754c0..9daab5a75f4c963a84c04c0e065c5996bbc97a33 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py @@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \ .compute_cost(10) \ .kernel_name("batchnorm_fold2") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "x", None, "required", None) \ .input(1, "beta", None, "required", None) \ .input(2, "gamma", None, "required", None) \ diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py index 824da62d19ba5cd378317f5f210a4a25b2c17817..9994a88f3008e86ae76f62176f1696af05ccf49b 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py @@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \ .compute_cost(10) \ .kernel_name("batchnorm_fold2_grad") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "dout", None, "required", None) \ .input(1, "dout_reduce", None, "required", None) \ .input(2, "dout_x_reduce", None, "required", None) \ diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py index 7806c6834ee41e9795a9f4e88058cd757e9e70ad..92b91ff71290ee362c1da4ffed504f3c22cc3332 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py @@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \ .compute_cost(10) \ .kernel_name("batchnorm_fold2_grad_reduce") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .input(0, "dout", None, "required", None) \ .input(1, "x", None, "required", None) \ .output(0, "dout_reduce", True, "required", "all") \ diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul.py b/mindspore/ops/_op_impl/_custom_op/correction_mul.py index ce92d2bbc5c110ae686f5899b6af721fdfb1fab5..49cd35cc1119433ce7946ad9c69811703e6fc91b 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul.py @@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \ .compute_cost(10) \ .kernel_name("correction_mul") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "batch_std", None, "required", None) \ diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py index da3a634454add6c0505406bc85722f3d5be64d0e..6c11ce685541d2a17fb1459b40583d9970e77047 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py @@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ .compute_cost(10) \ .kernel_name("correction_mul_grad") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "dout", None, "required", None) \ .input(1, "x", None, "required", None) \ @@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \ .compute_cost(10) \ .kernel_name("correction_mul_grad_reduce") \ .partial_flag(True) \ - .op_pattern("formatAgnostic") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "dout", None, "required", None) \ .output(0, "d_batch_std", True, "required", "all") \ diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py index f6c133c8086947aee90e4f5350fce549eb388137..dae2d7058dddee086222abfaa498fb1ca4e612cd 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, quant_min = quant_min + 1 shape_c = [1] * len(x_shape) - shape_c[channel_axis] = min_val.get("ori_shape")[0] - if x_format == "NC1HWC0" and channel_axis == 1: + shape_c[channel_axis_] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py index 4e9053fcb145c805c6c115cccc7f944314e850f1..795aab52a3db8a7f3584a0057803cac3eb701128 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py @@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, quant_min = quant_min + 1 shape_c = [1] * len(x_shape) - shape_c[channel_axis] = min_val.get("ori_shape")[0] - if x_format == "NC1HWC0" and channel_axis == 1: + shape_c[channel_axis_] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py similarity index 53% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py rename to mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py index 7694753d8f574c572a3a8f3be3782767c7449fc1..f29fc5332551d5905bcde8e7d880bfefea62ad0f 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py @@ -1,4 +1,3 @@ - # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantMinMaxPerChannelUpdate op""" +"""MinMaxUpdatePerChannel op""" import te.lang.cce from te import tvm from te.platform.fusion_manager import fusion_manager @@ -22,20 +21,15 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \ +minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_min_max_per_channel_update.so") \ + .binfile_name("minmax_update_perchannel.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_min_max_per_channel_update") \ + .kernel_name("minmax_update_perchannel") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ - .attr("symmetric", "optional", "bool", "all") \ - .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ - .attr("num_bits", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ @@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan .get_op_info() -@op_info_register(fake_quant_min_max_per_channel_update_op_info) -def _fake_quant_min_max_per_channel_update_tbe(): - """FakeQuantPerChannelUpdate TBE register""" +@op_info_register(minmax_update_perchannel_op_info) +def _minmax_update_perchannel_tbe(): + """MinMaxUpdatePerChannel TBE register""" return -@fusion_manager.register("fake_quant_min_max_per_channel_update") -def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, - ema, ema_decay, quant_min, quant_max, training, channel_axis, - kernel_name="fake_quant_min_max_per_channel_update"): - """FakeQuantPerChannelUpdate compute""" +@fusion_manager.register("minmax_update_perchannel") +def minmax_update_perchannel_compute(x, min_val, max_val, + ema, ema_decay, channel_axis): + """MinMaxUpdatePerChannel compute""" shape_min = te.lang.cce.util.shape_to_list(min_val.shape) if not ema: ema_decay = 0.0 - if training: - # CalMinMax + + # CalMinMax + if channel_axis == 0: + axis = [1, 2, 3, 4] + else: axis = [0, 2, 3] - x_min = te.lang.cce.reduce_min(x, axis=axis) - x_max = te.lang.cce.reduce_max(x, axis=axis) - x_min = te.lang.cce.broadcast(x_min, shape_min) - x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls( - min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls( - max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) - min_val = te.lang.cce.vmins(min_val, 0) - max_val = te.lang.cce.vmaxs(max_val, 0) + + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) -def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, - kernel_name="fake_quant_min_max_per_channel_update"): - """FakeQuantPerLayer op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str) +def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, + ema, ema_decay, channel_axis, + kernel_name="minmax_update_perchannel"): + """MinMaxUpdatePerChannel op""" x_shape = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") @@ -91,11 +88,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -108,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 + if channel_axis_ == 0: + shape_c = min_val.get("ori_shape") else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - if narrow_range: - quant_min = quant_min + 1 - - shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] + shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) - res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) + res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, + ema, ema_decay, channel_axis_) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py similarity index 61% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py rename to mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py index 0ad2315bb3f65e61ec21b8c8340ce5fd798aaf35..4d2096d55ba6fd4d7afedf422e563260cb0975cc 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantMinMaxPerLayerUpdate op""" +"""MinMaxUpdatePerLayer op""" from functools import reduce as functools_reduce import te.lang.cce from te import tvm @@ -22,20 +22,15 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ +minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_minmax_update.so") \ + .binfile_name("minmax_update_perlayer.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_minmax_update") \ + .kernel_name("minmax_update_perlayer") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ - .attr("symmetric", "optional", "bool", "all") \ - .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ - .attr("num_bits", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ .input(2, "max", None, "required", None) \ @@ -46,44 +41,42 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ .get_op_info() -@op_info_register(fake_quant_minmax_update_op_info) -def _fake_quant_minmax_update_tbe(): - """FakeQuantMinMaxPerLayerUpdate TBE register""" +@op_info_register(minmax_update_perlayer_op_info) +def _minmax_update_perlayer_tbe(): + """MinMaxUpdatePerLayer TBE register""" return -@fusion_manager.register("fake_quant_minmax_update") -def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, - kernel_name="fake_quant_minmax_update"): - """FakeQuantMinMaxPerLayerUpdate compute""" +@fusion_manager.register("minmax_update_perlayer") +def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay): + """MinMaxUpdatePerLayer compute""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) if not ema: ema_decay = 0.0 - if training: - # CalMinMax - axis = tuple(range(len(shape))) - x_min = te.lang.cce.reduce_min(x, axis=axis) - x_max = te.lang.cce.reduce_max(x, axis=axis) - x_min = te.lang.cce.broadcast(x_min, shape_min) - x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls( - min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls( - max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) - min_val = te.lang.cce.vmins(min_val, 0) - max_val = te.lang.cce.vmaxs(max_val, 0) + + # CalMinMax + axis = tuple(range(len(shape))) + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) -def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, - kernel_name="fake_quant_minmax_update"): - """FakeQuantPerLayer op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str) +def minmax_update_perlayer(x, min_val, max_val, min_up, max_up, + ema, ema_decay, kernel_name="minmax_update_perlayer"): + """MinMaxUpdatePerLayer op""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - if narrow_range: - quant_min = quant_min + 1 - input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, kernel_name) + res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index dd46fa491a35a8b2004175546246627f83e3b29a..3dd1ea4c9bafe50f98dfde43fc61bfa9bfdeb6ff 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -21,12 +21,12 @@ from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register from ...common import dtype as mstype -__all__ = ["FakeQuantPerLayer", +__all__ = ["MinMaxUpdatePerLayer", + "MinMaxUpdatePerChannel", + "FakeQuantPerLayer", "FakeQuantPerLayerGrad", "FakeQuantPerChannel", "FakeQuantPerChannelGrad", - "FakeQuantMinMaxPerLayerUpdate", - "FakeQuantMinMaxPerChannelUpdate", "BatchNormFold", "BatchNormFoldGrad", "CorrectionMul", @@ -38,20 +38,141 @@ __all__ = ["FakeQuantPerLayer", "BatchNormFoldGradD", "BatchNormFold2_D", "BatchNormFold2GradD", - "BatchNormFold2GradReduce", + "BatchNormFold2GradReduce" ] +class MinMaxUpdatePerLayer(PrimitiveWithInfer): + r""" + Update min and max per layer. + + Args: + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6]), mstype.float32) + >>> max_tensor = Tensor(np.array([6]), mstype.float32) + >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, ema=False, ema_decay=0.999): + """init FakeQuantMinMaxPerLayerUpdate OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import minmax_update_perlayer + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.init_prim_io_names(inputs=['x', 'min', 'max'], + outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + +class MinMaxUpdatePerChannel(PrimitiveWithInfer): + r""" + Update min and max per channel. + + Args: + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) + """ + support_quant_bit = [4, 7, 8] + support_x_rank = [2, 4] + + @prim_attr_register + def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): + """init FakeQuantPerChannelUpdate OP for Ascend""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import minmax_update_perchannel + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.channel_axis = validator.check_int_range( + 'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) + self.init_prim_io_names( + inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + if len(x_shape) not in self.support_x_rank: + raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'") + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same( + {"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + class FakeQuantPerLayer(PrimitiveWithInfer): r""" Simulate the quantize and dequantize operations in training time. Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. + num_bits (int) : Number bits for quantization aware. Default: 8. ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update - simulate aware quantize funcion. After delay step in training time begin simulate the aware + simulate quantization aware funcion. After delay step in training time begin simulate the aware quantize funcion. Default: 0. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -103,8 +224,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) self.num_bits = validator.check_integer( 'num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type( - 'quant_delay', quant_delay, (int,), self.name) + self.quant_delay = validator.check_integer( + 'quant_delay', quant_delay, 0, Rel.GE, self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) @@ -196,6 +317,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer): symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. training (bool): Training the network or not. Default: True. + channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1. Inputs: - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. @@ -213,6 +335,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer): >>> result = fake_quant(input_x, _min, _max) """ support_quant_bit = [4, 7, 8] + support_x_rank = [2, 4] @prim_attr_register def __init__(self, @@ -245,14 +368,15 @@ class FakeQuantPerChannel(PrimitiveWithInfer): 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) self.num_bits = validator.check_integer( 'num_bits', num_bits, 0, Rel.GT, self.name) - self.quant_delay = validator.check_value_type( - 'quant_delay', quant_delay, (int,), self.name) - self.channel_axis = validator.check_integer( - 'channel_axis', channel_axis, 0, Rel.GE, self.name) + self.quant_delay = validator.check_integer( + 'quant_delay', quant_delay, 0, Rel.GE, self.name) + self.channel_axis = validator.check_int_range( + 'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + if len(x_shape) not in self.support_x_rank: + raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'") validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check_integer( "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) @@ -832,153 +956,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type): validator.check("dout type", dout_type, "x type", x_type) return dout_type, dout_type - - -class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): - r""" - Update min and max value for fake quant per layer op. - - Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. - ema (bool): Use EMA algorithm update value min and max. Default: False. - ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Training the network or not. Default: True. - - Inputs: - - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - - **min** (Tensor) : Value of the min range of the input data x. - - **max** (Tensor) : Value of the max range of the input data x. - - Outputs: - - Tensor: Simulate quantize tensor of x. - - Examples: - >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) - >>> min_tensor = Tensor(np.array([-6]), mstype.float32) - >>> max_tensor = Tensor(np.array([6]), mstype.float32) - >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) - """ - support_quant_bit = [4, 7, 8] - - @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, - training=True): - """init FakeQuantMinMaxPerLayerUpdate OP""" - if context.get_context('device_target') == "Ascend": - from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update - if num_bits not in self.support_quant_bit: - raise ValueError( - f"For '{self.name}' attr \'num_bits\' is not support.") - if ema and not ema_decay: - raise ValueError( - f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type( - 'symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type( - 'narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type( - 'training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer( - 'num_bits', num_bits, 0, Rel.GT, self.name) - self.init_prim_io_names(inputs=['x', 'min', 'max'], - outputs=['min_up', 'max_up']) - - def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) - return min_shape, max_shape - - def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) - return min_type, max_type - - -class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): - r""" - Update min and max value for fake quant per layer op. - - Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. - ema (bool): Use EMA algorithm update value min and max. Default: False. - ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Training the network or not. Default: True. - channel_axis (int): Channel asis for per channel compute. Default: 1. - - Inputs: - - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - - **min** (Tensor) : Value of the min range of the input data x. - - **max** (Tensor) : Value of the max range of the input data x. - - Outputs: - - Tensor: Simulate quantize tensor of x. - - Examples: - >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) - >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max) - """ - support_quant_bit = [4, 7, 8] - - @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, - training=True, channel_axis=1): - """init FakeQuantPerChannelUpdate OP for Ascend""" - if context.get_context('device_target') == "Ascend": - from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update - if num_bits not in self.support_quant_bit: - raise ValueError( - f"For '{self.name}' attr \'num_bits\' is not support.") - if ema and not ema_decay: - raise ValueError( - f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type( - 'symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type( - 'narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type( - 'training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer( - 'num_bits', num_bits, 0, Rel.GT, self.name) - self.channel_axis = validator.check_integer( - 'channel axis', channel_axis, 0, Rel.GE, self.name) - self.init_prim_io_names( - inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) - - def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) - return min_shape, max_shape - - def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same( - {"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) - return min_type, max_type diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 93064ab7047969d10db6d67c1c91dd1332076849..7108bbe0509f1c2754e1385af0413a9a1c96d74a 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -32,7 +32,6 @@ from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils - _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, nn.ReLU6: quant.ReLU6Quant, nn.HSigmoid: quant.HSigmoidQuant, @@ -41,15 +40,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, class _AddFakeQuantInput(nn.Cell): """ - Add FakeQuant at input and output of the Network. Only support one input and one output case. + Add FakeQuant OP at input of the network. Only support one input case. """ def __init__(self, network, quant_delay=0): super(_AddFakeQuantInput, self).__init__(auto_prefix=False) - self.network = network - self.fake_quant_input = quant.FakeQuantWithMinMax( - min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) + self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) self.fake_quant_input.update_parameters_name('fake_quant_input') + self.network = network def construct(self, data): data = self.fake_quant_input(data) @@ -59,7 +57,7 @@ class _AddFakeQuantInput(nn.Cell): class _AddFakeQuantAfterSubCell(nn.Cell): """ - Add FakeQuant after of the sub Cell. + Add FakeQuant OP after of the sub Cell. """ def __init__(self, subcell, **kwargs): @@ -114,11 +112,12 @@ class ConvertToQuantNetwork: self.network.update_cell_prefix() network = self._convert_subcells2quant(self.network) network = _AddFakeQuantInput(network) + self.network.update_cell_type("quant") return network def _convert_subcells2quant(self, network): """ - convet sub cell to quant cell + convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell """ cells = network.name_cells() change = False @@ -137,19 +136,19 @@ class ConvertToQuantNetwork: if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) - # tensoradd to tensoradd quant + # add FakeQuant OP after OP in while list add_list = [] for name in network.__dict__: if name[0] == '_': continue attr = network.__dict__[name] - if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__: + if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: add_list.append((name, attr)) for name, prim_op in add_list: prefix = name add_quant = _AddFakeQuantAfterSubCell(prim_op, num_bits=self.act_bits, - quant_delay=self.act_delay, + quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) @@ -163,11 +162,11 @@ class ConvertToQuantNetwork: def _convert_conv(self, subcell): """ - convet conv cell to quant cell + convert Conv2d cell to quant cell """ conv_inner = subcell.conv bn_inner = subcell.batchnorm - if subcell.batchnorm is not None and self.bn_fold: + if subcell.has_bn and self.bn_fold: conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, @@ -177,7 +176,7 @@ class ConvertToQuantNetwork: dilation=conv_inner.dilation, group=conv_inner.group, eps=bn_inner.eps, - momentum=bn_inner.momentum, + momentum=1 - bn_inner.momentum, quant_delay=self.weight_qdelay, freeze_bn=self.freeze_bn, per_channel=self.weight_channel, @@ -185,6 +184,11 @@ class ConvertToQuantNetwork: fake=True, symmetric=self.weight_symmetric, narrow_range=self.weight_range) + # change original network BatchNormal OP parameters to quant network + conv_inner.gamma = subcell.batchnorm.gamma + conv_inner.beta = subcell.batchnorm.beta + conv_inner.moving_mean = subcell.batchnorm.moving_mean + conv_inner.moving_variance = subcell.batchnorm.moving_variance del subcell.batchnorm subcell.batchnorm = None subcell.has_bn = False @@ -203,6 +207,10 @@ class ConvertToQuantNetwork: num_bits=self.weight_bits, symmetric=self.weight_symmetric, narrow_range=self.weight_range) + # change original network Conv2D OP parameters to quant network + conv_inner.weight = subcell.conv.weight + if subcell.conv.has_bias: + conv_inner.bias = subcell.conv.bias subcell.conv = conv_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) @@ -229,6 +237,10 @@ class ConvertToQuantNetwork: per_channel=self.weight_channel, symmetric=self.weight_symmetric, narrow_range=self.weight_range) + # change original network Dense OP parameters to quant network + dense_inner.weight = subcell.dense.weight + if subcell.dense.has_bias: + dense_inner.bias = subcell.dense.bias subcell.dense = dense_inner if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) @@ -236,7 +248,7 @@ class ConvertToQuantNetwork: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, - quant_delay=self.act_delay, + quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) @@ -246,12 +258,12 @@ class ConvertToQuantNetwork: act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: raise ValueError( - "Unsupported activation in auto Quant: ", act_class) + "Unsupported activation in auto quant: ", act_class) return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) + symmetric=self.act_symmetric, + narrow_range=self.act_range) class ExportQuantNetworkDeploy: @@ -405,7 +417,7 @@ def convert_quant_network(network, narrow_range=(False, False) ): r""" - Create aware quantizaiton training network. + Create quantization aware training network. Args: network (Cell): Obtain a pipeline through network for saving graph summary. @@ -419,7 +431,7 @@ def convert_quant_network(network, then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: [False, False] symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on - symmetric otherwise base on assymmetric. The first element represent weights and second + symmetric otherwise base on asymmetric. The first element represent weights and second element represent data flow. Default: [False, False] narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base on narrow range otherwise base on off narrow range. The first element represent weights and @@ -428,6 +440,7 @@ def convert_quant_network(network, Returns: Cell, Network which has change to aware quantization training network cell. """ + def convert2list(name, value): if not isinstance(value, list) and not isinstance(value, tuple): value = [value] diff --git a/model_zoo/lenet_quant/README.md b/model_zoo/lenet_quant/README.md index c895f68be860e9ada4690cbaa0aea6f844ca12a1..2fd3e129a223bf93afad11f0016fe6dce22c432a 100644 --- a/model_zoo/lenet_quant/README.md +++ b/model_zoo/lenet_quant/README.md @@ -2,13 +2,13 @@ ## Description -Training LeNet with MNIST dataset in MindSpore with quantization aware trainging. +Training LeNet with MNIST dataset in MindSpore with quantization aware training. This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. In this tutorial, you will: -1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. +1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. 2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. 3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend. 4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples. @@ -24,16 +24,16 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http ```python pip uninstall -y mindspore-ascend pip uninstall -y mindspore-gpu -pip install mindspore-ascend-0.4.0.whl +pip install mindspore-ascend.whl ``` -then you will get the following display +Then you will get the following display ```bash >>> Found existing installation: mindspore-ascend >>> Uninstalling mindspore-ascend: ->>> Successfully uninstalled mindspore-ascend. +>>> Successfully uninstalled mindspore-ascend. ``` ### Prepare Dataset @@ -87,7 +87,7 @@ class LeNet5(nn.Cell): return x ``` -get the MNIST from scratch dataset. +Get the MNIST from scratch dataset. ```Python ds_train = create_dataset(os.path.join(args.data_path, "train"), @@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size() ### Train model -Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. +Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. ```Python # Define the network @@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following: >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] ``` -To save your time, just run this command. +Also, you can just run this command instead. ```python python train.py --data_path MNIST_Data --device_target Ascend @@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the # define funsion network network = LeNet5Fusion(cfg.num_classes) -# load aware quantizaiton network checkpoint +# load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) -# convert funsion netwrok to aware quantizaiton network +# convert funsion netwrok to quantization aware network network = quant.convert_quant_network(network) ``` ### load checkpoint -after convert to quantization aware network, we can load the checkpoint file. +After convert to quantization aware network, we can load the checkpoint file. ```python config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, @@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) ### train quantization aware model -To save your time, just run this command. +Also, you can just run this command instead. ```python python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt @@ -210,7 +210,7 @@ Procedure of quantization aware model evaluation is different from normal. Becau # define funsion network network = LeNet5Fusion(cfg.num_classes) -# load aware quantizaiton network checkpoint +# load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) @@ -218,10 +218,10 @@ load_param_into_net(network, param_dict) network = quant.convert_quant_network(network) ``` -To save your time, just run this command. +Also, you can just run this command insread. ```python -python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt +python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt ``` The top1 accuracy would display on shell. @@ -235,7 +235,7 @@ The top1 accuracy would display on shell. Here are some optional parameters: ```bash ---device_target {Ascend,GPU,CPU} +--device_target {Ascend,GPU} device where the code will be implemented (default: Ascend) --data_path DATA_PATH path where the dataset is saved diff --git a/model_zoo/lenet_quant/eval.py b/model_zoo/lenet_quant/eval.py index d94e77279faf5e0a12d9d0e38d30f64db13556fb..c0293ae1f78561438473b8f5acbf07b826f51731 100644 --- a/model_zoo/lenet_quant/eval.py +++ b/model_zoo/lenet_quant/eval.py @@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') diff --git a/model_zoo/lenet_quant/eval_quant.py b/model_zoo/lenet_quant/eval_quant.py index 2c2477123fedf714a99051ff53acaafe30ff6b01..bc9b62121d90718e80e4ab1e1cc04bfe039d313f 100644 --- a/model_zoo/lenet_quant/eval_quant.py +++ b/model_zoo/lenet_quant/eval_quant.py @@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -61,7 +61,7 @@ if __name__ == "__main__": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) # load quantization aware network checkpoint - param_dict = load_checkpoint(args.ckpt_path, model_type="quant") + param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) print("============== Starting Testing ==============") diff --git a/model_zoo/lenet_quant/train.py b/model_zoo/lenet_quant/train.py index b6040776ef40c7c822080f00647a6204b1733442..a34b6d5ed6695436d85ebc5c700edbace26bdf3d 100644 --- a/model_zoo/lenet_quant/train.py +++ b/model_zoo/lenet_quant/train.py @@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -56,8 +56,7 @@ if __name__ == "__main__": # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max, - model_type=network.type) + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) # define model diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index eb1f783a7c3f27aa725252c231136458599d2e76..ba54e63d8017659a5f4bc9cca3541b402abadb1d 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU', 'CPU'], + choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') @@ -50,11 +50,13 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) + + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path, network.type) load_param_into_net(network, param_dict) - # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") @@ -64,8 +66,7 @@ if __name__ == "__main__": # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max, - model_type="quant") + keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) # define model diff --git a/model_zoo/mobilenetv2/scripts/run_infer.sh b/model_zoo/mobilenetv2/scripts/run_infer.sh index e200e600bfec0f8e725cdeaad4a8de3f7fe95f76..7385a221d4f062875b71dcaf38fc933152de4bd1 100644 --- a/model_zoo/mobilenetv2/scripts/run_infer.sh +++ b/model_zoo/mobilenetv2/scripts/run_infer.sh @@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH export DEVICE_ID=0 export RANK_ID=0 export RANK_SIZE=1 -if [ -d "eval" ]; +if [ -d "../eval" ]; then rm -rf ../eval fi diff --git a/model_zoo/mobilenetv2/scripts/run_train.sh b/model_zoo/mobilenetv2/scripts/run_train.sh index fc013d474cb4730ced8f98e4fc2ff9ccbd8b25d7..3414aa7528ed80b6e26d73b7992bc130a81aadcd 100644 --- a/model_zoo/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/mobilenetv2/scripts/run_train.sh @@ -62,7 +62,7 @@ run_gpu() BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "train" ]; + if [ -d "../train" ]; then rm -rf ../train fi diff --git a/model_zoo/mobilenetv3/scripts/run_infer.sh b/model_zoo/mobilenetv3/scripts/run_infer.sh index e200e600bfec0f8e725cdeaad4a8de3f7fe95f76..7385a221d4f062875b71dcaf38fc933152de4bd1 100644 --- a/model_zoo/mobilenetv3/scripts/run_infer.sh +++ b/model_zoo/mobilenetv3/scripts/run_infer.sh @@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH export DEVICE_ID=0 export RANK_ID=0 export RANK_SIZE=1 -if [ -d "eval" ]; +if [ -d "../eval" ]; then rm -rf ../eval fi diff --git a/model_zoo/mobilenetv3/scripts/run_train.sh b/model_zoo/mobilenetv3/scripts/run_train.sh index 78b79b235fd459e9938bf57d9926c4c56b2a06d0..47dabffe01ea195109d39fc8c677daa0f84583ca 100644 --- a/model_zoo/mobilenetv3/scripts/run_train.sh +++ b/model_zoo/mobilenetv3/scripts/run_train.sh @@ -60,7 +60,7 @@ run_gpu() BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH - if [ -d "train" ]; + if [ -d "../train" ]; then rm -rf ../train fi diff --git a/tests/ut/python/train/quant/mobilenetv2_combined.py b/tests/ut/python/train/quant/mobilenetv2_combined.py index 7ed1498fb61350516a40041dbe26be1c680e4228..a87aaf827f3d7d384113041ea17f6ecf690f8960 100644 --- a/tests/ut/python/train/quant/mobilenetv2_combined.py +++ b/tests/ut/python/train/quant/mobilenetv2_combined.py @@ -17,7 +17,7 @@ def _conv_bn(in_channel, out_channel, kernel_size=ksize, stride=stride, - batchnorm=True)]) + has_bn=True)]) class InvertedResidual(nn.Cell): @@ -35,25 +35,25 @@ class InvertedResidual(nn.Cell): 3, stride, group=hidden_dim, - batchnorm=True, + has_bn=True, activation='relu6'), nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - batchnorm=True) + has_bn=True) ]) else: self.conv = nn.SequentialCell([ nn.Conv2dBnAct(inp, hidden_dim, 1, 1, - batchnorm=True, + has_bn=True, activation='relu6'), nn.Conv2dBnAct(hidden_dim, hidden_dim, 3, stride, group=hidden_dim, - batchnorm=True, + has_bn=True, activation='relu6'), nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - batchnorm=True) + has_bn=True) ]) self.add = P.TensorAdd() diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 54563d86eb7d826a647b35f7ff4e8a5b1fb03f31..1a21bc2c02342e0d4daabe86b87b92de724018a2 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -42,7 +42,7 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid") + self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu')