Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9958bc47
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9958bc47
编写于
6月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2161 [qat]Export network from quantization aware network to deploy
Merge pull request !2161 from vlne-v1/I1IZV3-quant-infer
上级
06b511ca
1d77bf86
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
529 addition
and
68 deletion
+529
-68
mindspore/ccsrc/ir/tensor.cc
mindspore/ccsrc/ir/tensor.cc
+13
-0
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+2
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+2
-0
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+2
-0
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+69
-0
mindspore/ccsrc/pipeline/pipeline.h
mindspore/ccsrc/pipeline/pipeline.h
+2
-0
mindspore/ccsrc/utils/graph_utils.h
mindspore/ccsrc/utils/graph_utils.h
+4
-0
mindspore/ccsrc/utils/graph_utils_extends.cc
mindspore/ccsrc/utils/graph_utils_extends.cc
+12
-3
mindspore/common/api.py
mindspore/common/api.py
+5
-0
mindspore/nn/layer/normalization.py
mindspore/nn/layer/normalization.py
+2
-5
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+96
-7
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+3
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+15
-3
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+170
-19
mindspore/train/quant/quant_utils.py
mindspore/train/quant/quant_utils.py
+95
-22
tests/st/model_zoo_tests/yolov3/test_yolov3.py
tests/st/model_zoo_tests/yolov3/test_yolov3.py
+1
-1
tests/ut/python/train/quant/test_quant.py
tests/ut/python/train/quant/test_quant.py
+34
-7
未找到文件。
mindspore/ccsrc/ir/tensor.cc
浏览文件 @
9958bc47
...
...
@@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}));
(
void
)
py
::
class_
<
MetaTensor
,
std
::
shared_ptr
<
MetaTensor
>>
(
*
m
,
"MetaTensor"
)
.
def
(
py
::
init
<
TypePtr
,
const
std
::
vector
<
int
>>
(),
py
::
arg
(
"dtype"
),
py
::
arg
(
"shape"
))
.
def
(
py
::
pickle
(
[](
const
MetaTensor
&
t
)
{
// __getstate__
/* Return a tuple that fully encodes the state of the object */
return
py
::
make_tuple
(
static_cast
<
int
>
(
t
.
data_type
()),
t
.
shape
());
},
[](
const
py
::
tuple
&
t
)
{
// __setstate__
if
(
t
.
size
()
!=
2
)
{
throw
std
::
runtime_error
(
"Invalid state!"
);
}
/* Create a new C++ instance */
MetaTensor
tensor
(
TypeId
(
t
[
0
].
cast
<
int
>
()),
t
[
1
].
cast
<
std
::
vector
<
int
>>
());
return
tensor
;
}))
.
def_readonly
(
PYTHON_META_TENSOR_FLAG
,
&
MetaTensor
::
parse_info_
)
.
def_property_readonly
(
"dtype"
,
&
MetaTensor
::
Dtype
,
"Get the MetaTensor's dtype."
)
.
def_property_readonly
(
"shape"
,
&
MetaTensor
::
shape
,
"Get the MetaTensor's shape."
);
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
9958bc47
...
...
@@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const
PrimitivePtr
kPrimZerosLike
=
std
::
make_shared
<
Primitive
>
(
"ZerosLike"
);
const
PrimitivePtr
kPrimFakeBprop
=
std
::
make_shared
<
Primitive
>
(
"fake_bprop"
);
const
PrimitivePtr
kPrimBpropCut
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
const
PrimitivePtr
kPrimFakeQuantPerLayer
=
std
::
make_shared
<
Primitive
>
(
"FakeQuantPerLayer"
);
const
PrimitivePtr
kPrimFakeQuantPerChannel
=
std
::
make_shared
<
Primitive
>
(
"FakeQuantPerChannel"
);
// Other miscellaneous
const
PrimitivePtr
kPrimIdentity
=
std
::
make_shared
<
Primitive
>
(
"identity"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
9958bc47
...
...
@@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation;
extern
const
PrimitivePtr
kPrimZerosLike
;
extern
const
PrimitivePtr
kPrimFakeBprop
;
extern
const
PrimitivePtr
kPrimBpropCut
;
extern
const
PrimitivePtr
kPrimFakeQuantPerLayer
;
extern
const
PrimitivePtr
kPrimFakeQuantPerChannel
;
// Other Miscellaneous
extern
const
PrimitivePtr
kPrimIdentity
;
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
9958bc47
...
...
@@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get CNode Strategy Dictionary."
)
.
def
(
"get_allreduce_fusion"
,
&
ExecutorPy
::
GetAllreduceFusion
,
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
"Get Allreduce Fusion Dictionary."
)
.
def
(
"fetch_info_for_quant_export"
,
&
ExecutorPy
::
FetchInfoForQuantExport
,
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
"Fetch the inputs of Conv or Matmul for quant export."
)
.
def
(
"build_data_graph"
,
&
ExecutorPy
::
BuildGraph
,
py
::
arg
(
"build_params"
),
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
py
::
arg
(
"broadcast_params"
)
=
py
::
dict
(),
"Build data graph."
)
.
def
(
"has_compiled"
,
&
ExecutorPy
::
HasCompiled
,
py
::
arg
(
"phase"
)
=
py
::
str
(
""
),
"get if cell compiled."
)
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
9958bc47
...
...
@@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() {
ConfigManager
::
GetInstance
().
ResetConfig
();
}
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
ExecutorPy
::
FetchInfoForQuantExport
(
const
std
::
string
&
phase_s
)
{
FuncGraphPtr
func_graph
=
info_
[
phase_s
]
->
resource
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_LOG
(
DEBUG
)
<<
"FetchInfoForQuantExport func graph("
<<
func_graph
->
ToString
()
<<
") phase("
<<
phase_s
<<
")!"
;
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
fake_quant_table
;
auto
filter
=
[](
AnfNodePtr
node
)
{
return
!
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimConv2D
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimMatMul
));
};
std
::
vector
<
AnfNodePtr
>
nodes
=
DeepScopedGraphSearchWithFilter
(
func_graph
->
get_return
(),
AlwaysInclude
,
filter
);
auto
is_quant_cnode
=
[](
AnfNodePtr
node
)
{
return
IsPrimitiveCNode
(
node
,
prim
::
kPrimFakeQuantPerLayer
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimFakeQuantPerChannel
);
};
for
(
auto
node
:
nodes
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
!=
3
)
{
continue
;
}
auto
x
=
cnode
->
input
(
1
);
auto
weight
=
cnode
->
input
(
2
);
if
(
!
is_quant_cnode
(
weight
))
{
continue
;
}
// get parameter weight's name
cnode
=
weight
->
cast
<
CNodePtr
>
();
auto
weight_node
=
cnode
->
input
(
2
);
if
(
!
weight_node
->
isa
<
Parameter
>
())
{
continue
;
}
auto
weight_name
=
weight_node
->
cast
<
ParameterPtr
>
()
->
name
();
// find the fakequant from input
int
count
=
0
;
int
max_depth
=
5
;
while
(
!
is_quant_cnode
(
x
))
{
if
(
count
>=
max_depth
)
{
break
;
}
cnode
=
x
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
<=
1
)
{
break
;
}
x
=
cnode
->
input
(
1
);
count
+=
1
;
}
// get the fakequant parameter minq's name
if
(
!
is_quant_cnode
(
x
))
{
continue
;
}
cnode
=
x
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
!=
4
)
{
continue
;
}
auto
fakequant_min_node
=
cnode
->
input
(
2
);
if
(
!
fakequant_min_node
->
isa
<
Parameter
>
())
{
continue
;
}
auto
fakequant_min_node_name
=
fakequant_min_node
->
cast
<
ParameterPtr
>
()
->
name
();
auto
quant_op_value
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
!
quant_op_value
->
isa
<
PrimitivePy
>
())
{
continue
;
}
auto
quant_op
=
quant_op_value
->
cast
<
PrimitivePyPtr
>
();
fake_quant_table
[
weight_name
]
=
std
::
make_pair
(
quant_op
,
fakequant_min_node_name
);
}
return
fake_quant_table
;
}
void
ExecutorPy
::
SaveCompiledGraph
(
const
std
::
string
&
phase_s
)
{
// save the graph to ExecutorPy
FuncGraphPtr
func_graph
=
info_
[
phase_s
]
->
resource
->
func_graph
();
...
...
mindspore/ccsrc/pipeline/pipeline.h
浏览文件 @
9958bc47
...
...
@@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void
ReleaseResource
(
const
py
::
object
&
phase
);
static
void
ClearRes
();
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
FetchInfoForQuantExport
(
const
std
::
string
&
phase_s
);
private:
ExecutorPy
();
void
ConvertObjectToTensors
(
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
*
tensors
);
...
...
mindspore/ccsrc/utils/graph_utils.h
浏览文件 @
9958bc47
...
...
@@ -39,6 +39,7 @@ namespace mindspore {
enum
IncludeType
{
FOLLOW
,
NOFOLLOW
,
EXCLUDE
};
using
IncludeFunc
=
std
::
function
<
IncludeType
(
const
AnfNodePtr
&
)
>
;
using
FilterFunc
=
std
::
function
<
bool
(
const
AnfNodePtr
&
)
>
;
using
SuccFunc
=
std
::
function
<
std
::
vector
<
AnfNodePtr
>
(
AnfNodePtr
)
>
;
using
SearchFunc
=
std
::
function
<
std
::
vector
<
AnfNodePtr
>
(
const
AnfNodePtr
&
,
const
IncludeFunc
&
)
>
;
...
...
@@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl
std
::
vector
<
AnfNodePtr
>
DeepUsedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
AnfNodePtr
>
DeepLinkedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearchWithFilter
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
,
const
FilterFunc
&
filter
);
std
::
vector
<
AnfNodePtr
>
TopoSort
(
const
AnfNodePtr
&
root
,
const
SuccFunc
&
succ
=
SuccIncoming
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
...
...
mindspore/ccsrc/utils/graph_utils_extends.cc
浏览文件 @
9958bc47
...
...
@@ -37,7 +37,8 @@ namespace mindspore {
namespace
{
class
DeepFirstSearcher
:
public
AnfVisitor
{
public:
explicit
DeepFirstSearcher
(
const
IncludeFunc
&
include
)
:
include_
(
include
)
{}
explicit
DeepFirstSearcher
(
const
IncludeFunc
&
include
,
const
FilterFunc
&
filter
=
nullptr
)
:
include_
(
include
),
filter_
(
filter
)
{}
~
DeepFirstSearcher
()
override
=
default
;
std
::
vector
<
AnfNodePtr
>
Search
(
const
AnfNodePtr
&
root
)
{
...
...
@@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor {
if
(
incl
==
EXCLUDE
)
{
return
;
}
res_
.
push_back
(
node
);
if
(
filter_
==
nullptr
||
!
filter_
(
node
))
{
res_
.
push_back
(
node
);
}
if
(
incl
==
FOLLOW
)
{
AnfVisitor
::
Visit
(
node
);
}
...
...
@@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor {
private:
size_t
seen_
{
0
};
IncludeFunc
include_
;
FilterFunc
filter_
;
std
::
vector
<
AnfNodePtr
>
res_
{};
};
...
...
@@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {
};
}
// namespace
// include for if expand the node the search, filter for if put the node to results.
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
)
{
return
DeepScopedGraphSearcher
(
include
).
Search
(
root
);
}
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearchWithFilter
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
,
const
FilterFunc
&
filter
)
{
return
DeepFirstSearcher
(
include
,
filter
).
Search
(
root
);
}
std
::
vector
<
AnfNodePtr
>
DeepUsedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
)
{
return
DeepUsedGraphSearcher
(
include
).
Search
(
root
);
}
...
...
mindspore/common/api.py
浏览文件 @
9958bc47
...
...
@@ -526,6 +526,11 @@ class _Executor:
phase
=
'export'
+
'.'
+
str
(
net
.
create_time
)
export_graph
(
file_name
,
file_format
,
phase
)
def
fetch_info_for_quant_export
(
self
,
exec_id
):
"""Get graph proto from pipeline."""
if
self
.
_executor
.
has_compiled
(
exec_id
)
is
False
:
return
None
return
self
.
_executor
.
fetch_info_for_quant_export
(
exec_id
)
_executor
=
_Executor
()
_pynative_exec
=
_PynativeExecutor
()
...
...
mindspore/nn/layer/normalization.py
浏览文件 @
9958bc47
...
...
@@ -18,8 +18,6 @@ from mindspore.ops import functional as F
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.ops.primitive
import
constexpr
from
mindspore.common.tensor
import
Tensor
import
mindspore.common.dtype
as
mstype
import
mindspore.context
as
context
from
mindspore._checkparam
import
check_bool
,
check_typename
from
mindspore._extends
import
cell_attr_register
...
...
@@ -85,13 +83,12 @@ class _BatchNorm(Cell):
self
.
reshape
=
P
.
Reshape
()
self
.
is_ascend
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
self
.
is_graph_mode
=
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
self
.
momentum
=
1.0
-
momentum
if
context
.
get_context
(
"enable_ge"
):
self
.
is_ge_backend
=
True
self
.
momentum
=
Tensor
(
1.0
-
momentum
,
mstype
.
float32
)
else
:
self
.
is_ge_backend
=
False
self
.
momentum
=
1.0
-
momentum
if
self
.
is_graph_mode
and
(
self
.
is_ge_backend
or
self
.
is_ascend
):
self
.
bn_train
=
P
.
BatchNorm
(
is_training
=
True
,
epsilon
=
self
.
eps
)
...
...
mindspore/nn/layer/quant.py
浏览文件 @
9958bc47
...
...
@@ -729,8 +729,8 @@ class DenseQuant(Cell):
self
.
has_bias
=
check_bool
(
has_bias
)
if
isinstance
(
weight_init
,
Tensor
):
if
weight_init
.
dim
()
!=
2
or
weight_init
.
shape
()
[
0
]
!=
out_channels
or
\
weight_init
.
shape
()
[
1
]
!=
in_channels
:
if
weight_init
.
dim
()
!=
2
or
weight_init
.
shape
[
0
]
!=
out_channels
or
\
weight_init
.
shape
[
1
]
!=
in_channels
:
raise
ValueError
(
"weight_init shape error"
)
self
.
weight
=
Parameter
(
initializer
(
...
...
@@ -738,7 +738,7 @@ class DenseQuant(Cell):
if
self
.
has_bias
:
if
isinstance
(
bias_init
,
Tensor
):
if
bias_init
.
dim
()
!=
1
or
bias_init
.
shape
()
[
0
]
!=
out_channels
:
if
bias_init
.
dim
()
!=
1
or
bias_init
.
shape
[
0
]
!=
out_channels
:
raise
ValueError
(
"bias_init shape error"
)
self
.
bias
=
Parameter
(
initializer
(
...
...
@@ -780,8 +780,14 @@ class DenseQuant(Cell):
return
str_info
class
_QuantActivation
(
Cell
):
r
"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
"""
def
get_origin
(
self
):
raise
NotImplementedError
class
ReLUQuant
(
Cell
):
class
ReLUQuant
(
_QuantActivation
):
r
"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
...
...
@@ -828,8 +834,11 @@ class ReLUQuant(Cell):
x
=
self
.
fake_quant_act
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
relu
class
ReLU6Quant
(
Cell
):
class
ReLU6Quant
(
_QuantActivation
):
r
"""
ReLU6Quant activation function.
...
...
@@ -878,8 +887,10 @@ class ReLU6Quant(Cell):
x
=
self
.
fake_quant_act
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
relu6
class
HSwishQuant
(
Cell
):
class
HSwishQuant
(
_QuantActivation
):
r
"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
...
...
@@ -935,8 +946,10 @@ class HSwishQuant(Cell):
x
=
self
.
fake_quant_act_after
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
act
class
HSigmoidQuant
(
Cell
):
class
HSigmoidQuant
(
_QuantActivation
):
r
"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
...
...
@@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell):
x
=
self
.
fake_quant_act_after
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
act
class
TensorAddQuant
(
Cell
):
r
"""
...
...
@@ -1083,3 +1098,77 @@ class MulQuant(Cell):
x
=
self
.
mul
(
x1
,
x2
)
x
=
self
.
fake_quant_act
(
x
)
return
x
class
QuantBlock
(
Cell
):
r
"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def
__init__
(
self
,
core_op
,
weight
,
quant_op
,
dequant_op
,
dequant_scale
,
bias
=
None
,
activation
=
None
):
super
(
QuantBlock
,
self
).
__init__
()
self
.
core_op
=
core_op
self
.
weight
=
weight
self
.
quant
=
quant_op
self
.
dequant
=
dequant_op
self
.
dequant_scale
=
dequant_scale
self
.
bias
=
bias
self
.
has_bias
=
bias
is
None
self
.
activation
=
activation
self
.
has_act
=
activation
is
None
def
construct
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
core_op
(
x
,
self
.
weight
)
if
self
.
has_bias
:
output
=
self
.
bias_add
(
output
,
self
.
bias
)
if
self
.
has_act
:
x
=
self
.
activation
(
x
)
x
=
self
.
dequant
(
x
,
self
.
dequant_scale
)
return
x
def
extend_repr
(
self
):
str_info
=
f
'quant=
{
self
.
quant
}
, core_op=
{
type
(
self
.
core_op
)
}
'
if
self
.
has_bias
:
str_info
=
str_info
+
f
', bias=
{
self
.
bias
}
'
if
self
.
has_act
:
str_info
=
str_info
+
f
', activation=
{
self
.
activation
}
'
str_info
=
str_info
+
f
', dequant=
{
self
.
dequant
}
'
return
str_info
mindspore/ops/operations/math_ops.py
浏览文件 @
9958bc47
...
...
@@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x
,
y
):
args
=
{
"x"
:
x
,
"y"
:
y
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
name
)
if
x
.
element_type
()
==
mstype
.
int8
:
return
mstype
.
tensor_type
(
mstype
.
int32
)
return
x
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
9958bc47
...
...
@@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer):
def
infer_shape
(
self
,
x_shape
,
w_shape
):
validator
.
check_integer
(
"weight rank"
,
len
(
w_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"x_shape[1] / group"
,
x_shape
[
1
]
//
self
.
group
,
"w_shape[1]"
,
w_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
f
"x_shape[1] / group"
,
x_shape
[
1
]
//
self
.
group
,
"w_shape[1]"
,
w_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'out_channel'
,
self
.
out_channel
,
'w_shape[0]'
,
w_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'kernel_size'
,
self
.
kernel_size
,
'w_shape[2:4]'
,
tuple
(
w_shape
[
2
:
4
]),
Rel
.
EQ
,
self
.
name
)
...
...
@@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer):
args
=
{
'x'
:
x_dtype
,
'w'
:
w_dtype
}
valid_types
=
[
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
validator
.
check_tensor_type_same
(
args
,
valid_types
,
self
.
name
)
if
x_dtype
.
element_type
()
==
mstype
.
int8
:
return
mstype
.
tensor_type
(
mstype
.
int32
)
return
x_dtype
...
...
mindspore/ops/primitive.py
浏览文件 @
9958bc47
...
...
@@ -43,11 +43,12 @@ class Primitive(Primitive_):
>>> # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
"""
_repr_ignore_list
=
[
'input_names'
,
'output_names'
]
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
attrs
=
{}
self
.
init_attrs
=
{}
self
.
init_attrs
=
{
"name"
:
name
}
Primitive_
.
__init__
(
self
,
name
,
self
)
if
hasattr
(
self
.
__class__
,
'__mindspore_signature__'
):
sig
=
self
.
_fill_signature
(
self
.
__class__
.
__mindspore_signature__
)
...
...
@@ -165,6 +166,16 @@ class Primitive(Primitive_):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
def
__deepcopy__
(
self
,
memo
):
return
type
(
self
)(
**
self
.
init_attrs
)
def
__repr__
(
self
):
attr
=
', '
.
join
([
f
'
{
k
}
=
{
self
.
attrs
[
k
]
}
'
for
k
in
self
.
attrs
if
not
k
in
Primitive
.
_repr_ignore_list
])
info_str
=
f
'Prim[
{
self
.
name
}
]'
if
attr
:
info_str
+=
f
'<
{
attr
}
>'
return
info_str
def
init_prim_io_names
(
self
,
inputs
,
outputs
):
"""
Initializes inputs and outpus name of Tensor or attributes.
...
...
@@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive):
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describ
l
e shape
and type infer logic. The infer_value() is used for constant prop
o
gation.
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape
and type infer logic. The infer_value() is used for constant prop
a
gation.
Args:
name (str): Name for current Primitive.
...
...
@@ -288,6 +299,7 @@ def prim_attr_register(fn):
bound_args
.
apply_defaults
()
arguments
=
bound_args
.
arguments
del
arguments
[
'self'
]
del
self
.
init_attrs
[
'name'
]
for
name
in
arguments
:
value
=
arguments
[
name
]
self
.
add_prim_attr
(
name
,
value
)
...
...
mindspore/train/quant/quant.py
浏览文件 @
9958bc47
...
...
@@ -14,12 +14,23 @@
# ============================================================================
"""aware quantization."""
import
copy
import
re
from
...
import
nn
from
...
import
ops
import
numpy
as
np
from
...
import
log
as
logger
from
...
import
nn
,
ops
from
..._checkparam
import
ParamValidator
as
validator
from
..._checkparam
import
Rel
from
...common
import
Tensor
from
...common
import
dtype
as
mstype
from
...common.api
import
_executor
from
...nn.layer
import
quant
from
...ops
import
functional
as
F
from
...ops.operations
import
_inner_ops
as
inner
from
...train
import
serialization
from
.
import
quant_utils
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
...
...
@@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn
.
HSwish
:
quant
.
HSwishQuant
}
class
_AddFakeQuantInput
Output
(
nn
.
Cell
):
class
_AddFakeQuantInput
(
nn
.
Cell
):
"""
Add FakeQuant at input and output of the Network. Only support one input and one output case.
"""
def
__init__
(
self
,
network
,
quant_delay
=
0
):
super
(
_AddFakeQuantInput
Output
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_AddFakeQuantInput
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
fake_quant_input
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input'
)
self
.
fake_quant_output
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_output
.
update_parameters_name
(
'fake_quant_output'
)
def
construct
(
self
,
data
):
data
=
self
.
fake_quant_input
(
data
)
output
=
self
.
network
(
data
)
output
=
self
.
fake_quant_output
(
output
)
return
output
...
...
@@ -99,6 +106,8 @@ class ConvertToQuantNetwork:
self
.
per_channel
=
validator
.
check_bool
(
"per channel"
,
per_channel
)
self
.
symmetric
=
validator
.
check_bool
(
"symmetric"
,
symmetric
)
self
.
narrow_range
=
validator
.
check_bool
(
"narrow range"
,
narrow_range
)
self
.
_convert_method_map
=
{
quant
.
Conv2dBnAct
:
self
.
_convert_conv
,
quant
.
DenseBnAct
:
self
.
_convert_dense
}
def
_convert_op_name
(
self
,
name
):
pattern
=
re
.
compile
(
r
'([A-Z]{1})'
)
...
...
@@ -110,6 +119,7 @@ class ConvertToQuantNetwork:
def
run
(
self
):
self
.
network
.
update_cell_prefix
()
network
=
self
.
_convert_subcells2quant
(
self
.
network
)
network
=
_AddFakeQuantInput
(
network
)
return
network
def
_convert_subcells2quant
(
self
,
network
):
...
...
@@ -122,15 +132,9 @@ class ConvertToQuantNetwork:
subcell
=
cells
[
name
]
if
subcell
==
network
:
continue
elif
isinstance
(
subcell
,
quant
.
Conv2dBnAct
):
elif
isinstance
(
subcell
,
(
quant
.
Conv2dBnAct
,
quant
.
DenseBnAct
)
):
prefix
=
subcell
.
param_prefix
new_subcell
=
self
.
_convert_conv
(
subcell
)
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
change
=
True
elif
isinstance
(
subcell
,
quant
.
DenseBnAct
):
prefix
=
subcell
.
param_prefix
new_subcell
=
self
.
_convert_dense
(
subcell
)
new_subcell
=
self
.
_convert_method_map
[
type
(
subcell
)](
subcell
)
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
change
=
True
...
...
@@ -199,10 +203,12 @@ class ConvertToQuantNetwork:
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
)
subcell
.
conv
=
conv_inner
if
subcell
.
activation
is
not
None
:
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
else
:
subcell
=
_AddFakeQuantAfterSubCell
(
subcell
)
subcell
.
has_act
=
True
subcell
.
activation
=
_AddFakeQuantAfterSubCell
(
F
.
identity
,
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
return
subcell
def
_convert_dense
(
self
,
subcell
):
...
...
@@ -217,8 +223,12 @@ class ConvertToQuantNetwork:
per_channel
=
self
.
per_channel
,
num_bits
=
self
.
weight_bits
)
subcell
.
dense
=
dense_inner
if
subcell
.
activation
is
not
None
:
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
else
:
subcell
.
has_act
=
True
subcell
.
activation
=
_AddFakeQuantAfterSubCell
(
F
.
identity
,
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
return
subcell
def
_convert_activation
(
self
,
activation
):
...
...
@@ -229,6 +239,147 @@ class ConvertToQuantNetwork:
return
_ACTIVATION_MAP
[
act_class
](
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
class
ExportQuantNetworkDeploy
:
"""
Convert quantization aware network to deploy network.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
Returns:
Cell, converted network.
"""
__quant_op_name__
=
[
"TensorAdd"
,
"Sub"
,
"Mul"
,
"RealDiv"
]
def
__init__
(
self
,
network
,
*
inputs
):
network
=
validator
.
check_isinstance
(
'network'
,
network
,
(
nn
.
Cell
,))
self
.
data_type
=
mstype
.
int8
self
.
network
=
copy
.
deepcopy
(
network
)
self
.
all_paramters
=
{
p
.
name
:
p
for
p
in
self
.
network
.
get_parameters
()}
self
.
get_inputs_table
(
inputs
)
def
get_inputs_table
(
self
,
inputs
):
"""Get the support info for quant export."""
phase_name
=
'export_quant'
graph_id
,
_
=
_executor
.
compile
(
self
.
network
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
self
.
quant_info_table
=
_executor
.
fetch_info_for_quant_export
(
graph_id
)
def
run
(
self
):
"""Start to convert."""
self
.
network
.
update_cell_prefix
()
network
=
self
.
network
if
isinstance
(
network
,
_AddFakeQuantInput
):
network
=
network
.
network
network
=
self
.
_convert_quant2deploy
(
network
)
return
network
def
_get_quant_block
(
self
,
cell_core
,
activation
,
fake_quant_a_out
):
"""convet network's quant subcell to deploy subcell"""
# Calculate the scale and zero point
w_minq_name
=
cell_core
.
fake_quant_weight
.
minq
.
name
np_type
=
mstype
.
dtype_to_nptype
(
self
.
data_type
)
scale_w
,
zp_w
=
quant_utils
.
scale_zp_from_fack_quant_cell
(
cell_core
.
fake_quant_weight
,
np_type
)
scale_a_out
,
_
=
quant_utils
.
scale_zp_from_fack_quant_cell
(
fake_quant_a_out
,
np_type
)
info
=
self
.
quant_info_table
.
get
(
w_minq_name
,
None
)
if
info
:
fack_quant_a_in_op
,
minq_name
=
info
maxq
=
self
.
all_paramters
[
minq_name
[:
-
4
]
+
"maxq"
]
minq
=
self
.
all_paramters
[
minq_name
]
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
maxq
,
minq
,
np_type
)
else
:
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fack_quant.minq`
{
w_minq_name
}
"
)
return
None
# Build the `Quant` `Dequant` op.
# AscendQuant only support perlayer version. Need check here.
quant_op
=
inner
.
AscendQuant
(
float
(
scale_a_in
),
float
(
zp_a_in
))
sqrt_mode
=
False
scale_deq
=
scale_a_out
*
scale_w
if
scale_deq
<
2
**
-
14
:
scale_deq
=
np
.
sqrt
(
scale_deq
)
sqrt_mode
=
True
dequant_op
=
inner
.
AscendDequant
(
sqrt_mode
)
# get op
op_core
=
cell_core
.
matmul
if
isinstance
(
cell_core
,
quant
.
DenseQuant
)
else
cell_core
.
conv
if
isinstance
(
activation
,
_AddFakeQuantAfterSubCell
):
activation
=
activation
.
subcell
elif
hasattr
(
activation
,
"get_origin"
):
activation
=
activation
.
get_origin
()
# get the `weight` and `bias`
weight
=
cell_core
.
weight
.
data
.
asnumpy
()
bias
=
None
if
isinstance
(
cell_core
,
(
quant
.
DenseQuant
,
quant
.
Conv2dQuant
)):
if
cell_core
.
has_bias
:
bias
=
cell_core
.
bias
.
data
.
asnumpy
()
elif
isinstance
(
cell_core
,
quant
.
Conv2dBatchNormQuant
):
weight
,
bias
=
quant_utils
.
fold_batchnorm
(
weight
,
cell_core
)
# apply the quant
weight
=
Tensor
(
quant_utils
.
weight2int
(
weight
,
scale_w
,
zp_w
),
self
.
data_type
)
if
bias
is
not
None
:
bias
=
Tensor
(
scale_a_in
*
scale_w
*
bias
,
mstype
.
int32
)
scale_deq
=
Tensor
(
scale_deq
,
mstype
.
float16
)
block
=
quant
.
QuantBlock
(
op_core
,
weight
,
quant_op
,
dequant_op
,
scale_deq
,
bias
,
activation
)
return
block
def
_convert_quant2deploy
(
self
,
network
):
"""Convet network's all quant subcell to deploy subcell."""
cells
=
network
.
name_cells
()
change
=
False
for
name
in
cells
:
subcell
=
cells
[
name
]
if
subcell
==
network
:
continue
cell_core
=
None
fake_quant_act
=
None
activation
=
None
if
isinstance
(
subcell
,
quant
.
Conv2dBnAct
):
cell_core
=
subcell
.
conv
activation
=
subcell
.
activation
fake_quant_act
=
activation
.
fake_quant_act
elif
isinstance
(
subcell
,
quant
.
DenseBnAct
):
cell_core
=
subcell
.
dense
activation
=
subcell
.
activation
fake_quant_act
=
activation
.
fake_quant_act
if
cell_core
is
not
None
:
new_subcell
=
self
.
_get_quant_block
(
cell_core
,
activation
,
fake_quant_act
)
if
new_subcell
:
prefix
=
subcell
.
param_prefix
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
change
=
True
elif
isinstance
(
subcell
,
_AddFakeQuantAfterSubCell
):
op
=
subcell
.
subcell
if
op
.
name
in
ConvertToQuantNetwork
.
__quant_op_name__
and
isinstance
(
op
,
ops
.
Primitive
):
network
.
__delattr__
(
name
)
network
.
__setattr__
(
name
,
op
)
change
=
True
else
:
self
.
_convert_quant2deploy
(
subcell
)
if
isinstance
(
network
,
nn
.
SequentialCell
)
and
change
:
network
.
cell_list
=
list
(
network
.
cells
())
return
network
def
export_geir
(
network
,
*
inputs
,
file_name
):
"""
Exports MindSpore quant predict model to deploy with GEIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
file_name (str): File name of model to export.
"""
exporter
=
ExportQuantNetworkDeploy
(
network
,
*
inputs
)
deploy_net
=
exporter
.
run
()
serialization
.
export
(
deploy_net
,
*
inputs
,
file_name
=
file_name
,
file_format
=
"GEIR"
)
def
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
...
...
mindspore/train/quant/quant_utils.py
浏览文件 @
9958bc47
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
q
uantization utils."""
"""
Q
uantization utils."""
import
numpy
as
np
...
...
@@ -24,22 +24,19 @@ def cal_quantization_params(input_min,
symmetric
=
False
,
narrow_range
=
False
):
r
"""
c
alculate quantization params for scale and zero point.
C
alculate quantization params for scale and zero point.
Args:
input_min (
int, list
): The dimension of channel or 1.
input_max (
int, list
): The dimension of channel or 1.
input_min (
numpy.ndarray
): The dimension of channel or 1.
input_max (
numpy.ndarray
): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Outputs:
scale (int, list): quantization param.
zero point (int, list): quantization param.
Examples:
>>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False)
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
input_max
=
np
.
maximum
(
0.0
,
input_max
)
input_min
=
np
.
minimum
(
0.0
,
input_min
)
...
...
@@ -92,27 +89,103 @@ def weight2int(data,
scale
,
zero_point
):
r
"""
c
alculate int8/uint8 weight from fp32. the formula is defined as:
C
alculate int8/uint8 weight from fp32. the formula is defined as:
.. math::
int8/uint8 = round(float/scale) + offset
Args:
data (
int, list
): The dimension of channel or 1. Should be NCHW.
scale (
int, list
): The dimension of channel or 1.
zero_point (
int, list
): The dimension of channel or 1.
data (
numpy.ndarray
): The dimension of channel or 1. Should be NCHW.
scale (
numpy.ndarray
): The dimension of channel or 1.
zero_point (
numpy.ndarray
): The dimension of channel or 1.
Outputs:
weight (int, list): The dimension of channel or 1.
Examples:
>>> weight = weight2int([1, 2, 1], 1, 0)
Returns:
weight (numpy.ndarray): The dimension of channel or 1.
"""
if
scale
.
shape
!=
zero_point
.
shape
:
raise
ValueError
(
"scale and zero_point should have the same shape."
)
if
scale
.
shape
[
0
]
>
0
:
scale
=
scale
.
reshape
(
1
,
-
1
,
1
,
1
)
zero_point
=
zero_point
.
reshape
(
1
,
-
1
,
1
,
1
)
scale
=
scale
.
reshape
(
1
,
-
1
)
zero_point
=
zero_point
.
reshape
(
1
,
-
1
)
return
np
.
round
((
data
/
scale
)
+
zero_point
)
def
scale_zp_from_fack_quant_cell
(
cell
,
data_type
):
r
"""
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
Args:
cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq
=
cell
.
minq
.
data
.
asnumpy
()
maxq
=
cell
.
maxq
.
data
.
asnumpy
()
op
=
cell
.
fake_quant
scale
,
zp
=
cal_quantization_params
(
minq
,
maxq
,
data_type
,
num_bits
=
op
.
num_bits
,
symmetric
=
op
.
symmetric
,
narrow_range
=
op
.
narrow_range
)
return
scale
,
zp
def
scale_zp_from_data
(
op
,
minq
,
maxq
,
data_type
):
r
"""
Get calculate quantization params for scale and zero point.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
`mindspore.ops.operation.FakeQuantPerChannel`
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq
=
minq
.
data
.
asnumpy
()
maxq
=
maxq
.
data
.
asnumpy
()
scale
,
zp
=
cal_quantization_params
(
minq
,
maxq
,
data_type
,
num_bits
=
op
.
num_bits
,
symmetric
=
op
.
symmetric
,
narrow_range
=
op
.
narrow_range
)
return
scale
,
zp
def
fold_batchnorm
(
weight
,
cell_quant
):
r
"""
Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
weight (numpy.ndarray): Weight of `cell_quant`.
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
Returns:
weight (numpy.ndarray): Folded weight.
bias (numpy.ndarray): Folded bias.
"""
variance
=
cell_quant
.
moving_variance
.
data
.
asnumpy
()
mean
=
cell_quant
.
moving_mean
.
data
.
asnumpy
()
gamma
=
cell_quant
.
gamma
.
data
.
asnumpy
()
beta
=
cell_quant
.
beta
.
data
.
asnumpy
()
epsilon
=
cell_quant
.
eps
sigma
=
np
.
sqrt
(
variance
+
epsilon
)
gamma
=
gamma
.
reshape
(
-
1
,
1
,
1
,
1
)
sigma
=
sigma
.
reshape
(
-
1
,
1
,
1
,
1
)
mean
=
mean
.
reshape
(
-
1
,
1
,
1
,
1
)
weight
=
weight
*
gamma
/
sigma
bias
=
beta
-
gamma
*
mean
/
sigma
return
weight
,
bias
tests/st/model_zoo_tests/yolov3/test_yolov3.py
浏览文件 @
9958bc47
...
...
@@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'):
params
=
network
.
trainable_params
()
for
p
in
params
:
if
isinstance
(
p
.
data
,
Tensor
)
and
'beta'
not
in
p
.
name
and
'gamma'
not
in
p
.
name
and
'bias'
not
in
p
.
name
:
p
.
set_parameter_data
(
initializer
(
init_value
,
p
.
data
.
shape
(),
p
.
data
.
dtype
()
))
p
.
set_parameter_data
(
initializer
(
init_value
,
p
.
data
.
shape
,
p
.
data
.
dtype
))
class
ModelCallback
(
Callback
):
def
__init__
(
self
):
...
...
tests/ut/python/train/quant/test_quant.py
浏览文件 @
9958bc47
...
...
@@ -13,9 +13,14 @@
# limitations under the License.
# ============================================================================
""" tests for quant """
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore
import
nn
from
mindspore.train.quant
import
quant
as
qat
from
mobilenetv2_combined
import
MobileNetV2
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
...
...
@@ -37,23 +42,45 @@ class LeNet5(nn.Cell):
def
__init__
(
self
,
num_class
=
10
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
nn
.
Conv2dBnAct
(
1
,
6
,
kernel_size
=
5
,
batchnorm
=
True
,
activation
=
'relu6'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
kernel_size
=
5
,
activation
=
'relu'
)
self
.
conv1
=
nn
.
Conv2dBnAct
(
1
,
6
,
kernel_size
=
5
,
batchnorm
=
True
,
activation
=
'relu6'
,
pad_mode
=
"valid"
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
kernel_size
=
5
,
activation
=
'relu'
,
pad_mode
=
"valid"
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc3
=
nn
.
DenseBnAct
(
84
,
self
.
num_class
)
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatte
r
n
=
nn
.
Flatten
()
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatte
r
n
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
def
test_qat_lenet
():
img
=
Tensor
(
np
.
ones
((
32
,
1
,
32
,
32
)).
astype
(
np
.
float32
))
net
=
LeNet5
()
net
=
qat
.
convert_quant_network
(
net
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
,
weight_bits
=
8
,
act_bits
=
8
)
# should load the checkpoint. mock here
for
param
in
net
.
get_parameters
():
param
.
init_data
()
qat
.
export_geir
(
net
,
img
,
file_name
=
"quant.pb"
)
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
def
test_qat_mobile
():
net
=
MobileNetV2
()
img
=
Tensor
(
np
.
ones
((
1
,
3
,
224
,
224
)).
astype
(
np
.
float32
))
net
=
qat
.
convert_quant_network
(
net
,
quant_delay
=
0
,
bn_fold
=
True
,
freeze_bn
=
10000
,
weight_bits
=
8
,
act_bits
=
8
)
# should load the checkpoint. mock here
for
param
in
net
.
get_parameters
():
param
.
init_data
()
qat
.
export_geir
(
net
,
img
,
file_name
=
"quant.pb"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录