提交 106f7980 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2451 fix perchannel num_channels not set bug and adjust quant.py params order

Merge pull request !2451 from 王东旭/master
......@@ -23,35 +23,24 @@
#include "device/gpu/cuda_common.h"
__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min,
float *output_max, const float min, const float max, const float decay,
const float symmetric) {
float *output_max, const float min, const float max,
const float decay) {
output_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
output_min[0] = input_min[0] > 0 ? 0 : input_min[0];
output_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
output_max[0] = input_max[0] < 0 ? 0 : input_max[0];
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max,
const float symmetric) {
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) {
output_min[0] = min > 0 ? 0 : min;
output_max[0] = max < 0 ? 0 : max;
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min,
float *output_max, int channels, int per_channel_nums, bool ema,
float ema_decay, bool symmetric) {
float ema_decay) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
......@@ -64,27 +53,21 @@ __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, floa
}
output_min[i] = input_min[i] > 0 ? 0 : input_min[i];
output_max[i] = input_max[i] < 0 ? 0 : input_max[i];
if (symmetric) {
output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i];
output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i];
}
}
return;
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream) {
cudaStream_t cuda_stream) {
int per_channel_num = total_num / channel_num;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric);
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay);
return;
}
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream) {
const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
......@@ -96,9 +79,9 @@ void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *
if (ema) {
UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel,
maxel, ema_decay, symmetric);
maxel, ema_decay);
} else {
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric);
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel);
}
return;
}
......@@ -21,10 +21,9 @@
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int size, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream);
const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
......@@ -24,16 +24,7 @@
namespace mindspore {
namespace kernel {
MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
num_channels_(0),
narrow_range_(false),
symmetric_(false) {}
: input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {}
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
......@@ -54,22 +45,8 @@ bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......@@ -110,7 +87,7 @@ bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inpu
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_,
ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
......
......@@ -44,15 +44,10 @@ class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
int num_channels_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -24,15 +24,7 @@
namespace mindspore {
namespace kernel {
MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
narrow_range_(false),
symmetric_(false) {}
: input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {}
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
......@@ -51,22 +43,8 @@ bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......@@ -104,7 +82,7 @@ bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null.";
}
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_,
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
......
......@@ -44,14 +44,9 @@ class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -276,15 +276,15 @@ class FakeQuantWithMinMax(Cell):
Args:
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_channels (int): declarate the min and max channel size, Default: 1.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
......@@ -301,15 +301,15 @@ class FakeQuantWithMinMax(Cell):
def __init__(self,
min_init=-6,
max_init=6,
num_bits=8,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
out_channels=1,
quant_delay=0,
num_channels=1,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init
......@@ -318,7 +318,7 @@ class FakeQuantWithMinMax(Cell):
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.out_channels = out_channels
self.num_channels = num_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
......@@ -327,11 +327,11 @@ class FakeQuantWithMinMax(Cell):
# init tensor min and max for fake quant op
if self.per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32)
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else:
min_array = np.array([self.min_init]).reshape(1).astype(np.float32)
max_array = np.array([self.max_init]).reshape(1).astype(np.float32)
min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
......@@ -343,57 +343,41 @@ class FakeQuantWithMinMax(Cell):
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.MinMaxUpdatePerLayer
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits,
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
self.fake_quant_infer = self.fake_quant_train
else:
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = quant_fun(num_bits=self.num_bits,
quant_fun = partial(quant_fun,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
quant_delay=quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)
def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init)
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
return s
def construct(self, x):
if self.is_ascend:
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant(x, self.minq, self.maxq)
else:
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant_train(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out
class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.
......@@ -407,8 +391,8 @@ class Conv2dBatchNormQuant(Cell):
stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.997.
eps (float): Parameters for BatchNormal. Default: 1e-5.
momentum (float): Parameters for BatchNormal op. Default: 0.997.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'normal'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
......@@ -419,13 +403,13 @@ class Conv2dBatchNormQuant(Cell):
mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'ones'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
......@@ -456,13 +440,13 @@ class Conv2dBatchNormQuant(Cell):
gamma_init='ones',
mean_init='zeros',
var_init='ones',
quant_delay=0,
freeze_bn=100000,
fake=True,
num_bits=8,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0,
freeze_bn=100000):
"""init Conv2dBatchNormQuant layer"""
super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels
......@@ -519,12 +503,13 @@ class Conv2dBatchNormQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
channel_axis=channel_axis,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = Q.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend":
......@@ -598,11 +583,11 @@ class Conv2dQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
......@@ -629,11 +614,11 @@ class Conv2dQuant(Cell):
has_bias=False,
weight_init='normal',
bias_init='zeros',
quant_delay=0,
num_bits=8,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
......@@ -669,12 +654,13 @@ class Conv2dQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
def construct(self, x):
weight = self.fake_quant_weight(self.weight)
......@@ -708,11 +694,11 @@ class DenseQuant(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
......@@ -734,19 +720,19 @@ class DenseQuant(Cell):
bias_init='zeros',
has_bias=True,
activation=None,
num_bits=8,
quant_delay=0,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(
......@@ -754,7 +740,7 @@ class DenseQuant(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(
......@@ -768,12 +754,13 @@ class DenseQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel,
out_channels=out_channels,
channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
def construct(self, x):
"""Use operators to construct to Dense layer."""
......@@ -796,13 +783,16 @@ class DenseQuant(Cell):
return str_info
class _QuantActivation(Cell):
r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
"""
def get_origin(self):
raise NotImplementedError
class ReLUQuant(_QuantActivation):
r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
......@@ -810,12 +800,12 @@ class ReLUQuant(_QuantActivation):
For a more Detailed overview of ReLU op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLUQuant.
......@@ -830,22 +820,22 @@ class ReLUQuant(_QuantActivation):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(ReLUQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu = P.ReLU()
def construct(self, x):
......@@ -866,12 +856,12 @@ class ReLU6Quant(_QuantActivation):
For a more Detailed overview of ReLU6 op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLU6Quant.
......@@ -886,22 +876,22 @@ class ReLU6Quant(_QuantActivation):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(ReLU6Quant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu6 = P.ReLU6()
def construct(self, x):
......@@ -912,6 +902,7 @@ class ReLU6Quant(_QuantActivation):
def get_origin(self):
return self.relu6
class HSwishQuant(_QuantActivation):
r"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
......@@ -919,12 +910,12 @@ class HSwishQuant(_QuantActivation):
For a more Detailed overview of HSwish op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of HSwishQuant.
......@@ -939,31 +930,31 @@ class HSwishQuant(_QuantActivation):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(HSwishQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSwish()
def construct(self, x):
......@@ -975,6 +966,7 @@ class HSwishQuant(_QuantActivation):
def get_origin(self):
return self.act
class HSigmoidQuant(_QuantActivation):
r"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
......@@ -982,12 +974,12 @@ class HSigmoidQuant(_QuantActivation):
For a more Detailed overview of HSigmoid op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of HSigmoidQuant.
......@@ -1002,30 +994,31 @@ class HSigmoidQuant(_QuantActivation):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSigmoid()
def construct(self, x):
......@@ -1037,6 +1030,7 @@ class HSigmoidQuant(_QuantActivation):
def get_origin(self):
return self.act
class TensorAddQuant(Cell):
r"""
Add Fake Quant OP after TensorAdd OP.
......@@ -1044,12 +1038,12 @@ class TensorAddQuant(Cell):
For a more Detailed overview of TensorAdd op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of TensorAddQuant.
......@@ -1065,22 +1059,22 @@ class TensorAddQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(TensorAddQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.add = P.TensorAdd()
def construct(self, x1, x2):
......@@ -1096,12 +1090,12 @@ class MulQuant(Cell):
For a more Detailed overview of Mul op.
Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of MulQuant.
......@@ -1112,22 +1106,22 @@ class MulQuant(Cell):
"""
def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False):
narrow_range=False,
quant_delay=0):
super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range)
narrow_range=narrow_range,
quant_delay=quant_delay)
self.mul = P.Mul()
def construct(self, x1, x2):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -22,20 +21,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.binfile_name("minmax_update_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.kernel_name("minmax_update_perchannel") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
......@@ -47,24 +41,27 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
@op_info_register(minmax_update_perchannel_op_info)
def _minmax_update_perchannel_tbe():
"""MinMaxUpdatePerChannel TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
@fusion_manager.register("minmax_update_perchannel")
def minmax_update_perchannel_compute(x, min_val, max_val,
ema, ema_decay, channel_axis):
"""MinMaxUpdatePerChannel compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
if channel_axis == 0:
axis = [1, 2, 3, 4]
else:
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
......@@ -79,11 +76,11 @@ def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
ema, ema_decay, channel_axis,
kernel_name="minmax_update_perchannel"):
"""MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
......@@ -108,21 +105,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
if channel_axis == 0:
shape_c = min_val.get("ori_shape")
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, channel_axis)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -22,20 +22,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \
.binfile_name("minmax_update_perlayer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_minmax_update") \
.kernel_name("minmax_update_perlayer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -46,15 +41,14 @@ fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.get_op_info()
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
@op_info_register(minmax_update_perlayer_op_info)
def _minmax_update_perlayer_tbe():
"""MinMaxUpdatePerLayer TBE register"""
return
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
@fusion_manager.register("minmax_update_perlayer")
def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
"""MinMaxUpdatePerLayer compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
......@@ -62,7 +56,7 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
......@@ -79,11 +73,10 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
ema, ema_decay, kernel_name="minmax_update_perlayer"):
"""MinMaxUpdatePerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -21,12 +21,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantPerLayer",
__all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
......@@ -38,10 +38,128 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGradD",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"BatchNormFold2GradReduce"
]
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
......@@ -832,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init MinMaxUpdatePerLayer OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init MinMaxUpdatePerChannel OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册