提交 93e27f03 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2917 bug fix in quantization aware training auto create graph

Merge pull request !2917 from chenzhongming/master
......@@ -855,7 +855,7 @@ class ActQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = activation
self.act = activation()
def construct(self, x):
x = self.act(x)
......@@ -921,7 +921,7 @@ class HSwishQuant(_QuantActivation):
narrow_range=narrow_range,
quant_delay=quant_delay)
if isinstance(activation, nn.HSwish):
self.act = activation
self.act = activation()
else:
raise ValueError("Activation should be `nn.HSwish`")
......@@ -990,7 +990,7 @@ class HSigmoidQuant(_QuantActivation):
narrow_range=narrow_range,
quant_delay=quant_delay)
if isinstance(activation, nn.HSwish):
self.act = activation
self.act = activation()
else:
raise ValueError("Activation should be `nn.HSigmoid`")
......
......@@ -114,7 +114,6 @@ class ConvertToQuantNetwork:
def run(self):
self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network)
network = _AddFakeQuantInput(network)
self.network.update_cell_type("quant")
return network
......@@ -275,16 +274,20 @@ class ExportToQuantInferNetwork:
Args:
network (Cell): MindSpore network API `convert_quant_network`.
inputs (Tensor): Input tensors of the `quantization aware training network`.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
Returns:
Cell, GEIR backend Infer network.
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self,
network,
*inputs):
def __init__(self, network, mean, std_dev, *inputs):
network = validator.check_isinstance('network', network, (nn.Cell,))
# quantize for inputs: q = f / scale + zero_point
# dequantize for outputs: f = (q - zero_point) * scale
self.input_scale = round(mean)
self.input_zero_point = 1 / std_dev
self.data_type = mstype.int8
self.network = copy.deepcopy(network)
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
......@@ -395,7 +398,7 @@ class ExportToQuantInferNetwork:
return network
def export(network, *inputs, file_name, file_format='GEIR'):
def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='GEIR'):
"""
Exports MindSpore quantization predict model to deploy with GEIR.
......@@ -403,12 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `quantization aware training network`.
file_name (str): File name of model to export.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
"""
supported_device = ["Ascend"]
supported_formats = ['GEIR']
mean = validator.check_type("mean", mean, (int, float))
std_dev = validator.check_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
......@@ -418,7 +426,7 @@ def export(network, *inputs, file_name, file_format='GEIR'):
network.set_train(False)
if file_format == 'GEIR':
exporter = ExportToQuantInferNetwork(network, *inputs)
exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run()
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册