diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu index 4d313fa9313d39d7ba32208a7bc5048c776fc328..27b2cb0232c9e14d17ff3ef3dfa4184e331d63dd 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu @@ -23,35 +23,24 @@ #include "device/gpu/cuda_common.h" __global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, - float *output_max, const float min, const float max, const float decay, - const float symmetric) { + float *output_max, const float min, const float max, + const float decay) { output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; - - if (symmetric) { - output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; - output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; - } return; } -__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max, - const float symmetric) { +__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) { output_min[0] = min > 0 ? 0 : min; output_max[0] = max < 0 ? 0 : max; - - if (symmetric) { - output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; - output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; - } return; } __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, int channels, int per_channel_nums, bool ema, - float ema_decay, bool symmetric) { + float ema_decay) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { thrust::pair sum = thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); @@ -64,27 +53,21 @@ __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, floa } output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; - - if (symmetric) { - output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i]; - output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i]; - } } return; } void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, const int total_num, const int channel_num, const float ema_decay, const bool ema, - const bool symmetric, cudaStream_t cuda_stream) { + cudaStream_t cuda_stream) { int per_channel_num = total_num / channel_num; UpdateInputMinMaxPerChannel<<>>( - input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric); + input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay); return; } void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const float ema_decay, const bool ema, const bool symmetric, - cudaStream_t cuda_stream) { + const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { float minel = 0.f; float maxel = 0.f; auto policy = thrust::cuda::par.on(cuda_stream); @@ -96,9 +79,9 @@ void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float * if (ema) { UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, - maxel, ema_decay, symmetric); + maxel, ema_decay); } else { - UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric); + UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel); } return; } diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh index e0e5a731d3fd673f96a82a0a683f9fe11eddabcc..5e9becab38084eb1a7b7243d08949ed87fb60751 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh @@ -21,10 +21,9 @@ void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, const int total_num, const int channel_num, const float ema_decay, const bool ema, - const bool symmetric, cudaStream_t cuda_stream); + cudaStream_t cuda_stream); void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int size, const float ema_decay, const bool ema, const bool symmetric, - cudaStream_t cuda_stream); + const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc index ae9df3355bb1be3cf1104efc96fc4807f1795bfe..a8ce72148b2ce80d454347768beb96ea1c9638fe 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc @@ -24,16 +24,7 @@ namespace mindspore { namespace kernel { MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() - : input_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - ema_(false), - ema_decay_(0), - num_channels_(0), - narrow_range_(false), - symmetric_(false) {} + : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {} const std::vector &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } @@ -54,22 +45,8 @@ bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; } - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - // quant min and max - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } // init size auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -110,7 +87,7 @@ bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector &inpu // calculate the input min and max according by the parameter ema and ema_decay. CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, - ema_decay_, ema_, symmetric_, reinterpret_cast(stream_ptr)); + ema_decay_, ema_, reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h index 5a35d4da32042824b6f1650fe1d49b6deff2752f..563a583ca1d00dc8c7bd954dc46fa364d418d89b 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h @@ -44,15 +44,10 @@ class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; - int num_bits_; - float quant_min_; - float quant_max_; int quant_num_; bool ema_; float ema_decay_; int num_channels_; - bool narrow_range_; - bool symmetric_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc index 8ba1d363dd6371066081ec052078b2afcb471bca..3659665b23ceec6667e672a3f4ca75884a851d98 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc @@ -24,15 +24,7 @@ namespace mindspore { namespace kernel { MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() - : input_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - ema_(false), - ema_decay_(0), - narrow_range_(false), - symmetric_(false) {} + : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {} const std::vector &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } @@ -51,22 +43,8 @@ bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; } - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - // quant min and max - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } // init size auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -104,7 +82,7 @@ bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector &inputs MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; } - CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_, + CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, reinterpret_cast(stream_ptr)); return true; diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h index 527de20427fdf1aa7fa4658bbc5b1fbfca76d09a..a237b6dc266948a52fc810433584964e4fc80bde 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h @@ -44,14 +44,9 @@ class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; - int num_bits_; - float quant_min_; - float quant_max_; int quant_num_; bool ema_; float ema_decay_; - bool narrow_range_; - bool symmetric_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 8c444362b256dd567660fbccfbf8482c3ef04555..14731c62623939fa8ffd101da6268938b7c6da2f 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -276,15 +276,15 @@ class FakeQuantWithMinMax(Cell): Args: min_init (int, float): The dimension of channel or 1(layer). Default: -6. max_init (int, float): The dimension of channel or 1(layer). Default: 6. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. channel_axis (int): Quantization by channel axis. Default: 1. - out_channels (int): declarate the min and max channel size, Default: 1. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. + num_channels (int): declarate the min and max channel size, Default: 1. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of FakeQuantWithMinMax. @@ -301,15 +301,15 @@ class FakeQuantWithMinMax(Cell): def __init__(self, min_init=-6, max_init=6, - num_bits=8, ema=False, ema_decay=0.999, per_channel=False, channel_axis=1, - out_channels=1, - quant_delay=0, + num_channels=1, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() self.min_init = min_init @@ -318,7 +318,7 @@ class FakeQuantWithMinMax(Cell): self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel - self.out_channels = out_channels + self.num_channels = num_channels self.channel_axis = channel_axis self.quant_delay = quant_delay self.symmetric = symmetric @@ -327,11 +327,11 @@ class FakeQuantWithMinMax(Cell): # init tensor min and max for fake quant op if self.per_channel: - min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) - max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) + min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) + max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) else: - min_array = np.array([self.min_init]).reshape(1).astype(np.float32) - max_array = np.array([self.max_init]).reshape(1).astype(np.float32) + min_array = np.array([self.min_init]).astype(np.float32) + max_array = np.array([self.max_init]).astype(np.float32) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) @@ -343,57 +343,41 @@ class FakeQuantWithMinMax(Cell): quant_fun = Q.FakeQuantPerLayer ema_fun = Q.MinMaxUpdatePerLayer + self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) if self.is_ascend: - self.fake_quant = quant_fun(num_bits=self.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range) - else: self.fake_quant_train = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=True) - self.fake_quant_infer = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=False) - self.ema_update = ema_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + narrow_range=self.narrow_range) + self.fake_quant_infer = self.fake_quant_train + else: + quant_fun = partial(quant_fun, + ema=self.ema, + ema_decay=ema_decay, + num_bits=self.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=quant_delay) + self.fake_quant_train = quant_fun(training=True) + self.fake_quant_infer = quant_fun(training=False) def extend_repr(self): s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ 'quant_delay={}, min_init={}, max_init={}'.format( self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, - self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) + self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) return s def construct(self, x): - if self.is_ascend: - if self.training: - min_up, max_up = self.ema_update(x, self.minq, self.maxq) - out = self.fake_quant(x, min_up, max_up) - P.Assign()(self.minq, min_up) - P.Assign()(self.maxq, max_up) - else: - out = self.fake_quant(x, self.minq, self.maxq) + if self.training: + min_up, max_up = self.ema_update(x, self.minq, self.maxq) + P.Assign()(self.minq, min_up) + P.Assign()(self.maxq, max_up) + out = self.fake_quant_train(x, self.minq, self.maxq) else: - if self.training: - min_up, max_up = self.ema_update(x, self.minq, self.maxq) - out = self.fake_quant_train(x, min_up, max_up) - P.Assign()(self.minq, min_up) - P.Assign()(self.maxq, max_up) - else: - out = self.fake_quant_infer(x, self.minq, self.maxq) + out = self.fake_quant_infer(x, self.minq, self.maxq) return out + class Conv2dBatchNormQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. @@ -407,8 +391,8 @@ class Conv2dBatchNormQuant(Cell): stride (int): Specifies stride for all spatial dimensions with the same value. pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". padding: (int): Implicit paddings on both sides of the input. Default: 0. - eps (int): Parameters for BatchNormal. Default: 1e-5. - momentum (int): Parameters for BatchNormal op. Default: 0.997. + eps (float): Parameters for BatchNormal. Default: 1e-5. + momentum (float): Parameters for BatchNormal op. Default: 0.997. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the @@ -419,13 +403,13 @@ class Conv2dBatchNormQuant(Cell): mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the variance vector. Default: 'ones'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -456,13 +440,13 @@ class Conv2dBatchNormQuant(Cell): gamma_init='ones', mean_init='zeros', var_init='ones', - quant_delay=0, - freeze_bn=100000, fake=True, - num_bits=8, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0, + freeze_bn=100000): """init Conv2dBatchNormQuant layer""" super(Conv2dBatchNormQuant, self).__init__() self.in_channels = in_channels @@ -519,12 +503,13 @@ class Conv2dBatchNormQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=channel_axis, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.correct_mul = Q.CorrectionMul(channel_axis) if context.get_context('device_target') == "Ascend": @@ -598,11 +583,11 @@ class Conv2dQuant(Cell): weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -629,11 +614,11 @@ class Conv2dQuant(Cell): has_bias=False, weight_init='normal', bias_init='zeros', - quant_delay=0, - num_bits=8, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(Conv2dQuant, self).__init__() if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size) @@ -669,12 +654,13 @@ class Conv2dQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) def construct(self, x): weight = self.fake_quant_weight(self.weight) @@ -708,11 +694,11 @@ class DenseQuant(Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -734,19 +720,19 @@ class DenseQuant(Cell): bias_init='zeros', has_bias=True, activation=None, - num_bits=8, - quant_delay=0, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(DenseQuant, self).__init__() self.in_channels = check_int_positive(in_channels) self.out_channels = check_int_positive(out_channels) self.has_bias = check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ - weight_init.shape[1] != in_channels: + if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ + weight_init.shape()[1] != in_channels: raise ValueError("weight_init shape error") self.weight = Parameter(initializer( @@ -754,7 +740,7 @@ class DenseQuant(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( @@ -768,12 +754,13 @@ class DenseQuant(Cell): self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, - num_bits=num_bits, - quant_delay=quant_delay, per_channel=per_channel, - out_channels=out_channels, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) def construct(self, x): """Use operators to construct to Dense layer.""" @@ -796,13 +783,16 @@ class DenseQuant(Cell): return str_info + class _QuantActivation(Cell): r""" Base class for Quant activation function. Add Fake Quant OP after activation OP. """ + def get_origin(self): raise NotImplementedError + class ReLUQuant(_QuantActivation): r""" ReLUQuant activation function. Add Fake Quant OP after Relu OP. @@ -810,12 +800,12 @@ class ReLUQuant(_QuantActivation): For a more Detailed overview of ReLU op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of ReLUQuant. @@ -830,22 +820,22 @@ class ReLUQuant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(ReLUQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.relu = P.ReLU() def construct(self, x): @@ -866,12 +856,12 @@ class ReLU6Quant(_QuantActivation): For a more Detailed overview of ReLU6 op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of ReLU6Quant. @@ -886,22 +876,22 @@ class ReLU6Quant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(ReLU6Quant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.relu6 = P.ReLU6() def construct(self, x): @@ -912,6 +902,7 @@ class ReLU6Quant(_QuantActivation): def get_origin(self): return self.relu6 + class HSwishQuant(_QuantActivation): r""" HSwishQuant activation function. Add Fake Quant OP after HSwish OP. @@ -919,12 +910,12 @@ class HSwishQuant(_QuantActivation): For a more Detailed overview of HSwish op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of HSwishQuant. @@ -939,31 +930,31 @@ class HSwishQuant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(HSwishQuant, self).__init__() self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.act = P.HSwish() def construct(self, x): @@ -975,6 +966,7 @@ class HSwishQuant(_QuantActivation): def get_origin(self): return self.act + class HSigmoidQuant(_QuantActivation): r""" HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. @@ -982,12 +974,12 @@ class HSigmoidQuant(_QuantActivation): For a more Detailed overview of HSigmoid op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of HSigmoidQuant. @@ -1002,30 +994,31 @@ class HSigmoidQuant(_QuantActivation): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(HSigmoidQuant, self).__init__() self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.act = P.HSigmoid() def construct(self, x): @@ -1037,6 +1030,7 @@ class HSigmoidQuant(_QuantActivation): def get_origin(self): return self.act + class TensorAddQuant(Cell): r""" Add Fake Quant OP after TensorAdd OP. @@ -1044,12 +1038,12 @@ class TensorAddQuant(Cell): For a more Detailed overview of TensorAdd op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of TensorAddQuant. @@ -1065,22 +1059,22 @@ class TensorAddQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(TensorAddQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.add = P.TensorAdd() def construct(self, x1, x2): @@ -1096,12 +1090,12 @@ class MulQuant(Cell): For a more Detailed overview of Mul op. Args: - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. Inputs: - **x** (Tensor) - The input of MulQuant. @@ -1112,22 +1106,22 @@ class MulQuant(Cell): """ def __init__(self, - num_bits=8, - quant_delay=0, ema_decay=0.999, per_channel=False, + num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + quant_delay=0): super(MulQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, ema=True, - per_channel=per_channel, ema_decay=ema_decay, + per_channel=per_channel, + num_bits=num_bits, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + quant_delay=quant_delay) self.mul = P.Mul() def construct(self, x1, x2): diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py similarity index 56% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py rename to mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py index 3560c20d7b0ab3cb9e9a1aec9007aba8d2689e3b..1ff63464c3735c5be9ddf07b4d7c909374d0e9fb 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py @@ -1,4 +1,3 @@ - # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,20 +21,15 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \ +minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_min_max_per_channel_update.so") \ + .binfile_name("minmax_update_perchannel.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_min_max_per_channel_update") \ + .kernel_name("minmax_update_perchannel") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ - .attr("symmetric", "optional", "bool", "all") \ - .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ - .attr("num_bits", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ @@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel .get_op_info() -@op_info_register(fake_quant_min_max_per_channel_update_op_info) -def _fake_quant_min_max_per_channel_update_tbe(): - """FakeQuantPerChannelUpdate TBE register""" +@op_info_register(minmax_update_perchannel_op_info) +def _minmax_update_perchannel_tbe(): + """MinMaxUpdatePerChannel TBE register""" return -@fusion_manager.register("fake_quant_min_max_per_channel_update") -def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, - ema, ema_decay, quant_min, quant_max, training, channel_axis, - kernel_name="fake_quant_min_max_per_channel_update"): - """FakeQuantPerChannelUpdate compute""" +@fusion_manager.register("minmax_update_perchannel") +def minmax_update_perchannel_compute(x, min_val, max_val, + ema, ema_decay, channel_axis): + """MinMaxUpdatePerChannel compute""" shape_min = te.lang.cce.util.shape_to_list(min_val.shape) if not ema: ema_decay = 0.0 - if training: - # CalMinMax + + # CalMinMax + if channel_axis == 0: + axis = [1, 2, 3, 4] + else: axis = [0, 2, 3] - x_min = te.lang.cce.reduce_min(x, axis=axis) - x_max = te.lang.cce.reduce_max(x, axis=axis) - x_min = te.lang.cce.broadcast(x_min, shape_min) - x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls( - min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls( - max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) - min_val = te.lang.cce.vmins(min_val, 0) - max_val = te.lang.cce.vmaxs(max_val, 0) + + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) -def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, - kernel_name="fake_quant_min_max_per_channel_update"): - """FakeQuantPerLayer op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str) +def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, + ema, ema_decay, channel_axis, + kernel_name="minmax_update_perchannel"): + """MinMaxUpdatePerChannel op""" x_shape = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") @@ -108,21 +105,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(max_dtype, check_list) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 + if channel_axis == 0: + shape_c = min_val.get("ori_shape") else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - if narrow_range: - quant_min = quant_min + 1 - - shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] + shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) - res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) + res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, + ema, ema_decay, channel_axis) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py b/mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py similarity index 63% rename from mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py rename to mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py index 554b66b95baee091e27950c46b276169258d721b..4d2096d55ba6fd4d7afedf422e563260cb0975cc 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py +++ b/mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py @@ -22,20 +22,15 @@ from topi import generic from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \ +minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fake_quant_minmax_update.so") \ + .binfile_name("minmax_update_perlayer.so") \ .compute_cost(10) \ - .kernel_name("fake_quant_minmax_update") \ + .kernel_name("minmax_update_perlayer") \ .partial_flag(True) \ .attr("ema", "optional", "bool", "all") \ .attr("ema_decay", "optional", "float", "all") \ - .attr("symmetric", "optional", "bool", "all") \ - .attr("narrow_range", "optional", "bool", "all") \ - .attr("training", "optional", "bool", "all") \ - .attr("num_bits", "optional", "int", "all") \ .input(0, "x", None, "required", None) \ .input(1, "min", None, "required", None) \ .input(2, "max", None, "required", None) \ @@ -46,15 +41,14 @@ fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \ .get_op_info() -@op_info_register(fake_quant_minmax_update_op_info) -def _fake_quant_minmax_update_tbe(): +@op_info_register(minmax_update_perlayer_op_info) +def _minmax_update_perlayer_tbe(): """MinMaxUpdatePerLayer TBE register""" return -@fusion_manager.register("fake_quant_minmax_update") -def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, - kernel_name="fake_quant_minmax_update"): +@fusion_manager.register("minmax_update_perlayer") +def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay): """MinMaxUpdatePerLayer compute""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) @@ -62,28 +56,27 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_ max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) if not ema: ema_decay = 0.0 - if training: - # CalMinMax - axis = tuple(range(len(shape))) - x_min = te.lang.cce.reduce_min(x, axis=axis) - x_max = te.lang.cce.reduce_max(x, axis=axis) - x_min = te.lang.cce.broadcast(x_min, shape_min) - x_max = te.lang.cce.broadcast(x_max, shape_min) - min_val = te.lang.cce.vadd(te.lang.cce.vmuls( - min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) - max_val = te.lang.cce.vadd(te.lang.cce.vmuls( - max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) - min_val = te.lang.cce.vmins(min_val, 0) - max_val = te.lang.cce.vmaxs(max_val, 0) + + # CalMinMax + axis = tuple(range(len(shape))) + x_min = te.lang.cce.reduce_min(x, axis=axis) + x_max = te.lang.cce.reduce_max(x, axis=axis) + x_min = te.lang.cce.broadcast(x_min, shape_min) + x_max = te.lang.cce.broadcast(x_max, shape_min) + min_val = te.lang.cce.vadd(te.lang.cce.vmuls( + min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) + max_val = te.lang.cce.vadd(te.lang.cce.vmuls( + max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) + min_val = te.lang.cce.vmins(min_val, 0) + max_val = te.lang.cce.vmaxs(max_val, 0) return [min_val, max_val] -@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) -def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, - ema, ema_decay, symmetric, narrow_range, training, num_bits, - kernel_name="fake_quant_minmax_update"): - """FakeQuantPerLayer op""" +@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str) +def minmax_update_perlayer(x, min_val, max_val, min_up, max_up, + ema, ema_decay, kernel_name="minmax_update_perlayer"): + """MinMaxUpdatePerLayer op""" input_shape = x.get("shape") input_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) shape_min, _, _ = util.produce_shapes(min_shape, input_shape) - if symmetric: - quant_min = 0 - 2 ** (num_bits - 1) - quant_max = 2 ** (num_bits - 1) - 1 - else: - quant_min = 0 - quant_max = 2 ** num_bits - 1 - if narrow_range: - quant_min = quant_min + 1 - input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) - res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, kernel_name) + res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index b668fd0d83d8e485925042b9fc822d9f7bc57243..42c240690601a709d916b41bd6eab9de521afaf5 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -21,12 +21,12 @@ from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register from ...common import dtype as mstype -__all__ = ["FakeQuantPerLayer", +__all__ = ["MinMaxUpdatePerLayer", + "MinMaxUpdatePerChannel", + "FakeQuantPerLayer", "FakeQuantPerLayerGrad", "FakeQuantPerChannel", "FakeQuantPerChannelGrad", - "MinMaxUpdatePerLayer", - "MinMaxUpdatePerChannel", "BatchNormFold", "BatchNormFoldGrad", "CorrectionMul", @@ -38,10 +38,128 @@ __all__ = ["FakeQuantPerLayer", "BatchNormFoldGradD", "BatchNormFold2_D", "BatchNormFold2GradD", - "BatchNormFold2GradReduce", + "BatchNormFold2GradReduce" ] +class MinMaxUpdatePerLayer(PrimitiveWithInfer): + r""" + Update min and max per layer. + + Args: + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6]), mstype.float32) + >>> max_tensor = Tensor(np.array([6]), mstype.float32) + >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, ema=False, ema_decay=0.999): + """init FakeQuantMinMaxPerLayerUpdate OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import minmax_update_perlayer + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.init_prim_io_names(inputs=['x', 'min', 'max'], + outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + +class MinMaxUpdatePerChannel(PrimitiveWithInfer): + r""" + Update min and max per channel. + + Args: + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + channel_axis (int): Channel asis for per channel compute. Default: 1. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) + >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): + """init FakeQuantPerChannelUpdate OP for Ascend""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import minmax_update_perchannel + if ema and not ema_decay: + raise ValueError( + f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.channel_axis = validator.check_integer( + 'channel axis', channel_axis, 0, Rel.GE, self.name) + self.init_prim_io_names( + inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check("min shape", min_shape, "max shape", + max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len( + min_shape), 1, Rel.EQ, self.name) + return min_shape, max_shape + + def infer_dtype(self, x_type, min_type, max_type): + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same( + {"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same( + {"max": max_type}, valid_types, self.name) + return min_type, max_type + + class FakeQuantPerLayer(PrimitiveWithInfer): r""" Simulate the quantize and dequantize operations in training time. @@ -832,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type): validator.check("dout type", dout_type, "x type", x_type) return dout_type, dout_type - - -class MinMaxUpdatePerLayer(PrimitiveWithInfer): - r""" - Update min and max value for fake quant per layer op. - - Args: - num_bits (int) : Number bits for quantization aware. Default: 8. - ema (bool): Use EMA algorithm update value min and max. Default: False. - ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Training the network or not. Default: True. - - Inputs: - - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - - **min** (Tensor) : Value of the min range of the input data x. - - **max** (Tensor) : Value of the max range of the input data x. - - Outputs: - - Tensor: Simulate quantize tensor of x. - - Examples: - >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) - >>> min_tensor = Tensor(np.array([-6]), mstype.float32) - >>> max_tensor = Tensor(np.array([6]), mstype.float32) - >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) - """ - support_quant_bit = [4, 7, 8] - - @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, - training=True): - """init MinMaxUpdatePerLayer OP""" - if context.get_context('device_target') == "Ascend": - from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update - if num_bits not in self.support_quant_bit: - raise ValueError( - f"For '{self.name}' attr \'num_bits\' is not support.") - if ema and not ema_decay: - raise ValueError( - f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type( - 'symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type( - 'narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type( - 'training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer( - 'num_bits', num_bits, 0, Rel.GT, self.name) - self.init_prim_io_names(inputs=['x', 'min', 'max'], - outputs=['min_up', 'max_up']) - - def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) - return min_shape, max_shape - - def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) - return min_type, max_type - - -class MinMaxUpdatePerChannel(PrimitiveWithInfer): - r""" - Update min and max value for fake quant per layer op. - - Args: - num_bits (int) : Number bits for quantization aware. Default: 8. - ema (bool): Use EMA algorithm update value min and max. Default: False. - ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - training (bool): Training the network or not. Default: True. - channel_axis (int): Channel asis for per channel compute. Default: 1. - - Inputs: - - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. - - **min** (Tensor) : Value of the min range of the input data x. - - **max** (Tensor) : Value of the max range of the input data x. - - Outputs: - - Tensor: Simulate quantize tensor of x. - - Examples: - >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) - >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) - """ - support_quant_bit = [4, 7, 8] - - @prim_attr_register - def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, - training=True, channel_axis=1): - """init MinMaxUpdatePerChannel OP for Ascend""" - if context.get_context('device_target') == "Ascend": - from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update - if num_bits not in self.support_quant_bit: - raise ValueError( - f"For '{self.name}' attr \'num_bits\' is not support.") - if ema and not ema_decay: - raise ValueError( - f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = validator.check_value_type('ema', ema, (bool,), self.name) - self.symmetric = validator.check_value_type( - 'symmetric', symmetric, (bool,), self.name) - self.narrow_range = validator.check_value_type( - 'narrow_range', narrow_range, (bool,), self.name) - self.training = validator.check_value_type( - 'training', training, (bool,), self.name) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) - self.num_bits = validator.check_integer( - 'num_bits', num_bits, 0, Rel.GT, self.name) - self.channel_axis = validator.check_integer( - 'channel axis', channel_axis, 0, Rel.GE, self.name) - self.init_prim_io_names( - inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) - - def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min shape", len( - min_shape), 1, Rel.EQ, self.name) - return min_shape, max_shape - - def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same( - {"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) - return min_type, max_type