提交 3a40ac65 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2435 fix perchannel num_channels not set bug and adjust quant.py params order

Merge pull request !2435 from 王东旭/r0.3
此差异已折叠。
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Generate bprop for aware quantization ops""" """Generate bprop for quantization aware ops"""
from .. import operations as P from .. import operations as P
from ..operations import _quant_ops as Q from ..operations import _quant_ops as Q
...@@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self): ...@@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
return bprop return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate) @bprop_getters.register(Q.MinMaxUpdatePerLayer)
def get_bprop_fakequant_with_minmax_per_layer_update(self): def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" """Generate bprop for MinMaxUpdatePerLayer for Ascend"""
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max) return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
...@@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self): ...@@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
return bprop return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate) @bprop_getters.register(Q.MinMaxUpdatePerChannel)
def get_bprop_fakequant_with_minmax_per_channel_update(self): def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" """Generate bprop for MinMaxUpdatePerChannel for Ascend"""
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max) return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
......
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,7 +13,7 @@ ...@@ -14,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op""" """MinMaxUpdatePerChannel op"""
import te.lang.cce import te.lang.cce
from te import tvm from te import tvm
from te.platform.fusion_manager import fusion_manager from te.platform.fusion_manager import fusion_manager
...@@ -22,20 +21,15 @@ from topi import generic ...@@ -22,20 +21,15 @@ from topi import generic
from topi.cce import util from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \ .binfile_name("minmax_update_perchannel.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \ .kernel_name("minmax_update_perchannel") \
.partial_flag(True) \ .partial_flag(True) \
.attr("ema", "optional", "bool", "all") \ .attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \ .attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \ .input(1, "min", None, "required", None) \
...@@ -47,24 +41,27 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan ...@@ -47,24 +41,27 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
.get_op_info() .get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info) @op_info_register(minmax_update_perchannel_op_info)
def _fake_quant_min_max_per_channel_update_tbe(): def _minmax_update_perchannel_tbe():
"""FakeQuantPerChannelUpdate TBE register""" """MinMaxUpdatePerChannel TBE register"""
return return
@fusion_manager.register("fake_quant_min_max_per_channel_update") @fusion_manager.register("minmax_update_perchannel")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, def minmax_update_perchannel_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis, ema, ema_decay, channel_axis):
kernel_name="fake_quant_min_max_per_channel_update"): """MinMaxUpdatePerChannel compute"""
"""FakeQuantPerChannelUpdate compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema: if not ema:
ema_decay = 0.0 ema_decay = 0.0
if training:
# CalMinMax # CalMinMax
if channel_axis == 0:
axis = [1, 2, 3, 4]
else:
axis = [0, 2, 3] axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis) x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis) x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min) x_min = te.lang.cce.broadcast(x_min, shape_min)
...@@ -79,11 +76,11 @@ def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, ...@@ -79,11 +76,11 @@ def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
return [min_val, max_val] return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) @util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, ema, ema_decay, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"): kernel_name="minmax_update_perchannel"):
"""FakeQuantPerLayer op""" """MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape") x_shape = x.get("ori_shape")
x_format = x.get("format") x_format = x.get("format")
x_dtype = x.get("dtype") x_dtype = x.get("dtype")
...@@ -112,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, ...@@ -112,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list) util.check_dtype_rule(max_dtype, check_list)
if symmetric: if channel_axis_ == 0:
quant_min = 0 - 2 ** (num_bits - 1) shape_c = min_val.get("ori_shape")
quant_max = 2 ** (num_bits - 1) - 1
else: else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name) ema, ema_decay, channel_axis_)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res_list) sch = generic.auto_schedule(res_list)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantMinMaxPerLayerUpdate op""" """MinMaxUpdatePerLayer op"""
from functools import reduce as functools_reduce from functools import reduce as functools_reduce
import te.lang.cce import te.lang.cce
from te import tvm from te import tvm
...@@ -22,20 +22,15 @@ from topi import generic ...@@ -22,20 +22,15 @@ from topi import generic
from topi.cce import util from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \ .binfile_name("minmax_update_perlayer.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_minmax_update") \ .kernel_name("minmax_update_perlayer") \
.partial_flag(True) \ .partial_flag(True) \
.attr("ema", "optional", "bool", "all") \ .attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \ .attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \ .input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \ .input(2, "max", None, "required", None) \
...@@ -46,23 +41,22 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ ...@@ -46,23 +41,22 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.get_op_info() .get_op_info()
@op_info_register(fake_quant_minmax_update_op_info) @op_info_register(minmax_update_perlayer_op_info)
def _fake_quant_minmax_update_tbe(): def _minmax_update_perlayer_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register""" """MinMaxUpdatePerLayer TBE register"""
return return
@fusion_manager.register("fake_quant_minmax_update") @fusion_manager.register("minmax_update_perlayer")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
kernel_name="fake_quant_minmax_update"): """MinMaxUpdatePerLayer compute"""
"""FakeQuantMinMaxPerLayerUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema: if not ema:
ema_decay = 0.0 ema_decay = 0.0
if training:
# CalMinMax # CalMinMax
axis = tuple(range(len(shape))) axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis) x_min = te.lang.cce.reduce_min(x, axis=axis)
...@@ -79,11 +73,10 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_ ...@@ -79,11 +73,10 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
return [min_val, max_val] return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) @util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, ema, ema_decay, kernel_name="minmax_update_perlayer"):
kernel_name="fake_quant_minmax_update"): """MinMaxUpdatePerLayer op"""
"""FakeQuantPerLayer op"""
input_shape = x.get("shape") input_shape = x.get("shape")
input_dtype = x.get("dtype") input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
...@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, ...@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape) shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res_list) sch = generic.auto_schedule(res_list)
......
...@@ -21,12 +21,12 @@ from ..._checkparam import Rel ...@@ -21,12 +21,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype from ...common import dtype as mstype
__all__ = ["FakeQuantPerLayer", __all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeQuantPerLayer",
"FakeQuantPerLayerGrad", "FakeQuantPerLayerGrad",
"FakeQuantPerChannel", "FakeQuantPerChannel",
"FakeQuantPerChannelGrad", "FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold", "BatchNormFold",
"BatchNormFoldGrad", "BatchNormFoldGrad",
"CorrectionMul", "CorrectionMul",
...@@ -36,23 +36,140 @@ __all__ = ["FakeQuantPerLayer", ...@@ -36,23 +36,140 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFold2Grad", "BatchNormFold2Grad",
"BatchNormFoldD", "BatchNormFoldD",
"BatchNormFoldGradD", "BatchNormFoldGradD",
"BNTrainingReduce",
"BatchNormFold2_D", "BatchNormFold2_D",
"BatchNormFold2GradD", "BatchNormFold2GradD",
"BatchNormFold2GradReduce", "BatchNormFold2GradReduce"
] ]
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantPerLayer(PrimitiveWithInfer): class FakeQuantPerLayer(PrimitiveWithInfer):
r""" r"""
Simulate the quantize and dequantize operations in training time. Simulate the quantize and dequantize operations in training time.
Args: Args:
num_bits (int) : Number bits for aware quantilization. Default: 8. num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False. ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware simulate quantization aware funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0. quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
...@@ -334,7 +451,7 @@ class BatchNormFold(PrimitiveWithInfer): ...@@ -334,7 +451,7 @@ class BatchNormFold(PrimitiveWithInfer):
Batch normalization folded. Batch normalization folded.
Args: Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1. momentum (float): Momentum value should be [0, 1]. Default: 0.9.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5. float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True. is_training (bool): In training mode set True, else set False. Default: True.
...@@ -366,7 +483,7 @@ class BatchNormFold(PrimitiveWithInfer): ...@@ -366,7 +483,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel_axis = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init batch norm fold layer""" """init batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
...@@ -731,32 +848,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer): ...@@ -731,32 +848,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
return x_type return x_type
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **x_sum** (Tensor) - Tensor has the same shape as x.
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
"""
@prim_attr_register
def __init__(self):
"""init _BNTrainingReduce layer"""
self.init_prim_io_names(inputs=['x'],
outputs=['x_sum', 'x_square_sum'])
def infer_shape(self, x_shape):
return [x_shape[1]], [x_shape[1]]
def infer_dtype(self, x_type):
return x_type, x_type
class BatchNormFold2_D(PrimitiveWithInfer): class BatchNormFold2_D(PrimitiveWithInfer):
""" """
Scale the bias with a correction factor to the long term statistics Scale the bias with a correction factor to the long term statistics
...@@ -859,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): ...@@ -859,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type): def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type) validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type return dout_type, dout_type
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册