提交 02584fe2 编写于 作者: W wangdongxu

fix perchannel num_channels not set bug and adjust quant.py params order

上级 bd0c5384
......@@ -23,35 +23,24 @@
#include "device/gpu/cuda_common.h"
__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min,
float *output_max, const float min, const float max, const float decay,
const float symmetric) {
float *output_max, const float min, const float max,
const float decay) {
output_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
output_min[0] = input_min[0] > 0 ? 0 : input_min[0];
output_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
output_max[0] = input_max[0] < 0 ? 0 : input_max[0];
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max,
const float symmetric) {
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) {
output_min[0] = min > 0 ? 0 : min;
output_max[0] = max < 0 ? 0 : max;
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min,
float *output_max, int channels, int per_channel_nums, bool ema,
float ema_decay, bool symmetric) {
float ema_decay) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
......@@ -64,27 +53,21 @@ __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, floa
}
output_min[i] = input_min[i] > 0 ? 0 : input_min[i];
output_max[i] = input_max[i] < 0 ? 0 : input_max[i];
if (symmetric) {
output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i];
output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i];
}
}
return;
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream) {
cudaStream_t cuda_stream) {
int per_channel_num = total_num / channel_num;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric);
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay);
return;
}
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream) {
const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
......@@ -96,9 +79,9 @@ void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *
if (ema) {
UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel,
maxel, ema_decay, symmetric);
maxel, ema_decay);
} else {
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric);
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel);
}
return;
}
......@@ -21,10 +21,9 @@
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int size, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream);
const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
......@@ -24,16 +24,7 @@
namespace mindspore {
namespace kernel {
MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
num_channels_(0),
narrow_range_(false),
symmetric_(false) {}
: input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {}
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
......@@ -54,22 +45,8 @@ bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......@@ -110,7 +87,7 @@ bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inpu
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_,
ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
......
......@@ -44,15 +44,10 @@ class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
int num_channels_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -24,15 +24,7 @@
namespace mindspore {
namespace kernel {
MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
narrow_range_(false),
symmetric_(false) {}
: input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {}
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
......@@ -51,22 +43,8 @@ bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......@@ -104,7 +82,7 @@ bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null.";
}
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_,
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
......
......@@ -44,14 +44,9 @@ class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
......
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -22,20 +21,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \
.binfile_name("minmax_update_perchannel.so") \
.compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \
.kernel_name("minmax_update_perchannel") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
......@@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel
.get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info)
def _fake_quant_min_max_per_channel_update_tbe():
"""FakeQuantPerChannelUpdate TBE register"""
@op_info_register(minmax_update_perchannel_op_info)
def _minmax_update_perchannel_tbe():
"""MinMaxUpdatePerChannel TBE register"""
return
@fusion_manager.register("fake_quant_min_max_per_channel_update")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerChannelUpdate compute"""
@fusion_manager.register("minmax_update_perchannel")
def minmax_update_perchannel_compute(x, min_val, max_val,
ema, ema_decay, channel_axis):
"""MinMaxUpdatePerChannel compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
# CalMinMax
if channel_axis == 0:
axis = [1, 2, 3, 4]
else:
axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
ema, ema_decay, channel_axis,
kernel_name="minmax_update_perchannel"):
"""MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape")
x_format = x.get("format")
x_dtype = x.get("dtype")
......@@ -108,21 +105,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
if channel_axis == 0:
shape_c = min_val.get("ori_shape")
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, channel_axis)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -22,20 +22,15 @@ from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \
.binfile_name("minmax_update_perlayer.so") \
.compute_cost(10) \
.kernel_name("fake_quant_minmax_update") \
.kernel_name("minmax_update_perlayer") \
.partial_flag(True) \
.attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \
......@@ -46,15 +41,14 @@ fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.get_op_info()
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
@op_info_register(minmax_update_perlayer_op_info)
def _minmax_update_perlayer_tbe():
"""MinMaxUpdatePerLayer TBE register"""
return
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
@fusion_manager.register("minmax_update_perlayer")
def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
"""MinMaxUpdatePerLayer compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
......@@ -62,28 +56,27 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema:
ema_decay = 0.0
if training:
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
# CalMinMax
axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min)
x_max = te.lang.cce.broadcast(x_max, shape_min)
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(
min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(
max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
min_val = te.lang.cce.vmins(min_val, 0)
max_val = te.lang.cce.vmaxs(max_val, 0)
return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantPerLayer op"""
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
ema, ema_decay, kernel_name="minmax_update_perlayer"):
"""MinMaxUpdatePerLayer op"""
input_shape = x.get("shape")
input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape")
......@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, kernel_name)
res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
with tvm.target.cce():
sch = generic.auto_schedule(res_list)
......
......@@ -21,12 +21,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype
__all__ = ["FakeQuantPerLayer",
__all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
......@@ -38,10 +38,128 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGradD",
"BatchNormFold2_D",
"BatchNormFold2GradD",
"BatchNormFold2GradReduce",
"BatchNormFold2GradReduce"
]
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantPerLayer(PrimitiveWithInfer):
r"""
Simulate the quantize and dequantize operations in training time.
......@@ -832,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init MinMaxUpdatePerLayer OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init MinMaxUpdatePerChannel OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册