提交 9958bc47 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2161 [qat]Export network from quantization aware network to deploy

Merge pull request !2161 from vlne-v1/I1IZV3-quant-infer
......@@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}));
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
return tensor;
}))
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.");
......
......@@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
// Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
......
......@@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation;
extern const PrimitivePtr kPrimZerosLike;
extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut;
extern const PrimitivePtr kPrimFakeQuantPerLayer;
extern const PrimitivePtr kPrimFakeQuantPerChannel;
// Other Miscellaneous
extern const PrimitivePtr kPrimIdentity;
......
......@@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get CNode Strategy Dictionary.")
.def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"),
"Get Allreduce Fusion Dictionary.")
.def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"),
"Fetch the inputs of Conv or Matmul for quant export.")
.def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"),
py::arg("broadcast_params") = py::dict(), "Build data graph.")
.def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.")
......
......@@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() {
ConfigManager::GetInstance().ResetConfig();
}
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
const std::string &phase_s) {
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
auto filter = [](AnfNodePtr node) {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul));
};
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
auto is_quant_cnode = [](AnfNodePtr node) {
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
};
for (auto node : nodes) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 3) {
continue;
}
auto x = cnode->input(1);
auto weight = cnode->input(2);
if (!is_quant_cnode(weight)) {
continue;
}
// get parameter weight's name
cnode = weight->cast<CNodePtr>();
auto weight_node = cnode->input(2);
if (!weight_node->isa<Parameter>()) {
continue;
}
auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input
int count = 0;
int max_depth = 5;
while (!is_quant_cnode(x)) {
if (count >= max_depth) {
break;
}
cnode = x->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() <= 1) {
break;
}
x = cnode->input(1);
count += 1;
}
// get the fakequant parameter minq's name
if (!is_quant_cnode(x)) {
continue;
}
cnode = x->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 4) {
continue;
}
auto fakequant_min_node = cnode->input(2);
if (!fakequant_min_node->isa<Parameter>()) {
continue;
}
auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
if (!quant_op_value->isa<PrimitivePy>()) {
continue;
}
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
}
return fake_quant_table;
}
void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
// save the graph to ExecutorPy
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
......
......@@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void ReleaseResource(const py::object &phase);
static void ClearRes();
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s);
private:
ExecutorPy();
void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors);
......
......@@ -39,6 +39,7 @@ namespace mindspore {
enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE };
using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>;
using FilterFunc = std::function<bool(const AnfNodePtr &)>;
using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>;
using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>;
......@@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
const FilterFunc &filter);
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
const IncludeFunc &include = AlwaysInclude);
......
......@@ -37,7 +37,8 @@ namespace mindspore {
namespace {
class DeepFirstSearcher : public AnfVisitor {
public:
explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {}
explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
: include_(include), filter_(filter) {}
~DeepFirstSearcher() override = default;
std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
......@@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor {
if (incl == EXCLUDE) {
return;
}
res_.push_back(node);
if (filter_ == nullptr || !filter_(node)) {
res_.push_back(node);
}
if (incl == FOLLOW) {
AnfVisitor::Visit(node);
}
......@@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor {
private:
size_t seen_{0};
IncludeFunc include_;
FilterFunc filter_;
std::vector<AnfNodePtr> res_{};
};
......@@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {
};
} // namespace
// include for if expand the node the search, filter for if put the node to results.
std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
return DeepScopedGraphSearcher(include).Search(root);
}
std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
const FilterFunc &filter) {
return DeepFirstSearcher(include, filter).Search(root);
}
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
return DeepUsedGraphSearcher(include).Search(root);
}
......
......@@ -526,6 +526,11 @@ class _Executor:
phase = 'export' + '.' + str(net.create_time)
export_graph(file_name, file_format, phase)
def fetch_info_for_quant_export(self, exec_id):
"""Get graph proto from pipeline."""
if self._executor.has_compiled(exec_id) is False:
return None
return self._executor.fetch_info_for_quant_export(exec_id)
_executor = _Executor()
_pynative_exec = _PynativeExecutor()
......
......@@ -18,8 +18,6 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore._checkparam import check_bool, check_typename
from mindspore._extends import cell_attr_register
......@@ -85,13 +83,12 @@ class _BatchNorm(Cell):
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
self.is_ge_backend = True
self.momentum = Tensor(1.0 - momentum, mstype.float32)
else:
self.is_ge_backend = False
self.momentum = 1.0 - momentum
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
......
......@@ -729,8 +729,8 @@ class DenseQuant(Cell):
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(
......@@ -738,7 +738,7 @@ class DenseQuant(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(
......@@ -780,8 +780,14 @@ class DenseQuant(Cell):
return str_info
class _QuantActivation(Cell):
r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
"""
def get_origin(self):
raise NotImplementedError
class ReLUQuant(Cell):
class ReLUQuant(_QuantActivation):
r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
......@@ -828,8 +834,11 @@ class ReLUQuant(Cell):
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu
class ReLU6Quant(Cell):
class ReLU6Quant(_QuantActivation):
r"""
ReLU6Quant activation function.
......@@ -878,8 +887,10 @@ class ReLU6Quant(Cell):
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu6
class HSwishQuant(Cell):
class HSwishQuant(_QuantActivation):
r"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
......@@ -935,8 +946,10 @@ class HSwishQuant(Cell):
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class HSigmoidQuant(Cell):
class HSigmoidQuant(_QuantActivation):
r"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
......@@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell):
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class TensorAddQuant(Cell):
r"""
......@@ -1083,3 +1098,77 @@ class MulQuant(Cell):
x = self.mul(x1, x2)
x = self.fake_quant_act(x)
return x
class QuantBlock(Cell):
r"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
core_op,
weight,
quant_op,
dequant_op,
dequant_scale,
bias=None,
activation=None):
super(QuantBlock, self).__init__()
self.core_op = core_op
self.weight = weight
self.quant = quant_op
self.dequant = dequant_op
self.dequant_scale = dequant_scale
self.bias = bias
self.has_bias = bias is None
self.activation = activation
self.has_act = activation is None
def construct(self, x):
x = self.quant(x)
x = self.core_op(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.has_act:
x = self.activation(x)
x = self.dequant(x, self.dequant_scale)
return x
def extend_repr(self):
str_info = f'quant={self.quant}, core_op={type(self.core_op)}'
if self.has_bias:
str_info = str_info + f', bias={self.bias}'
if self.has_act:
str_info = str_info + f', activation={self.activation}'
str_info = str_info + f', dequant={self.dequant}'
return str_info
......@@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x, y):
args = {"x": x, "y": y}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
if x.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x
......
......@@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer):
def infer_shape(self, x_shape, w_shape):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
......@@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer):
args = {'x': x_dtype, 'w': w_dtype}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype
......
......@@ -43,11 +43,12 @@ class Primitive(Primitive_):
>>> # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
"""
_repr_ignore_list = ['input_names', 'output_names']
def __init__(self, name):
self.name = name
self.attrs = {}
self.init_attrs = {}
self.init_attrs = {"name": name}
Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__)
......@@ -165,6 +166,16 @@ class Primitive(Primitive_):
def __setstate__(self, d):
self.__dict__.update(d)
def __deepcopy__(self, memo):
return type(self)(**self.init_attrs)
def __repr__(self):
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
info_str = f'Prim[{self.name}]'
if attr:
info_str += f'<{attr}>'
return info_str
def init_prim_io_names(self, inputs, outputs):
"""
Initializes inputs and outpus name of Tensor or attributes.
......@@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive):
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describle shape
and type infer logic. The infer_value() is used for constant propogation.
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape
and type infer logic. The infer_value() is used for constant propagation.
Args:
name (str): Name for current Primitive.
......@@ -288,6 +299,7 @@ def prim_attr_register(fn):
bound_args.apply_defaults()
arguments = bound_args.arguments
del arguments['self']
del self.init_attrs['name']
for name in arguments:
value = arguments[name]
self.add_prim_attr(name, value)
......
......@@ -14,12 +14,23 @@
# ============================================================================
"""aware quantization."""
import copy
import re
from ... import nn
from ... import ops
import numpy as np
from ... import log as logger
from ... import nn, ops
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...common import Tensor
from ...common import dtype as mstype
from ...common.api import _executor
from ...nn.layer import quant
from ...ops import functional as F
from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant,
......@@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.HSwish: quant.HSwishQuant}
class _AddFakeQuantInputOutput(nn.Cell):
class _AddFakeQuantInput(nn.Cell):
"""
Add FakeQuant at input and output of the Network. Only support one input and one output case.
"""
def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False)
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.network = network
self.fake_quant_input = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input')
self.fake_quant_output = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_output.update_parameters_name('fake_quant_output')
def construct(self, data):
data = self.fake_quant_input(data)
output = self.network(data)
output = self.fake_quant_output(output)
return output
......@@ -99,6 +106,8 @@ class ConvertToQuantNetwork:
self.per_channel = validator.check_bool("per channel", per_channel)
self.symmetric = validator.check_bool("symmetric", symmetric)
self.narrow_range = validator.check_bool("narrow range", narrow_range)
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
def _convert_op_name(self, name):
pattern = re.compile(r'([A-Z]{1})')
......@@ -110,6 +119,7 @@ class ConvertToQuantNetwork:
def run(self):
self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network)
network = _AddFakeQuantInput(network)
return network
def _convert_subcells2quant(self, network):
......@@ -122,15 +132,9 @@ class ConvertToQuantNetwork:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, quant.Conv2dBnAct):
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
prefix = subcell.param_prefix
new_subcell = self._convert_conv(subcell)
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
elif isinstance(subcell, quant.DenseBnAct):
prefix = subcell.param_prefix
new_subcell = self._convert_dense(subcell)
new_subcell = self._convert_method_map[type(subcell)](subcell)
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
......@@ -199,10 +203,12 @@ class ConvertToQuantNetwork:
symmetric=self.symmetric,
narrow_range=self.narrow_range)
subcell.conv = conv_inner
if subcell.activation is not None:
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
subcell = _AddFakeQuantAfterSubCell(subcell)
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
quant_delay=self.quant_delay)
return subcell
def _convert_dense(self, subcell):
......@@ -217,8 +223,12 @@ class ConvertToQuantNetwork:
per_channel=self.per_channel,
num_bits=self.weight_bits)
subcell.dense = dense_inner
if subcell.activation is not None:
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
quant_delay=self.quant_delay)
return subcell
def _convert_activation(self, activation):
......@@ -229,6 +239,147 @@ class ConvertToQuantNetwork:
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay)
class ExportQuantNetworkDeploy:
"""
Convert quantization aware network to deploy network.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
Returns:
Cell, converted network.
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self,
network,
*inputs):
network = validator.check_isinstance('network', network, (nn.Cell,))
self.data_type = mstype.int8
self.network = copy.deepcopy(network)
self.all_paramters = {p.name: p for p in self.network.get_parameters()}
self.get_inputs_table(inputs)
def get_inputs_table(self, inputs):
"""Get the support info for quant export."""
phase_name = 'export_quant'
graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
def run(self):
"""Start to convert."""
self.network.update_cell_prefix()
network = self.network
if isinstance(network, _AddFakeQuantInput):
network = network.network
network = self._convert_quant2deploy(network)
return network
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
"""convet network's quant subcell to deploy subcell"""
# Calculate the scale and zero point
w_minq_name = cell_core.fake_quant_weight.minq.name
np_type = mstype.dtype_to_nptype(self.data_type)
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
info = self.quant_info_table.get(w_minq_name, None)
if info:
fack_quant_a_in_op, minq_name = info
maxq = self.all_paramters[minq_name[:-4] + "maxq"]
minq = self.all_paramters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
else:
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
return None
# Build the `Quant` `Dequant` op.
# AscendQuant only support perlayer version. Need check here.
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
sqrt_mode = False
scale_deq = scale_a_out * scale_w
if scale_deq < 2 ** -14:
scale_deq = np.sqrt(scale_deq)
sqrt_mode = True
dequant_op = inner.AscendDequant(sqrt_mode)
# get op
op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv
if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell
elif hasattr(activation, "get_origin"):
activation = activation.get_origin()
# get the `weight` and `bias`
weight = cell_core.weight.data.asnumpy()
bias = None
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
if cell_core.has_bias:
bias = cell_core.bias.data.asnumpy()
elif isinstance(cell_core, quant.Conv2dBatchNormQuant):
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
# apply the quant
weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type)
if bias is not None:
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32)
scale_deq = Tensor(scale_deq, mstype.float16)
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
def _convert_quant2deploy(self, network):
"""Convet network's all quant subcell to deploy subcell."""
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
cell_core = None
fake_quant_act = None
activation = None
if isinstance(subcell, quant.Conv2dBnAct):
cell_core = subcell.conv
activation = subcell.activation
fake_quant_act = activation.fake_quant_act
elif isinstance(subcell, quant.DenseBnAct):
cell_core = subcell.dense
activation = subcell.activation
fake_quant_act = activation.fake_quant_act
if cell_core is not None:
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell:
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
op = subcell.subcell
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
network.__delattr__(name)
network.__setattr__(name, op)
change = True
else:
self._convert_quant2deploy(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
return network
def export_geir(network, *inputs, file_name):
"""
Exports MindSpore quant predict model to deploy with GEIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
file_name (str): File name of model to export.
"""
exporter = ExportQuantNetworkDeploy(network, *inputs)
deploy_net = exporter.run()
serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR")
def convert_quant_network(network,
quant_delay=0,
bn_fold=False,
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""quantization utils."""
"""Quantization utils."""
import numpy as np
......@@ -24,22 +24,19 @@ def cal_quantization_params(input_min,
symmetric=False,
narrow_range=False):
r"""
calculate quantization params for scale and zero point.
Calculate quantization params for scale and zero point.
Args:
input_min (int, list): The dimension of channel or 1.
input_max (int, list): The dimension of channel or 1.
input_min (numpy.ndarray): The dimension of channel or 1.
input_max (numpy.ndarray): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Outputs:
scale (int, list): quantization param.
zero point (int, list): quantization param.
Examples:
>>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False)
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
input_max = np.maximum(0.0, input_max)
input_min = np.minimum(0.0, input_min)
......@@ -92,27 +89,103 @@ def weight2int(data,
scale,
zero_point):
r"""
calculate int8/uint8 weight from fp32. the formula is defined as:
Calculate int8/uint8 weight from fp32. the formula is defined as:
.. math::
int8/uint8 = round(float/scale) + offset
Args:
data (int, list): The dimension of channel or 1. Should be NCHW.
scale (int, list): The dimension of channel or 1.
zero_point (int, list): The dimension of channel or 1.
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
scale (numpy.ndarray): The dimension of channel or 1.
zero_point (numpy.ndarray): The dimension of channel or 1.
Outputs:
weight (int, list): The dimension of channel or 1.
Examples:
>>> weight = weight2int([1, 2, 1], 1, 0)
Returns:
weight (numpy.ndarray): The dimension of channel or 1.
"""
if scale.shape != zero_point.shape:
raise ValueError("scale and zero_point should have the same shape.")
if scale.shape[0] > 0:
scale = scale.reshape(1, -1, 1, 1)
zero_point = zero_point.reshape(1, -1, 1, 1)
scale = scale.reshape(1, -1)
zero_point = zero_point.reshape(1, -1)
return np.round((data/scale) + zero_point)
def scale_zp_from_fack_quant_cell(cell, data_type):
r"""
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
Args:
cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq = cell.minq.data.asnumpy()
maxq = cell.maxq.data.asnumpy()
op = cell.fake_quant
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp
def scale_zp_from_data(op, minq, maxq, data_type):
r"""
Get calculate quantization params for scale and zero point.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
`mindspore.ops.operation.FakeQuantPerChannel`
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq = minq.data.asnumpy()
maxq = maxq.data.asnumpy()
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp
def fold_batchnorm(weight, cell_quant):
r"""
Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
weight (numpy.ndarray): Weight of `cell_quant`.
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
Returns:
weight (numpy.ndarray): Folded weight.
bias (numpy.ndarray): Folded bias.
"""
variance = cell_quant.moving_variance.data.asnumpy()
mean = cell_quant.moving_mean.data.asnumpy()
gamma = cell_quant.gamma.data.asnumpy()
beta = cell_quant.beta.data.asnumpy()
epsilon = cell_quant.eps
sigma = np.sqrt(variance + epsilon)
gamma = gamma.reshape(-1, 1, 1, 1)
sigma = sigma.reshape(-1, 1, 1, 1)
mean = mean.reshape(-1, 1, 1, 1)
weight = weight * gamma / sigma
bias = beta - gamma * mean / sigma
return weight, bias
......@@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'):
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype))
class ModelCallback(Callback):
def __init__(self):
......
......@@ -13,9 +13,14 @@
# limitations under the License.
# ============================================================================
""" tests for quant """
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.quant import quant as qat
from mobilenetv2_combined import MobileNetV2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
......@@ -37,23 +42,45 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6')
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu')
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
self.fc3 = nn.DenseBnAct(84, self.num_class)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flattern = nn.Flatten()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.max_pool2d(x)
x = self.flattern(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_lenet():
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
net = LeNet5()
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
# should load the checkpoint. mock here
for param in net.get_parameters():
param.init_data()
qat.export_geir(net, img, file_name="quant.pb")
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_mobile():
net = MobileNetV2()
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8)
# should load the checkpoint. mock here
for param in net.get_parameters():
param.init_data()
qat.export_geir(net, img, file_name="quant.pb")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册