Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9958bc47
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9958bc47
编写于
6月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2161 [qat]Export network from quantization aware network to deploy
Merge pull request !2161 from vlne-v1/I1IZV3-quant-infer
上级
06b511ca
1d77bf86
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
529 addition
and
68 deletion
+529
-68
mindspore/ccsrc/ir/tensor.cc
mindspore/ccsrc/ir/tensor.cc
+13
-0
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+2
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+2
-0
mindspore/ccsrc/pipeline/init.cc
mindspore/ccsrc/pipeline/init.cc
+2
-0
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+69
-0
mindspore/ccsrc/pipeline/pipeline.h
mindspore/ccsrc/pipeline/pipeline.h
+2
-0
mindspore/ccsrc/utils/graph_utils.h
mindspore/ccsrc/utils/graph_utils.h
+4
-0
mindspore/ccsrc/utils/graph_utils_extends.cc
mindspore/ccsrc/utils/graph_utils_extends.cc
+12
-3
mindspore/common/api.py
mindspore/common/api.py
+5
-0
mindspore/nn/layer/normalization.py
mindspore/nn/layer/normalization.py
+2
-5
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+96
-7
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+3
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+15
-3
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+170
-19
mindspore/train/quant/quant_utils.py
mindspore/train/quant/quant_utils.py
+95
-22
tests/st/model_zoo_tests/yolov3/test_yolov3.py
tests/st/model_zoo_tests/yolov3/test_yolov3.py
+1
-1
tests/ut/python/train/quant/test_quant.py
tests/ut/python/train/quant/test_quant.py
+34
-7
未找到文件。
mindspore/ccsrc/ir/tensor.cc
浏览文件 @
9958bc47
...
@@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
...
@@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}));
}));
(
void
)
py
::
class_
<
MetaTensor
,
std
::
shared_ptr
<
MetaTensor
>>
(
*
m
,
"MetaTensor"
)
(
void
)
py
::
class_
<
MetaTensor
,
std
::
shared_ptr
<
MetaTensor
>>
(
*
m
,
"MetaTensor"
)
.
def
(
py
::
init
<
TypePtr
,
const
std
::
vector
<
int
>>
(),
py
::
arg
(
"dtype"
),
py
::
arg
(
"shape"
))
.
def
(
py
::
init
<
TypePtr
,
const
std
::
vector
<
int
>>
(),
py
::
arg
(
"dtype"
),
py
::
arg
(
"shape"
))
.
def
(
py
::
pickle
(
[](
const
MetaTensor
&
t
)
{
// __getstate__
/* Return a tuple that fully encodes the state of the object */
return
py
::
make_tuple
(
static_cast
<
int
>
(
t
.
data_type
()),
t
.
shape
());
},
[](
const
py
::
tuple
&
t
)
{
// __setstate__
if
(
t
.
size
()
!=
2
)
{
throw
std
::
runtime_error
(
"Invalid state!"
);
}
/* Create a new C++ instance */
MetaTensor
tensor
(
TypeId
(
t
[
0
].
cast
<
int
>
()),
t
[
1
].
cast
<
std
::
vector
<
int
>>
());
return
tensor
;
}))
.
def_readonly
(
PYTHON_META_TENSOR_FLAG
,
&
MetaTensor
::
parse_info_
)
.
def_readonly
(
PYTHON_META_TENSOR_FLAG
,
&
MetaTensor
::
parse_info_
)
.
def_property_readonly
(
"dtype"
,
&
MetaTensor
::
Dtype
,
"Get the MetaTensor's dtype."
)
.
def_property_readonly
(
"dtype"
,
&
MetaTensor
::
Dtype
,
"Get the MetaTensor's dtype."
)
.
def_property_readonly
(
"shape"
,
&
MetaTensor
::
shape
,
"Get the MetaTensor's shape."
);
.
def_property_readonly
(
"shape"
,
&
MetaTensor
::
shape
,
"Get the MetaTensor's shape."
);
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
9958bc47
...
@@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
...
@@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const
PrimitivePtr
kPrimZerosLike
=
std
::
make_shared
<
Primitive
>
(
"ZerosLike"
);
const
PrimitivePtr
kPrimZerosLike
=
std
::
make_shared
<
Primitive
>
(
"ZerosLike"
);
const
PrimitivePtr
kPrimFakeBprop
=
std
::
make_shared
<
Primitive
>
(
"fake_bprop"
);
const
PrimitivePtr
kPrimFakeBprop
=
std
::
make_shared
<
Primitive
>
(
"fake_bprop"
);
const
PrimitivePtr
kPrimBpropCut
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
const
PrimitivePtr
kPrimBpropCut
=
std
::
make_shared
<
Primitive
>
(
"bprop_cut"
);
const
PrimitivePtr
kPrimFakeQuantPerLayer
=
std
::
make_shared
<
Primitive
>
(
"FakeQuantPerLayer"
);
const
PrimitivePtr
kPrimFakeQuantPerChannel
=
std
::
make_shared
<
Primitive
>
(
"FakeQuantPerChannel"
);
// Other miscellaneous
// Other miscellaneous
const
PrimitivePtr
kPrimIdentity
=
std
::
make_shared
<
Primitive
>
(
"identity"
);
const
PrimitivePtr
kPrimIdentity
=
std
::
make_shared
<
Primitive
>
(
"identity"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
9958bc47
...
@@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation;
...
@@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation;
extern
const
PrimitivePtr
kPrimZerosLike
;
extern
const
PrimitivePtr
kPrimZerosLike
;
extern
const
PrimitivePtr
kPrimFakeBprop
;
extern
const
PrimitivePtr
kPrimFakeBprop
;
extern
const
PrimitivePtr
kPrimBpropCut
;
extern
const
PrimitivePtr
kPrimBpropCut
;
extern
const
PrimitivePtr
kPrimFakeQuantPerLayer
;
extern
const
PrimitivePtr
kPrimFakeQuantPerChannel
;
// Other Miscellaneous
// Other Miscellaneous
extern
const
PrimitivePtr
kPrimIdentity
;
extern
const
PrimitivePtr
kPrimIdentity
;
...
...
mindspore/ccsrc/pipeline/init.cc
浏览文件 @
9958bc47
...
@@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) {
...
@@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get CNode Strategy Dictionary."
)
"Get CNode Strategy Dictionary."
)
.
def
(
"get_allreduce_fusion"
,
&
ExecutorPy
::
GetAllreduceFusion
,
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
.
def
(
"get_allreduce_fusion"
,
&
ExecutorPy
::
GetAllreduceFusion
,
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
"Get Allreduce Fusion Dictionary."
)
"Get Allreduce Fusion Dictionary."
)
.
def
(
"fetch_info_for_quant_export"
,
&
ExecutorPy
::
FetchInfoForQuantExport
,
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
"Fetch the inputs of Conv or Matmul for quant export."
)
.
def
(
"build_data_graph"
,
&
ExecutorPy
::
BuildGraph
,
py
::
arg
(
"build_params"
),
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
.
def
(
"build_data_graph"
,
&
ExecutorPy
::
BuildGraph
,
py
::
arg
(
"build_params"
),
py
::
arg
(
"phase"
)
=
py
::
str
(
"train"
),
py
::
arg
(
"broadcast_params"
)
=
py
::
dict
(),
"Build data graph."
)
py
::
arg
(
"broadcast_params"
)
=
py
::
dict
(),
"Build data graph."
)
.
def
(
"has_compiled"
,
&
ExecutorPy
::
HasCompiled
,
py
::
arg
(
"phase"
)
=
py
::
str
(
""
),
"get if cell compiled."
)
.
def
(
"has_compiled"
,
&
ExecutorPy
::
HasCompiled
,
py
::
arg
(
"phase"
)
=
py
::
str
(
""
),
"get if cell compiled."
)
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
9958bc47
...
@@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() {
...
@@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() {
ConfigManager
::
GetInstance
().
ResetConfig
();
ConfigManager
::
GetInstance
().
ResetConfig
();
}
}
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
ExecutorPy
::
FetchInfoForQuantExport
(
const
std
::
string
&
phase_s
)
{
FuncGraphPtr
func_graph
=
info_
[
phase_s
]
->
resource
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
MS_LOG
(
DEBUG
)
<<
"FetchInfoForQuantExport func graph("
<<
func_graph
->
ToString
()
<<
") phase("
<<
phase_s
<<
")!"
;
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
fake_quant_table
;
auto
filter
=
[](
AnfNodePtr
node
)
{
return
!
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimConv2D
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimMatMul
));
};
std
::
vector
<
AnfNodePtr
>
nodes
=
DeepScopedGraphSearchWithFilter
(
func_graph
->
get_return
(),
AlwaysInclude
,
filter
);
auto
is_quant_cnode
=
[](
AnfNodePtr
node
)
{
return
IsPrimitiveCNode
(
node
,
prim
::
kPrimFakeQuantPerLayer
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimFakeQuantPerChannel
);
};
for
(
auto
node
:
nodes
)
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
!=
3
)
{
continue
;
}
auto
x
=
cnode
->
input
(
1
);
auto
weight
=
cnode
->
input
(
2
);
if
(
!
is_quant_cnode
(
weight
))
{
continue
;
}
// get parameter weight's name
cnode
=
weight
->
cast
<
CNodePtr
>
();
auto
weight_node
=
cnode
->
input
(
2
);
if
(
!
weight_node
->
isa
<
Parameter
>
())
{
continue
;
}
auto
weight_name
=
weight_node
->
cast
<
ParameterPtr
>
()
->
name
();
// find the fakequant from input
int
count
=
0
;
int
max_depth
=
5
;
while
(
!
is_quant_cnode
(
x
))
{
if
(
count
>=
max_depth
)
{
break
;
}
cnode
=
x
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
<=
1
)
{
break
;
}
x
=
cnode
->
input
(
1
);
count
+=
1
;
}
// get the fakequant parameter minq's name
if
(
!
is_quant_cnode
(
x
))
{
continue
;
}
cnode
=
x
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
||
cnode
->
size
()
!=
4
)
{
continue
;
}
auto
fakequant_min_node
=
cnode
->
input
(
2
);
if
(
!
fakequant_min_node
->
isa
<
Parameter
>
())
{
continue
;
}
auto
fakequant_min_node_name
=
fakequant_min_node
->
cast
<
ParameterPtr
>
()
->
name
();
auto
quant_op_value
=
cnode
->
input
(
0
)
->
cast
<
ValueNodePtr
>
()
->
value
();
if
(
!
quant_op_value
->
isa
<
PrimitivePy
>
())
{
continue
;
}
auto
quant_op
=
quant_op_value
->
cast
<
PrimitivePyPtr
>
();
fake_quant_table
[
weight_name
]
=
std
::
make_pair
(
quant_op
,
fakequant_min_node_name
);
}
return
fake_quant_table
;
}
void
ExecutorPy
::
SaveCompiledGraph
(
const
std
::
string
&
phase_s
)
{
void
ExecutorPy
::
SaveCompiledGraph
(
const
std
::
string
&
phase_s
)
{
// save the graph to ExecutorPy
// save the graph to ExecutorPy
FuncGraphPtr
func_graph
=
info_
[
phase_s
]
->
resource
->
func_graph
();
FuncGraphPtr
func_graph
=
info_
[
phase_s
]
->
resource
->
func_graph
();
...
...
mindspore/ccsrc/pipeline/pipeline.h
浏览文件 @
9958bc47
...
@@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
...
@@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void
ReleaseResource
(
const
py
::
object
&
phase
);
void
ReleaseResource
(
const
py
::
object
&
phase
);
static
void
ClearRes
();
static
void
ClearRes
();
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
FetchInfoForQuantExport
(
const
std
::
string
&
phase_s
);
private:
private:
ExecutorPy
();
ExecutorPy
();
void
ConvertObjectToTensors
(
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
*
tensors
);
void
ConvertObjectToTensors
(
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
*
tensors
);
...
...
mindspore/ccsrc/utils/graph_utils.h
浏览文件 @
9958bc47
...
@@ -39,6 +39,7 @@ namespace mindspore {
...
@@ -39,6 +39,7 @@ namespace mindspore {
enum
IncludeType
{
FOLLOW
,
NOFOLLOW
,
EXCLUDE
};
enum
IncludeType
{
FOLLOW
,
NOFOLLOW
,
EXCLUDE
};
using
IncludeFunc
=
std
::
function
<
IncludeType
(
const
AnfNodePtr
&
)
>
;
using
IncludeFunc
=
std
::
function
<
IncludeType
(
const
AnfNodePtr
&
)
>
;
using
FilterFunc
=
std
::
function
<
bool
(
const
AnfNodePtr
&
)
>
;
using
SuccFunc
=
std
::
function
<
std
::
vector
<
AnfNodePtr
>
(
AnfNodePtr
)
>
;
using
SuccFunc
=
std
::
function
<
std
::
vector
<
AnfNodePtr
>
(
AnfNodePtr
)
>
;
using
SearchFunc
=
std
::
function
<
std
::
vector
<
AnfNodePtr
>
(
const
AnfNodePtr
&
,
const
IncludeFunc
&
)
>
;
using
SearchFunc
=
std
::
function
<
std
::
vector
<
AnfNodePtr
>
(
const
AnfNodePtr
&
,
const
IncludeFunc
&
)
>
;
...
@@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl
...
@@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl
std
::
vector
<
AnfNodePtr
>
DeepUsedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
AnfNodePtr
>
DeepUsedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
AnfNodePtr
>
DeepLinkedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
AnfNodePtr
>
DeepLinkedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearchWithFilter
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
,
const
FilterFunc
&
filter
);
std
::
vector
<
AnfNodePtr
>
TopoSort
(
const
AnfNodePtr
&
root
,
const
SuccFunc
&
succ
=
SuccIncoming
,
std
::
vector
<
AnfNodePtr
>
TopoSort
(
const
AnfNodePtr
&
root
,
const
SuccFunc
&
succ
=
SuccIncoming
,
const
IncludeFunc
&
include
=
AlwaysInclude
);
const
IncludeFunc
&
include
=
AlwaysInclude
);
...
...
mindspore/ccsrc/utils/graph_utils_extends.cc
浏览文件 @
9958bc47
...
@@ -37,7 +37,8 @@ namespace mindspore {
...
@@ -37,7 +37,8 @@ namespace mindspore {
namespace
{
namespace
{
class
DeepFirstSearcher
:
public
AnfVisitor
{
class
DeepFirstSearcher
:
public
AnfVisitor
{
public:
public:
explicit
DeepFirstSearcher
(
const
IncludeFunc
&
include
)
:
include_
(
include
)
{}
explicit
DeepFirstSearcher
(
const
IncludeFunc
&
include
,
const
FilterFunc
&
filter
=
nullptr
)
:
include_
(
include
),
filter_
(
filter
)
{}
~
DeepFirstSearcher
()
override
=
default
;
~
DeepFirstSearcher
()
override
=
default
;
std
::
vector
<
AnfNodePtr
>
Search
(
const
AnfNodePtr
&
root
)
{
std
::
vector
<
AnfNodePtr
>
Search
(
const
AnfNodePtr
&
root
)
{
...
@@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor {
...
@@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor {
if
(
incl
==
EXCLUDE
)
{
if
(
incl
==
EXCLUDE
)
{
return
;
return
;
}
}
if
(
filter_
==
nullptr
||
!
filter_
(
node
))
{
res_
.
push_back
(
node
);
res_
.
push_back
(
node
);
}
if
(
incl
==
FOLLOW
)
{
if
(
incl
==
FOLLOW
)
{
AnfVisitor
::
Visit
(
node
);
AnfVisitor
::
Visit
(
node
);
}
}
...
@@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor {
...
@@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor {
private:
private:
size_t
seen_
{
0
};
size_t
seen_
{
0
};
IncludeFunc
include_
;
IncludeFunc
include_
;
FilterFunc
filter_
;
std
::
vector
<
AnfNodePtr
>
res_
{};
std
::
vector
<
AnfNodePtr
>
res_
{};
};
};
...
@@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {
...
@@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {
};
};
}
// namespace
}
// namespace
// include for if expand the node the search, filter for if put the node to results.
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
)
{
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
)
{
return
DeepScopedGraphSearcher
(
include
).
Search
(
root
);
return
DeepScopedGraphSearcher
(
include
).
Search
(
root
);
}
}
std
::
vector
<
AnfNodePtr
>
DeepScopedGraphSearchWithFilter
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
,
const
FilterFunc
&
filter
)
{
return
DeepFirstSearcher
(
include
,
filter
).
Search
(
root
);
}
std
::
vector
<
AnfNodePtr
>
DeepUsedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
)
{
std
::
vector
<
AnfNodePtr
>
DeepUsedGraphSearch
(
const
AnfNodePtr
&
root
,
const
IncludeFunc
&
include
)
{
return
DeepUsedGraphSearcher
(
include
).
Search
(
root
);
return
DeepUsedGraphSearcher
(
include
).
Search
(
root
);
}
}
...
...
mindspore/common/api.py
浏览文件 @
9958bc47
...
@@ -526,6 +526,11 @@ class _Executor:
...
@@ -526,6 +526,11 @@ class _Executor:
phase
=
'export'
+
'.'
+
str
(
net
.
create_time
)
phase
=
'export'
+
'.'
+
str
(
net
.
create_time
)
export_graph
(
file_name
,
file_format
,
phase
)
export_graph
(
file_name
,
file_format
,
phase
)
def
fetch_info_for_quant_export
(
self
,
exec_id
):
"""Get graph proto from pipeline."""
if
self
.
_executor
.
has_compiled
(
exec_id
)
is
False
:
return
None
return
self
.
_executor
.
fetch_info_for_quant_export
(
exec_id
)
_executor
=
_Executor
()
_executor
=
_Executor
()
_pynative_exec
=
_PynativeExecutor
()
_pynative_exec
=
_PynativeExecutor
()
...
...
mindspore/nn/layer/normalization.py
浏览文件 @
9958bc47
...
@@ -18,8 +18,6 @@ from mindspore.ops import functional as F
...
@@ -18,8 +18,6 @@ from mindspore.ops import functional as F
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.common.initializer
import
initializer
from
mindspore.ops.primitive
import
constexpr
from
mindspore.ops.primitive
import
constexpr
from
mindspore.common.tensor
import
Tensor
import
mindspore.common.dtype
as
mstype
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore._checkparam
import
check_bool
,
check_typename
from
mindspore._checkparam
import
check_bool
,
check_typename
from
mindspore._extends
import
cell_attr_register
from
mindspore._extends
import
cell_attr_register
...
@@ -85,13 +83,12 @@ class _BatchNorm(Cell):
...
@@ -85,13 +83,12 @@ class _BatchNorm(Cell):
self
.
reshape
=
P
.
Reshape
()
self
.
reshape
=
P
.
Reshape
()
self
.
is_ascend
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
self
.
is_ascend
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
self
.
is_graph_mode
=
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
self
.
is_graph_mode
=
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
self
.
momentum
=
1.0
-
momentum
if
context
.
get_context
(
"enable_ge"
):
if
context
.
get_context
(
"enable_ge"
):
self
.
is_ge_backend
=
True
self
.
is_ge_backend
=
True
self
.
momentum
=
Tensor
(
1.0
-
momentum
,
mstype
.
float32
)
else
:
else
:
self
.
is_ge_backend
=
False
self
.
is_ge_backend
=
False
self
.
momentum
=
1.0
-
momentum
if
self
.
is_graph_mode
and
(
self
.
is_ge_backend
or
self
.
is_ascend
):
if
self
.
is_graph_mode
and
(
self
.
is_ge_backend
or
self
.
is_ascend
):
self
.
bn_train
=
P
.
BatchNorm
(
is_training
=
True
,
self
.
bn_train
=
P
.
BatchNorm
(
is_training
=
True
,
epsilon
=
self
.
eps
)
epsilon
=
self
.
eps
)
...
...
mindspore/nn/layer/quant.py
浏览文件 @
9958bc47
...
@@ -729,8 +729,8 @@ class DenseQuant(Cell):
...
@@ -729,8 +729,8 @@ class DenseQuant(Cell):
self
.
has_bias
=
check_bool
(
has_bias
)
self
.
has_bias
=
check_bool
(
has_bias
)
if
isinstance
(
weight_init
,
Tensor
):
if
isinstance
(
weight_init
,
Tensor
):
if
weight_init
.
dim
()
!=
2
or
weight_init
.
shape
()
[
0
]
!=
out_channels
or
\
if
weight_init
.
dim
()
!=
2
or
weight_init
.
shape
[
0
]
!=
out_channels
or
\
weight_init
.
shape
()
[
1
]
!=
in_channels
:
weight_init
.
shape
[
1
]
!=
in_channels
:
raise
ValueError
(
"weight_init shape error"
)
raise
ValueError
(
"weight_init shape error"
)
self
.
weight
=
Parameter
(
initializer
(
self
.
weight
=
Parameter
(
initializer
(
...
@@ -738,7 +738,7 @@ class DenseQuant(Cell):
...
@@ -738,7 +738,7 @@ class DenseQuant(Cell):
if
self
.
has_bias
:
if
self
.
has_bias
:
if
isinstance
(
bias_init
,
Tensor
):
if
isinstance
(
bias_init
,
Tensor
):
if
bias_init
.
dim
()
!=
1
or
bias_init
.
shape
()
[
0
]
!=
out_channels
:
if
bias_init
.
dim
()
!=
1
or
bias_init
.
shape
[
0
]
!=
out_channels
:
raise
ValueError
(
"bias_init shape error"
)
raise
ValueError
(
"bias_init shape error"
)
self
.
bias
=
Parameter
(
initializer
(
self
.
bias
=
Parameter
(
initializer
(
...
@@ -780,8 +780,14 @@ class DenseQuant(Cell):
...
@@ -780,8 +780,14 @@ class DenseQuant(Cell):
return
str_info
return
str_info
class
_QuantActivation
(
Cell
):
r
"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
"""
def
get_origin
(
self
):
raise
NotImplementedError
class
ReLUQuant
(
Cell
):
class
ReLUQuant
(
_QuantActivation
):
r
"""
r
"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
...
@@ -828,8 +834,11 @@ class ReLUQuant(Cell):
...
@@ -828,8 +834,11 @@ class ReLUQuant(Cell):
x
=
self
.
fake_quant_act
(
x
)
x
=
self
.
fake_quant_act
(
x
)
return
x
return
x
def
get_origin
(
self
):
return
self
.
relu
class
ReLU6Quant
(
Cell
):
class
ReLU6Quant
(
_QuantActivation
):
r
"""
r
"""
ReLU6Quant activation function.
ReLU6Quant activation function.
...
@@ -878,8 +887,10 @@ class ReLU6Quant(Cell):
...
@@ -878,8 +887,10 @@ class ReLU6Quant(Cell):
x
=
self
.
fake_quant_act
(
x
)
x
=
self
.
fake_quant_act
(
x
)
return
x
return
x
def
get_origin
(
self
):
return
self
.
relu6
class
HSwishQuant
(
Cell
):
class
HSwishQuant
(
_QuantActivation
):
r
"""
r
"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
...
@@ -935,8 +946,10 @@ class HSwishQuant(Cell):
...
@@ -935,8 +946,10 @@ class HSwishQuant(Cell):
x
=
self
.
fake_quant_act_after
(
x
)
x
=
self
.
fake_quant_act_after
(
x
)
return
x
return
x
def
get_origin
(
self
):
return
self
.
act
class
HSigmoidQuant
(
Cell
):
class
HSigmoidQuant
(
_QuantActivation
):
r
"""
r
"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
...
@@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell):
...
@@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell):
x
=
self
.
fake_quant_act_after
(
x
)
x
=
self
.
fake_quant_act_after
(
x
)
return
x
return
x
def
get_origin
(
self
):
return
self
.
act
class
TensorAddQuant
(
Cell
):
class
TensorAddQuant
(
Cell
):
r
"""
r
"""
...
@@ -1083,3 +1098,77 @@ class MulQuant(Cell):
...
@@ -1083,3 +1098,77 @@ class MulQuant(Cell):
x
=
self
.
mul
(
x1
,
x2
)
x
=
self
.
mul
(
x1
,
x2
)
x
=
self
.
fake_quant_act
(
x
)
x
=
self
.
fake_quant_act
(
x
)
return
x
return
x
class
QuantBlock
(
Cell
):
r
"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def
__init__
(
self
,
core_op
,
weight
,
quant_op
,
dequant_op
,
dequant_scale
,
bias
=
None
,
activation
=
None
):
super
(
QuantBlock
,
self
).
__init__
()
self
.
core_op
=
core_op
self
.
weight
=
weight
self
.
quant
=
quant_op
self
.
dequant
=
dequant_op
self
.
dequant_scale
=
dequant_scale
self
.
bias
=
bias
self
.
has_bias
=
bias
is
None
self
.
activation
=
activation
self
.
has_act
=
activation
is
None
def
construct
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
core_op
(
x
,
self
.
weight
)
if
self
.
has_bias
:
output
=
self
.
bias_add
(
output
,
self
.
bias
)
if
self
.
has_act
:
x
=
self
.
activation
(
x
)
x
=
self
.
dequant
(
x
,
self
.
dequant_scale
)
return
x
def
extend_repr
(
self
):
str_info
=
f
'quant=
{
self
.
quant
}
, core_op=
{
type
(
self
.
core_op
)
}
'
if
self
.
has_bias
:
str_info
=
str_info
+
f
', bias=
{
self
.
bias
}
'
if
self
.
has_act
:
str_info
=
str_info
+
f
', activation=
{
self
.
activation
}
'
str_info
=
str_info
+
f
', dequant=
{
self
.
dequant
}
'
return
str_info
mindspore/ops/operations/math_ops.py
浏览文件 @
9958bc47
...
@@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer):
...
@@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x
,
y
):
def
infer_dtype
(
self
,
x
,
y
):
args
=
{
"x"
:
x
,
"y"
:
y
}
args
=
{
"x"
:
x
,
"y"
:
y
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
name
)
if
x
.
element_type
()
==
mstype
.
int8
:
return
mstype
.
tensor_type
(
mstype
.
int32
)
return
x
return
x
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
9958bc47
...
@@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer):
...
@@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer):
def
infer_shape
(
self
,
x_shape
,
w_shape
):
def
infer_shape
(
self
,
x_shape
,
w_shape
):
validator
.
check_integer
(
"weight rank"
,
len
(
w_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"weight rank"
,
len
(
w_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"x_shape[1] / group"
,
x_shape
[
1
]
//
self
.
group
,
"w_shape[1]"
,
w_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
f
"x_shape[1] / group"
,
x_shape
[
1
]
//
self
.
group
,
"w_shape[1]"
,
w_shape
[
1
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'out_channel'
,
self
.
out_channel
,
'w_shape[0]'
,
w_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'out_channel'
,
self
.
out_channel
,
'w_shape[0]'
,
w_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'kernel_size'
,
self
.
kernel_size
,
'w_shape[2:4]'
,
tuple
(
w_shape
[
2
:
4
]),
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'kernel_size'
,
self
.
kernel_size
,
'w_shape[2:4]'
,
tuple
(
w_shape
[
2
:
4
]),
Rel
.
EQ
,
self
.
name
)
...
@@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer):
...
@@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer):
args
=
{
'x'
:
x_dtype
,
'w'
:
w_dtype
}
args
=
{
'x'
:
x_dtype
,
'w'
:
w_dtype
}
valid_types
=
[
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
valid_types
=
[
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
validator
.
check_tensor_type_same
(
args
,
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
args
,
valid_types
,
self
.
name
)
if
x_dtype
.
element_type
()
==
mstype
.
int8
:
return
mstype
.
tensor_type
(
mstype
.
int32
)
return
x_dtype
return
x_dtype
...
...
mindspore/ops/primitive.py
浏览文件 @
9958bc47
...
@@ -43,11 +43,12 @@ class Primitive(Primitive_):
...
@@ -43,11 +43,12 @@ class Primitive(Primitive_):
>>> # init a Primitive obj with attr1=1 and attr2=2
>>> # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
>>> add = Add(attr1=1, attr2=2)
"""
"""
_repr_ignore_list
=
[
'input_names'
,
'output_names'
]
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
name
=
name
self
.
attrs
=
{}
self
.
attrs
=
{}
self
.
init_attrs
=
{}
self
.
init_attrs
=
{
"name"
:
name
}
Primitive_
.
__init__
(
self
,
name
,
self
)
Primitive_
.
__init__
(
self
,
name
,
self
)
if
hasattr
(
self
.
__class__
,
'__mindspore_signature__'
):
if
hasattr
(
self
.
__class__
,
'__mindspore_signature__'
):
sig
=
self
.
_fill_signature
(
self
.
__class__
.
__mindspore_signature__
)
sig
=
self
.
_fill_signature
(
self
.
__class__
.
__mindspore_signature__
)
...
@@ -165,6 +166,16 @@ class Primitive(Primitive_):
...
@@ -165,6 +166,16 @@ class Primitive(Primitive_):
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
def
__deepcopy__
(
self
,
memo
):
return
type
(
self
)(
**
self
.
init_attrs
)
def
__repr__
(
self
):
attr
=
', '
.
join
([
f
'
{
k
}
=
{
self
.
attrs
[
k
]
}
'
for
k
in
self
.
attrs
if
not
k
in
Primitive
.
_repr_ignore_list
])
info_str
=
f
'Prim[
{
self
.
name
}
]'
if
attr
:
info_str
+=
f
'<
{
attr
}
>'
return
info_str
def
init_prim_io_names
(
self
,
inputs
,
outputs
):
def
init_prim_io_names
(
self
,
inputs
,
outputs
):
"""
"""
Initializes inputs and outpus name of Tensor or attributes.
Initializes inputs and outpus name of Tensor or attributes.
...
@@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive):
...
@@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive):
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describ
l
e shape
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape
and type infer logic. The infer_value() is used for constant prop
o
gation.
and type infer logic. The infer_value() is used for constant prop
a
gation.
Args:
Args:
name (str): Name for current Primitive.
name (str): Name for current Primitive.
...
@@ -288,6 +299,7 @@ def prim_attr_register(fn):
...
@@ -288,6 +299,7 @@ def prim_attr_register(fn):
bound_args
.
apply_defaults
()
bound_args
.
apply_defaults
()
arguments
=
bound_args
.
arguments
arguments
=
bound_args
.
arguments
del
arguments
[
'self'
]
del
arguments
[
'self'
]
del
self
.
init_attrs
[
'name'
]
for
name
in
arguments
:
for
name
in
arguments
:
value
=
arguments
[
name
]
value
=
arguments
[
name
]
self
.
add_prim_attr
(
name
,
value
)
self
.
add_prim_attr
(
name
,
value
)
...
...
mindspore/train/quant/quant.py
浏览文件 @
9958bc47
...
@@ -14,12 +14,23 @@
...
@@ -14,12 +14,23 @@
# ============================================================================
# ============================================================================
"""aware quantization."""
"""aware quantization."""
import
copy
import
re
import
re
from
...
import
nn
from
...
import
ops
import
numpy
as
np
from
...
import
log
as
logger
from
...
import
nn
,
ops
from
..._checkparam
import
ParamValidator
as
validator
from
..._checkparam
import
ParamValidator
as
validator
from
..._checkparam
import
Rel
from
..._checkparam
import
Rel
from
...common
import
Tensor
from
...common
import
dtype
as
mstype
from
...common.api
import
_executor
from
...nn.layer
import
quant
from
...nn.layer
import
quant
from
...ops
import
functional
as
F
from
...ops.operations
import
_inner_ops
as
inner
from
...train
import
serialization
from
.
import
quant_utils
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
...
@@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
...
@@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn
.
HSwish
:
quant
.
HSwishQuant
}
nn
.
HSwish
:
quant
.
HSwishQuant
}
class
_AddFakeQuantInput
Output
(
nn
.
Cell
):
class
_AddFakeQuantInput
(
nn
.
Cell
):
"""
"""
Add FakeQuant at input and output of the Network. Only support one input and one output case.
Add FakeQuant at input and output of the Network. Only support one input and one output case.
"""
"""
def
__init__
(
self
,
network
,
quant_delay
=
0
):
def
__init__
(
self
,
network
,
quant_delay
=
0
):
super
(
_AddFakeQuantInput
Output
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_AddFakeQuantInput
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
=
network
self
.
fake_quant_input
=
quant
.
FakeQuantWithMinMax
(
self
.
fake_quant_input
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input'
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input'
)
self
.
fake_quant_output
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_output
.
update_parameters_name
(
'fake_quant_output'
)
def
construct
(
self
,
data
):
def
construct
(
self
,
data
):
data
=
self
.
fake_quant_input
(
data
)
data
=
self
.
fake_quant_input
(
data
)
output
=
self
.
network
(
data
)
output
=
self
.
network
(
data
)
output
=
self
.
fake_quant_output
(
output
)
return
output
return
output
...
@@ -99,6 +106,8 @@ class ConvertToQuantNetwork:
...
@@ -99,6 +106,8 @@ class ConvertToQuantNetwork:
self
.
per_channel
=
validator
.
check_bool
(
"per channel"
,
per_channel
)
self
.
per_channel
=
validator
.
check_bool
(
"per channel"
,
per_channel
)
self
.
symmetric
=
validator
.
check_bool
(
"symmetric"
,
symmetric
)
self
.
symmetric
=
validator
.
check_bool
(
"symmetric"
,
symmetric
)
self
.
narrow_range
=
validator
.
check_bool
(
"narrow range"
,
narrow_range
)
self
.
narrow_range
=
validator
.
check_bool
(
"narrow range"
,
narrow_range
)
self
.
_convert_method_map
=
{
quant
.
Conv2dBnAct
:
self
.
_convert_conv
,
quant
.
DenseBnAct
:
self
.
_convert_dense
}
def
_convert_op_name
(
self
,
name
):
def
_convert_op_name
(
self
,
name
):
pattern
=
re
.
compile
(
r
'([A-Z]{1})'
)
pattern
=
re
.
compile
(
r
'([A-Z]{1})'
)
...
@@ -110,6 +119,7 @@ class ConvertToQuantNetwork:
...
@@ -110,6 +119,7 @@ class ConvertToQuantNetwork:
def
run
(
self
):
def
run
(
self
):
self
.
network
.
update_cell_prefix
()
self
.
network
.
update_cell_prefix
()
network
=
self
.
_convert_subcells2quant
(
self
.
network
)
network
=
self
.
_convert_subcells2quant
(
self
.
network
)
network
=
_AddFakeQuantInput
(
network
)
return
network
return
network
def
_convert_subcells2quant
(
self
,
network
):
def
_convert_subcells2quant
(
self
,
network
):
...
@@ -122,15 +132,9 @@ class ConvertToQuantNetwork:
...
@@ -122,15 +132,9 @@ class ConvertToQuantNetwork:
subcell
=
cells
[
name
]
subcell
=
cells
[
name
]
if
subcell
==
network
:
if
subcell
==
network
:
continue
continue
elif
isinstance
(
subcell
,
quant
.
Conv2dBnAct
):
elif
isinstance
(
subcell
,
(
quant
.
Conv2dBnAct
,
quant
.
DenseBnAct
)
):
prefix
=
subcell
.
param_prefix
prefix
=
subcell
.
param_prefix
new_subcell
=
self
.
_convert_conv
(
subcell
)
new_subcell
=
self
.
_convert_method_map
[
type
(
subcell
)](
subcell
)
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
change
=
True
elif
isinstance
(
subcell
,
quant
.
DenseBnAct
):
prefix
=
subcell
.
param_prefix
new_subcell
=
self
.
_convert_dense
(
subcell
)
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
change
=
True
change
=
True
...
@@ -199,10 +203,12 @@ class ConvertToQuantNetwork:
...
@@ -199,10 +203,12 @@ class ConvertToQuantNetwork:
symmetric
=
self
.
symmetric
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
)
narrow_range
=
self
.
narrow_range
)
subcell
.
conv
=
conv_inner
subcell
.
conv
=
conv_inner
if
subcell
.
activation
is
not
None
:
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
else
:
else
:
subcell
=
_AddFakeQuantAfterSubCell
(
subcell
)
subcell
.
has_act
=
True
subcell
.
activation
=
_AddFakeQuantAfterSubCell
(
F
.
identity
,
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
return
subcell
return
subcell
def
_convert_dense
(
self
,
subcell
):
def
_convert_dense
(
self
,
subcell
):
...
@@ -217,8 +223,12 @@ class ConvertToQuantNetwork:
...
@@ -217,8 +223,12 @@ class ConvertToQuantNetwork:
per_channel
=
self
.
per_channel
,
per_channel
=
self
.
per_channel
,
num_bits
=
self
.
weight_bits
)
num_bits
=
self
.
weight_bits
)
subcell
.
dense
=
dense_inner
subcell
.
dense
=
dense_inner
if
subcell
.
activation
is
not
None
:
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
else
:
subcell
.
has_act
=
True
subcell
.
activation
=
_AddFakeQuantAfterSubCell
(
F
.
identity
,
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
return
subcell
return
subcell
def
_convert_activation
(
self
,
activation
):
def
_convert_activation
(
self
,
activation
):
...
@@ -229,6 +239,147 @@ class ConvertToQuantNetwork:
...
@@ -229,6 +239,147 @@ class ConvertToQuantNetwork:
return
_ACTIVATION_MAP
[
act_class
](
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
return
_ACTIVATION_MAP
[
act_class
](
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
quant_delay
)
class
ExportQuantNetworkDeploy
:
"""
Convert quantization aware network to deploy network.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
Returns:
Cell, converted network.
"""
__quant_op_name__
=
[
"TensorAdd"
,
"Sub"
,
"Mul"
,
"RealDiv"
]
def
__init__
(
self
,
network
,
*
inputs
):
network
=
validator
.
check_isinstance
(
'network'
,
network
,
(
nn
.
Cell
,))
self
.
data_type
=
mstype
.
int8
self
.
network
=
copy
.
deepcopy
(
network
)
self
.
all_paramters
=
{
p
.
name
:
p
for
p
in
self
.
network
.
get_parameters
()}
self
.
get_inputs_table
(
inputs
)
def
get_inputs_table
(
self
,
inputs
):
"""Get the support info for quant export."""
phase_name
=
'export_quant'
graph_id
,
_
=
_executor
.
compile
(
self
.
network
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
self
.
quant_info_table
=
_executor
.
fetch_info_for_quant_export
(
graph_id
)
def
run
(
self
):
"""Start to convert."""
self
.
network
.
update_cell_prefix
()
network
=
self
.
network
if
isinstance
(
network
,
_AddFakeQuantInput
):
network
=
network
.
network
network
=
self
.
_convert_quant2deploy
(
network
)
return
network
def
_get_quant_block
(
self
,
cell_core
,
activation
,
fake_quant_a_out
):
"""convet network's quant subcell to deploy subcell"""
# Calculate the scale and zero point
w_minq_name
=
cell_core
.
fake_quant_weight
.
minq
.
name
np_type
=
mstype
.
dtype_to_nptype
(
self
.
data_type
)
scale_w
,
zp_w
=
quant_utils
.
scale_zp_from_fack_quant_cell
(
cell_core
.
fake_quant_weight
,
np_type
)
scale_a_out
,
_
=
quant_utils
.
scale_zp_from_fack_quant_cell
(
fake_quant_a_out
,
np_type
)
info
=
self
.
quant_info_table
.
get
(
w_minq_name
,
None
)
if
info
:
fack_quant_a_in_op
,
minq_name
=
info
maxq
=
self
.
all_paramters
[
minq_name
[:
-
4
]
+
"maxq"
]
minq
=
self
.
all_paramters
[
minq_name
]
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
maxq
,
minq
,
np_type
)
else
:
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fack_quant.minq`
{
w_minq_name
}
"
)
return
None
# Build the `Quant` `Dequant` op.
# AscendQuant only support perlayer version. Need check here.
quant_op
=
inner
.
AscendQuant
(
float
(
scale_a_in
),
float
(
zp_a_in
))
sqrt_mode
=
False
scale_deq
=
scale_a_out
*
scale_w
if
scale_deq
<
2
**
-
14
:
scale_deq
=
np
.
sqrt
(
scale_deq
)
sqrt_mode
=
True
dequant_op
=
inner
.
AscendDequant
(
sqrt_mode
)
# get op
op_core
=
cell_core
.
matmul
if
isinstance
(
cell_core
,
quant
.
DenseQuant
)
else
cell_core
.
conv
if
isinstance
(
activation
,
_AddFakeQuantAfterSubCell
):
activation
=
activation
.
subcell
elif
hasattr
(
activation
,
"get_origin"
):
activation
=
activation
.
get_origin
()
# get the `weight` and `bias`
weight
=
cell_core
.
weight
.
data
.
asnumpy
()
bias
=
None
if
isinstance
(
cell_core
,
(
quant
.
DenseQuant
,
quant
.
Conv2dQuant
)):
if
cell_core
.
has_bias
:
bias
=
cell_core
.
bias
.
data
.
asnumpy
()
elif
isinstance
(
cell_core
,
quant
.
Conv2dBatchNormQuant
):
weight
,
bias
=
quant_utils
.
fold_batchnorm
(
weight
,
cell_core
)
# apply the quant
weight
=
Tensor
(
quant_utils
.
weight2int
(
weight
,
scale_w
,
zp_w
),
self
.
data_type
)
if
bias
is
not
None
:
bias
=
Tensor
(
scale_a_in
*
scale_w
*
bias
,
mstype
.
int32
)
scale_deq
=
Tensor
(
scale_deq
,
mstype
.
float16
)
block
=
quant
.
QuantBlock
(
op_core
,
weight
,
quant_op
,
dequant_op
,
scale_deq
,
bias
,
activation
)
return
block
def
_convert_quant2deploy
(
self
,
network
):
"""Convet network's all quant subcell to deploy subcell."""
cells
=
network
.
name_cells
()
change
=
False
for
name
in
cells
:
subcell
=
cells
[
name
]
if
subcell
==
network
:
continue
cell_core
=
None
fake_quant_act
=
None
activation
=
None
if
isinstance
(
subcell
,
quant
.
Conv2dBnAct
):
cell_core
=
subcell
.
conv
activation
=
subcell
.
activation
fake_quant_act
=
activation
.
fake_quant_act
elif
isinstance
(
subcell
,
quant
.
DenseBnAct
):
cell_core
=
subcell
.
dense
activation
=
subcell
.
activation
fake_quant_act
=
activation
.
fake_quant_act
if
cell_core
is
not
None
:
new_subcell
=
self
.
_get_quant_block
(
cell_core
,
activation
,
fake_quant_act
)
if
new_subcell
:
prefix
=
subcell
.
param_prefix
new_subcell
.
update_parameters_name
(
prefix
+
'.'
)
network
.
insert_child_to_cell
(
name
,
new_subcell
)
change
=
True
elif
isinstance
(
subcell
,
_AddFakeQuantAfterSubCell
):
op
=
subcell
.
subcell
if
op
.
name
in
ConvertToQuantNetwork
.
__quant_op_name__
and
isinstance
(
op
,
ops
.
Primitive
):
network
.
__delattr__
(
name
)
network
.
__setattr__
(
name
,
op
)
change
=
True
else
:
self
.
_convert_quant2deploy
(
subcell
)
if
isinstance
(
network
,
nn
.
SequentialCell
)
and
change
:
network
.
cell_list
=
list
(
network
.
cells
())
return
network
def
export_geir
(
network
,
*
inputs
,
file_name
):
"""
Exports MindSpore quant predict model to deploy with GEIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
file_name (str): File name of model to export.
"""
exporter
=
ExportQuantNetworkDeploy
(
network
,
*
inputs
)
deploy_net
=
exporter
.
run
()
serialization
.
export
(
deploy_net
,
*
inputs
,
file_name
=
file_name
,
file_format
=
"GEIR"
)
def
convert_quant_network
(
network
,
def
convert_quant_network
(
network
,
quant_delay
=
0
,
quant_delay
=
0
,
bn_fold
=
False
,
bn_fold
=
False
,
...
...
mindspore/train/quant/quant_utils.py
浏览文件 @
9958bc47
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""
q
uantization utils."""
"""
Q
uantization utils."""
import
numpy
as
np
import
numpy
as
np
...
@@ -24,22 +24,19 @@ def cal_quantization_params(input_min,
...
@@ -24,22 +24,19 @@ def cal_quantization_params(input_min,
symmetric
=
False
,
symmetric
=
False
,
narrow_range
=
False
):
narrow_range
=
False
):
r
"""
r
"""
c
alculate quantization params for scale and zero point.
C
alculate quantization params for scale and zero point.
Args:
Args:
input_min (
int, list
): The dimension of channel or 1.
input_min (
numpy.ndarray
): The dimension of channel or 1.
input_max (
int, list
): The dimension of channel or 1.
input_max (
numpy.ndarray
): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Outputs:
Returns:
scale (int, list): quantization param.
scale (numpy.ndarray): quantization param.
zero point (int, list): quantization param.
zero point (numpy.ndarray): quantization param.
Examples:
>>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False)
"""
"""
input_max
=
np
.
maximum
(
0.0
,
input_max
)
input_max
=
np
.
maximum
(
0.0
,
input_max
)
input_min
=
np
.
minimum
(
0.0
,
input_min
)
input_min
=
np
.
minimum
(
0.0
,
input_min
)
...
@@ -92,27 +89,103 @@ def weight2int(data,
...
@@ -92,27 +89,103 @@ def weight2int(data,
scale
,
scale
,
zero_point
):
zero_point
):
r
"""
r
"""
c
alculate int8/uint8 weight from fp32. the formula is defined as:
C
alculate int8/uint8 weight from fp32. the formula is defined as:
.. math::
.. math::
int8/uint8 = round(float/scale) + offset
int8/uint8 = round(float/scale) + offset
Args:
Args:
data (
int, list
): The dimension of channel or 1. Should be NCHW.
data (
numpy.ndarray
): The dimension of channel or 1. Should be NCHW.
scale (
int, list
): The dimension of channel or 1.
scale (
numpy.ndarray
): The dimension of channel or 1.
zero_point (
int, list
): The dimension of channel or 1.
zero_point (
numpy.ndarray
): The dimension of channel or 1.
Outputs:
Returns:
weight (int, list): The dimension of channel or 1.
weight (numpy.ndarray): The dimension of channel or 1.
Examples:
>>> weight = weight2int([1, 2, 1], 1, 0)
"""
"""
if
scale
.
shape
!=
zero_point
.
shape
:
if
scale
.
shape
!=
zero_point
.
shape
:
raise
ValueError
(
"scale and zero_point should have the same shape."
)
raise
ValueError
(
"scale and zero_point should have the same shape."
)
if
scale
.
shape
[
0
]
>
0
:
if
scale
.
shape
[
0
]
>
0
:
scale
=
scale
.
reshape
(
1
,
-
1
,
1
,
1
)
scale
=
scale
.
reshape
(
1
,
-
1
)
zero_point
=
zero_point
.
reshape
(
1
,
-
1
,
1
,
1
)
zero_point
=
zero_point
.
reshape
(
1
,
-
1
)
return
np
.
round
((
data
/
scale
)
+
zero_point
)
return
np
.
round
((
data
/
scale
)
+
zero_point
)
def
scale_zp_from_fack_quant_cell
(
cell
,
data_type
):
r
"""
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
Args:
cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq
=
cell
.
minq
.
data
.
asnumpy
()
maxq
=
cell
.
maxq
.
data
.
asnumpy
()
op
=
cell
.
fake_quant
scale
,
zp
=
cal_quantization_params
(
minq
,
maxq
,
data_type
,
num_bits
=
op
.
num_bits
,
symmetric
=
op
.
symmetric
,
narrow_range
=
op
.
narrow_range
)
return
scale
,
zp
def
scale_zp_from_data
(
op
,
minq
,
maxq
,
data_type
):
r
"""
Get calculate quantization params for scale and zero point.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
`mindspore.ops.operation.FakeQuantPerChannel`
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq
=
minq
.
data
.
asnumpy
()
maxq
=
maxq
.
data
.
asnumpy
()
scale
,
zp
=
cal_quantization_params
(
minq
,
maxq
,
data_type
,
num_bits
=
op
.
num_bits
,
symmetric
=
op
.
symmetric
,
narrow_range
=
op
.
narrow_range
)
return
scale
,
zp
def
fold_batchnorm
(
weight
,
cell_quant
):
r
"""
Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
weight (numpy.ndarray): Weight of `cell_quant`.
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
Returns:
weight (numpy.ndarray): Folded weight.
bias (numpy.ndarray): Folded bias.
"""
variance
=
cell_quant
.
moving_variance
.
data
.
asnumpy
()
mean
=
cell_quant
.
moving_mean
.
data
.
asnumpy
()
gamma
=
cell_quant
.
gamma
.
data
.
asnumpy
()
beta
=
cell_quant
.
beta
.
data
.
asnumpy
()
epsilon
=
cell_quant
.
eps
sigma
=
np
.
sqrt
(
variance
+
epsilon
)
gamma
=
gamma
.
reshape
(
-
1
,
1
,
1
,
1
)
sigma
=
sigma
.
reshape
(
-
1
,
1
,
1
,
1
)
mean
=
mean
.
reshape
(
-
1
,
1
,
1
,
1
)
weight
=
weight
*
gamma
/
sigma
bias
=
beta
-
gamma
*
mean
/
sigma
return
weight
,
bias
tests/st/model_zoo_tests/yolov3/test_yolov3.py
浏览文件 @
9958bc47
...
@@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'):
...
@@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'):
params
=
network
.
trainable_params
()
params
=
network
.
trainable_params
()
for
p
in
params
:
for
p
in
params
:
if
isinstance
(
p
.
data
,
Tensor
)
and
'beta'
not
in
p
.
name
and
'gamma'
not
in
p
.
name
and
'bias'
not
in
p
.
name
:
if
isinstance
(
p
.
data
,
Tensor
)
and
'beta'
not
in
p
.
name
and
'gamma'
not
in
p
.
name
and
'bias'
not
in
p
.
name
:
p
.
set_parameter_data
(
initializer
(
init_value
,
p
.
data
.
shape
(),
p
.
data
.
dtype
()
))
p
.
set_parameter_data
(
initializer
(
init_value
,
p
.
data
.
shape
,
p
.
data
.
dtype
))
class
ModelCallback
(
Callback
):
class
ModelCallback
(
Callback
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
tests/ut/python/train/quant/test_quant.py
浏览文件 @
9958bc47
...
@@ -13,9 +13,14 @@
...
@@ -13,9 +13,14 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
""" tests for quant """
""" tests for quant """
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore
import
nn
from
mindspore
import
nn
from
mindspore.train.quant
import
quant
as
qat
from
mobilenetv2_combined
import
MobileNetV2
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
...
@@ -37,23 +42,45 @@ class LeNet5(nn.Cell):
...
@@ -37,23 +42,45 @@ class LeNet5(nn.Cell):
def
__init__
(
self
,
num_class
=
10
):
def
__init__
(
self
,
num_class
=
10
):
super
(
LeNet5
,
self
).
__init__
()
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
num_class
=
num_class
self
.
conv1
=
nn
.
Conv2dBnAct
(
1
,
6
,
kernel_size
=
5
,
batchnorm
=
True
,
activation
=
'relu6'
)
self
.
conv1
=
nn
.
Conv2dBnAct
(
1
,
6
,
kernel_size
=
5
,
batchnorm
=
True
,
activation
=
'relu6'
,
pad_mode
=
"valid"
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
kernel_size
=
5
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
kernel_size
=
5
,
activation
=
'relu'
,
pad_mode
=
"valid"
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc3
=
nn
.
DenseBnAct
(
84
,
self
.
num_class
)
self
.
fc3
=
nn
.
DenseBnAct
(
84
,
self
.
num_class
)
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatte
r
n
=
nn
.
Flatten
()
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatte
r
n
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
x
=
self
.
fc3
(
x
)
return
x
return
x
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
def
test_qat_lenet
():
img
=
Tensor
(
np
.
ones
((
32
,
1
,
32
,
32
)).
astype
(
np
.
float32
))
net
=
LeNet5
()
net
=
qat
.
convert_quant_network
(
net
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
,
weight_bits
=
8
,
act_bits
=
8
)
# should load the checkpoint. mock here
for
param
in
net
.
get_parameters
():
param
.
init_data
()
qat
.
export_geir
(
net
,
img
,
file_name
=
"quant.pb"
)
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
def
test_qat_mobile
():
net
=
MobileNetV2
()
img
=
Tensor
(
np
.
ones
((
1
,
3
,
224
,
224
)).
astype
(
np
.
float32
))
net
=
qat
.
convert_quant_network
(
net
,
quant_delay
=
0
,
bn_fold
=
True
,
freeze_bn
=
10000
,
weight_bits
=
8
,
act_bits
=
8
)
# should load the checkpoint. mock here
for
param
in
net
.
get_parameters
():
param
.
init_data
()
qat
.
export_geir
(
net
,
img
,
file_name
=
"quant.pb"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录