提交 31ecc13b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1900 bug fix in fake quant ops

Merge pull request !1900 from chenzhongming/master
......@@ -15,6 +15,7 @@
"""Operators for quantization."""
import mindspore.context as context
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
......@@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
narrow_range=False,
training=True):
"""init FakeQuantPerLayer OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
......@@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
quant_delay=0,
symmetric=False,
narrow_range=False):
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
......@@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' Attr \'num_bits\' is not support.")
......@@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
narrow_range=False,
channel_axis=1):
"""init FakeQuantPerChannelGrad Fill"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
......@@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
outputs=['out'])
......@@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul_grad
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
outputs=['dx', 'd_gamma'])
......@@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
......@@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册