Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
28d1d370
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
28d1d370
编写于
7月 16, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
End at validate when export.
上级
536f7533
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
90 addition
and
88 deletion
+90
-88
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+37
-48
mindspore/ccsrc/pipeline/jit/pipeline.h
mindspore/ccsrc/pipeline/jit/pipeline.h
+0
-1
mindspore/ccsrc/transform/graph_ir/convert.cc
mindspore/ccsrc/transform/graph_ir/convert.cc
+2
-2
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+1
-1
mindspore/ops/operations/_inner_ops.py
mindspore/ops/operations/_inner_ops.py
+4
-4
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+13
-10
mindspore/train/quant/quant_utils.py
mindspore/train/quant/quant_utils.py
+1
-1
mindspore/train/serialization.py
mindspore/train/serialization.py
+4
-3
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+16
-16
tests/ut/python/train/quant/test_quant.py
tests/ut/python/train/quant/test_quant.py
+12
-2
未找到文件。
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
28d1d370
...
@@ -383,16 +383,6 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
...
@@ -383,16 +383,6 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
MS_LOG
(
INFO
)
<<
"End save compiled func graph!"
;
MS_LOG
(
INFO
)
<<
"End save compiled func graph!"
;
}
}
bool
ExecutorPy
::
ChangeExportGeirUseVmFlag
(
bool
use_vm
,
const
std
::
string
&
phase_s
)
const
{
std
::
string
phase_prefix
=
GetPhasePrefix
(
phase_s
);
if
(
use_vm
&&
phase_prefix
==
"export"
)
{
MS_LOG
(
INFO
)
<<
"Use ge backend to export geir"
;
use_vm
=
false
;
}
return
use_vm
;
}
void
ExecutorPy
::
GetGeBackendPolicy
()
const
{
void
ExecutorPy
::
GetGeBackendPolicy
()
const
{
auto
ms_context
=
MsContext
::
GetInstance
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
MS_EXCEPTION_IF_NULL
(
ms_context
);
...
@@ -402,6 +392,40 @@ void ExecutorPy::GetGeBackendPolicy() const {
...
@@ -402,6 +392,40 @@ void ExecutorPy::GetGeBackendPolicy() const {
}
}
}
}
bool
IsPhaseExportGeir
(
const
std
::
string
&
phase_s
)
{
auto
phase_to_export
=
"export.geir"
;
return
phase_s
.
rfind
(
phase_to_export
,
0
)
!=
std
::
string
::
npos
;
}
std
::
vector
<
ActionItem
>
GetPipline
(
const
ResourcePtr
&
resource
,
const
std
::
string
&
phase_s
,
bool
use_vm
)
{
bool
is_geir
=
IsPhaseExportGeir
(
phase_s
);
std
::
string
backend
=
MsContext
::
GetInstance
()
->
backend_policy
();
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
if
(
mindspore
::
parallel
::
ps
::
Util
::
IsParamServerMode
())
{
mindspore
::
parallel
::
ps
::
Util
::
SetInternalEnvVar
();
}
if
(
parallel
::
ps
::
Util
::
IsRoleOfPServer
())
{
resource
->
results
()[
kBackend
]
=
compile
::
CreateBackend
();
return
PServerPipeline
();
}
if
(
parallel
::
ps
::
Util
::
IsRoleOfScheduler
())
{
return
PSchedulerPipeline
();
}
#endif
if
(
use_vm
&&
backend
!=
"ge"
&&
!
is_geir
)
{
// Create backend and session
auto
backend_ptr
=
compile
::
CreateBackend
();
// Connect session to debugger
backend_ptr
->
SetDebugger
();
resource
->
results
()[
kBackend
]
=
backend_ptr
;
return
VmPipeline
();
}
return
GePipeline
();
}
bool
ExecutorPy
::
CompileInner
(
const
py
::
object
&
obj
,
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
,
bool
use_vm
)
{
bool
ExecutorPy
::
CompileInner
(
const
py
::
object
&
obj
,
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
,
bool
use_vm
)
{
MS_LOG
(
DEBUG
)
<<
"Start ExecutorPy compile!"
;
MS_LOG
(
DEBUG
)
<<
"Start ExecutorPy compile!"
;
if
((
!
py
::
isinstance
<
py
::
str
>
(
phase
)))
{
if
((
!
py
::
isinstance
<
py
::
str
>
(
phase
)))
{
...
@@ -420,43 +444,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
...
@@ -420,43 +444,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
std
::
string
phase_s
=
py
::
cast
<
std
::
string
>
(
phase
);
std
::
string
phase_s
=
py
::
cast
<
std
::
string
>
(
phase
);
MS_LOG
(
INFO
)
<<
"ExecutorPy compile phase:"
<<
phase_s
<<
"!"
;
MS_LOG
(
INFO
)
<<
"ExecutorPy compile phase:"
<<
phase_s
<<
"!"
;
ResourcePtr
resource
=
std
::
make_shared
<
Resource
>
(
obj
);
ResourcePtr
resource
=
std
::
make_shared
<
Resource
>
(
obj
);
std
::
vector
<
ActionItem
>
p_actions
;
use_vm
=
ChangeExportGeirUseVmFlag
(
use_vm
,
phase_s
);
std
::
string
backend
=
MsContext
::
GetInstance
()
->
backend_policy
();
#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
if
(
mindspore
::
parallel
::
ps
::
Util
::
IsParamServerMode
())
{
mindspore
::
parallel
::
ps
::
Util
::
SetInternalEnvVar
();
}
if
(
parallel
::
ps
::
Util
::
IsRoleOfPServer
())
{
resource
->
results
()[
kBackend
]
=
compile
::
CreateBackend
();
p_actions
=
PServerPipeline
();
}
else
if
(
parallel
::
ps
::
Util
::
IsRoleOfScheduler
())
{
p_actions
=
PSchedulerPipeline
();
}
else
if
(
use_vm
&&
backend
!=
"ge"
)
{
// Create backend and session
auto
backend_ptr
=
compile
::
CreateBackend
();
// Connect session to debugger
backend_ptr
->
SetDebugger
();
resource
->
results
()[
kBackend
]
=
backend_ptr
;
p_actions
=
VmPipeline
();
}
else
{
p_actions
=
GePipeline
();
}
#else
if
(
use_vm
&&
backend
!=
"ge"
)
{
// Create backend and session
auto
backend_ptr
=
compile
::
CreateBackend
();
// Connect session to debugger
backend_ptr
->
SetDebugger
();
resource
->
results
()[
kBackend
]
=
backend_ptr
;
p_actions
=
VmPipeline
();
}
else
{
p_actions
=
GePipeline
();
}
#endif
auto
p_actions
=
GetPipline
(
resource
,
phase_s
,
use_vm
);
std
::
shared_ptr
<
Pipeline
>
pip
=
std
::
make_shared
<
Pipeline
>
(
resource
,
FilterActions
(
p_actions
,
phase_s
));
std
::
shared_ptr
<
Pipeline
>
pip
=
std
::
make_shared
<
Pipeline
>
(
resource
,
FilterActions
(
p_actions
,
phase_s
));
// get the parameters items and add the value to args_spec
// get the parameters items and add the value to args_spec
...
@@ -490,8 +479,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
...
@@ -490,8 +479,8 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
}
}
std
::
vector
<
ActionItem
>
ExecutorPy
::
FilterActions
(
const
std
::
vector
<
ActionItem
>
&
actions
,
const
std
::
string
&
phase
)
{
std
::
vector
<
ActionItem
>
ExecutorPy
::
FilterActions
(
const
std
::
vector
<
ActionItem
>
&
actions
,
const
std
::
string
&
phase
)
{
//
phase does not contain 'export_onnx'
//
filter action after validate when 'export'.
if
(
GetPhasePrefix
(
phase
).
find
(
"export_onnx"
)
==
std
::
string
::
npos
)
{
if
(
GetPhasePrefix
(
phase
).
rfind
(
"export"
,
0
)
==
std
::
string
::
npos
)
{
return
actions
;
return
actions
;
}
}
MS_LOG
(
INFO
)
<<
"Phase is '"
<<
phase
<<
"', filter out actions after stage 'validate'"
;
MS_LOG
(
INFO
)
<<
"Phase is '"
<<
phase
<<
"', filter out actions after stage 'validate'"
;
...
...
mindspore/ccsrc/pipeline/jit/pipeline.h
浏览文件 @
28d1d370
...
@@ -101,7 +101,6 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
...
@@ -101,7 +101,6 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
private:
private:
ExecutorPy
();
ExecutorPy
();
void
ConvertObjectToTensors
(
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
*
tensors
);
void
ConvertObjectToTensors
(
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
tensor
::
TensorPtr
>
*
tensors
);
bool
ChangeExportGeirUseVmFlag
(
bool
use_vm
,
const
std
::
string
&
phase_s
)
const
;
void
GetGeBackendPolicy
()
const
;
void
GetGeBackendPolicy
()
const
;
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
// 'validate' stage
// 'validate' stage
...
...
mindspore/ccsrc/transform/graph_ir/convert.cc
浏览文件 @
28d1d370
...
@@ -205,8 +205,8 @@ const char kNameL2Loss[] = "L2Loss";
...
@@ -205,8 +205,8 @@ const char kNameL2Loss[] = "L2Loss";
const
char
kNameCTCLoss
[]
=
"CTCLoss"
;
const
char
kNameCTCLoss
[]
=
"CTCLoss"
;
const
char
kNameRange
[]
=
"Range"
;
const
char
kNameRange
[]
=
"Range"
;
const
char
kNameSquareSumAll
[]
=
"SquareSumAll"
;
const
char
kNameSquareSumAll
[]
=
"SquareSumAll"
;
const
char
kNameAscendQuant
[]
=
"
Ascend
Quant"
;
const
char
kNameAscendQuant
[]
=
"Quant"
;
const
char
kNameAscendDequant
[]
=
"
Ascend
Dequant"
;
const
char
kNameAscendDequant
[]
=
"Dequant"
;
const
char
kNameCase
[]
=
"Case"
;
const
char
kNameCase
[]
=
"Case"
;
// -----------------OpAdapter initialization--------------
// -----------------OpAdapter initialization--------------
...
...
mindspore/nn/layer/quant.py
浏览文件 @
28d1d370
...
@@ -1107,7 +1107,7 @@ class QuantBlock(Cell):
...
@@ -1107,7 +1107,7 @@ class QuantBlock(Cell):
r
"""
r
"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with
AscendQuant and Ascend
DeQuant.
Calculate Conv or Dense in Int8, with
Quant and
DeQuant.
Notes:
Notes:
This block is only for deploy, and not trainable.
This block is only for deploy, and not trainable.
...
...
mindspore/ops/operations/_inner_ops.py
浏览文件 @
28d1d370
...
@@ -160,7 +160,7 @@ class Range(PrimitiveWithInfer):
...
@@ -160,7 +160,7 @@ class Range(PrimitiveWithInfer):
return
x_dtype
return
x_dtype
class
Ascend
Quant
(
PrimitiveWithInfer
):
class
Quant
(
PrimitiveWithInfer
):
r
"""
r
"""
Returns the quantized value of input_x.
Returns the quantized value of input_x.
...
@@ -192,7 +192,7 @@ class AscendQuant(PrimitiveWithInfer):
...
@@ -192,7 +192,7 @@ class AscendQuant(PrimitiveWithInfer):
Examples:
Examples:
>>> input_x = Tensor([100.0, 150.0], mstype.float32)
>>> input_x = Tensor([100.0, 150.0], mstype.float32)
>>> quant = P.
Ascend
Quant(80.0, 0.0, False, "Round")
>>> quant = P.Quant(80.0, 0.0, False, "Round")
>>> y = quant(input_x)
>>> y = quant(input_x)
"""
"""
...
@@ -213,7 +213,7 @@ class AscendQuant(PrimitiveWithInfer):
...
@@ -213,7 +213,7 @@ class AscendQuant(PrimitiveWithInfer):
return
mstype
.
int8
return
mstype
.
int8
class
Ascend
Dequant
(
PrimitiveWithInfer
):
class
Dequant
(
PrimitiveWithInfer
):
r
"""
r
"""
Returns the dequantized value of input_x.
Returns the dequantized value of input_x.
This operation will do ReLU to the dequantized value if `relu_flag` is True.
This operation will do ReLU to the dequantized value if `relu_flag` is True.
...
@@ -245,7 +245,7 @@ class AscendDequant(PrimitiveWithInfer):
...
@@ -245,7 +245,7 @@ class AscendDequant(PrimitiveWithInfer):
Examples:
Examples:
>>> input_x = Tensor([100.0, 150.0], mstype.float32)
>>> input_x = Tensor([100.0, 150.0], mstype.float32)
>>> dequant = P.
Ascend
Dequant(False, False)
>>> dequant = P.Dequant(False, False)
>>> y = dequant(input_x)
>>> y = dequant(input_x)
"""
"""
@
prim_attr_register
@
prim_attr_register
...
...
mindspore/train/quant/quant.py
浏览文件 @
28d1d370
...
@@ -329,14 +329,14 @@ class ExportToQuantInferNetwork:
...
@@ -329,14 +329,14 @@ class ExportToQuantInferNetwork:
return
None
return
None
# Build the `Quant` `Dequant` op.
# Build the `Quant` `Dequant` op.
#
Ascend
Quant only support perlayer version. Need check here.
# Quant only support perlayer version. Need check here.
quant_op
=
inner
.
Ascend
Quant
(
float
(
scale_a_in
),
float
(
zp_a_in
))
quant_op
=
inner
.
Quant
(
float
(
scale_a_in
),
float
(
zp_a_in
))
sqrt_mode
=
False
sqrt_mode
=
False
scale_deq
=
scale_a_out
*
scale_w
scale_deq
=
scale_a_out
*
scale_w
if
(
scale_deq
<
2
**
-
14
).
all
():
if
(
scale_deq
<
2
**
-
14
).
all
():
scale_deq
=
np
.
sqrt
(
scale_deq
)
scale_deq
=
np
.
sqrt
(
scale_deq
)
sqrt_mode
=
True
sqrt_mode
=
True
dequant_op
=
inner
.
Ascend
Dequant
(
sqrt_mode
)
dequant_op
=
inner
.
Dequant
(
sqrt_mode
)
# get op
# get op
op_core
=
cell_core
.
matmul
if
isinstance
(
cell_core
,
quant
.
DenseQuant
)
else
cell_core
.
conv
op_core
=
cell_core
.
matmul
if
isinstance
(
cell_core
,
quant
.
DenseQuant
)
else
cell_core
.
conv
...
@@ -411,11 +411,15 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
...
@@ -411,11 +411,15 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
file_name (str): File name of model to export.
file_name (str): File name of model to export.
mean (int): Input data mean. Default: 127.5.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'BINARY' format for exported
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
quantization aware model. Default: 'GEIR'.
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- BINARY: Binary format for model. An intermidiate representation format for models.
"""
"""
supported_device
=
[
"Ascend"
]
supported_device
=
[
"Ascend"
]
supported_formats
=
[
'GEIR'
]
supported_formats
=
[
'GEIR'
,
'BINARY'
]
mean
=
validator
.
check_type
(
"mean"
,
mean
,
(
int
,
float
))
mean
=
validator
.
check_type
(
"mean"
,
mean
,
(
int
,
float
))
std_dev
=
validator
.
check_type
(
"std_dev"
,
std_dev
,
(
int
,
float
))
std_dev
=
validator
.
check_type
(
"std_dev"
,
std_dev
,
(
int
,
float
))
...
@@ -428,10 +432,9 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
...
@@ -428,10 +432,9 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
network
.
set_train
(
False
)
network
.
set_train
(
False
)
if
file_format
==
'GEIR'
:
exporter
=
ExportToQuantInferNetwork
(
network
,
mean
,
std_dev
,
*
inputs
)
exporter
=
ExportToQuantInferNetwork
(
network
,
mean
,
std_dev
,
*
inputs
)
deploy_net
=
exporter
.
run
()
deploy_net
=
exporter
.
run
()
serialization
.
export
(
deploy_net
,
*
inputs
,
file_name
=
file_name
,
file_format
=
file_format
)
serialization
.
export
(
deploy_net
,
*
inputs
,
file_name
=
file_name
,
file_format
=
file_format
)
def
convert_quant_network
(
network
,
def
convert_quant_network
(
network
,
...
...
mindspore/train/quant/quant_utils.py
浏览文件 @
28d1d370
...
@@ -104,7 +104,7 @@ def weight2int(data, scale, zero_point):
...
@@ -104,7 +104,7 @@ def weight2int(data, scale, zero_point):
raise
ValueError
(
"`scale` and `zero_point` should have the same shape."
)
raise
ValueError
(
"`scale` and `zero_point` should have the same shape."
)
if
scale
.
shape
[
0
]
<
0
:
if
scale
.
shape
[
0
]
<
0
:
raise
ValueError
(
"`scale` and `zero_point` shape should greater than zero."
)
raise
ValueError
(
"`scale` and `zero_point` shape should greater than zero."
)
if
len
(
scale
.
shape
)
>
1
:
if
len
(
scale
.
shape
)
>
=
1
and
scale
.
shape
[
0
]
>
1
:
# for perchannel
# for perchannel
if
scale
.
shape
[
0
]
==
data
.
shape
[
0
]:
if
scale
.
shape
[
0
]
==
data
.
shape
[
0
]:
# `Conv2d` or `Dense` op weight
# `Conv2d` or `Dense` op weight
...
...
mindspore/train/serialization.py
浏览文件 @
28d1d370
...
@@ -451,19 +451,20 @@ def export(net, *inputs, file_name, file_format='GEIR'):
...
@@ -451,19 +451,20 @@ def export(net, *inputs, file_name, file_format='GEIR'):
# export model
# export model
net
.
init_parameters_data
()
net
.
init_parameters_data
()
if
file_format
==
'GEIR'
:
if
file_format
==
'GEIR'
:
_executor
.
compile
(
net
,
*
inputs
,
phase
=
'export'
)
phase_name
=
'export.geir'
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
)
_executor
.
export
(
net
,
file_name
,
file_format
)
_executor
.
export
(
net
,
file_name
,
file_format
)
elif
file_format
==
'ONNX'
:
# file_format is 'ONNX'
elif
file_format
==
'ONNX'
:
# file_format is 'ONNX'
# NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline,
# NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline,
# do not change it to other values.
# do not change it to other values.
phase_name
=
'export
_
onnx'
phase_name
=
'export
.
onnx'
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
onnx_stream
=
_executor
.
_get_func_graph_proto
(
graph_id
)
onnx_stream
=
_executor
.
_get_func_graph_proto
(
graph_id
)
with
open
(
file_name
,
'wb'
)
as
f
:
with
open
(
file_name
,
'wb'
)
as
f
:
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
f
.
write
(
onnx_stream
)
f
.
write
(
onnx_stream
)
elif
file_format
==
'BINARY'
:
# file_format is 'BINARY'
elif
file_format
==
'BINARY'
:
# file_format is 'BINARY'
phase_name
=
'export
_
binary'
phase_name
=
'export
.
binary'
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
onnx_stream
=
_executor
.
_get_func_graph_proto
(
graph_id
,
'binary_ir'
)
onnx_stream
=
_executor
.
_get_func_graph_proto
(
graph_id
,
'binary_ir'
)
with
open
(
file_name
,
'wb'
)
as
f
:
with
open
(
file_name
,
'wb'
)
as
f
:
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
28d1d370
...
@@ -2180,36 +2180,36 @@ test_case_other_ops = [
...
@@ -2180,36 +2180,36 @@ test_case_other_ops = [
]
]
test_case_quant_ops
=
[
test_case_quant_ops
=
[
(
'
Ascend
Quant_1'
,
{
(
'Quant_1'
,
{
'block'
:
inner
.
Ascend
Quant
(
0.5
,
0.0
,
False
,
"Round"
),
'block'
:
inner
.
Quant
(
0.5
,
0.0
,
False
,
"Round"
),
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
1
,
2
,
4
,
4
),
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
(
np
.
random
.
rand
(
1
,
2
,
4
,
4
),
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_2'
,
{
(
'Quant_2'
,
{
'block'
:
inner
.
Ascend
Quant
(
80.0
,
10.0
,
True
,
"Round"
),
'block'
:
inner
.
Quant
(
80.0
,
10.0
,
True
,
"Round"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_3'
,
{
(
'Quant_3'
,
{
'block'
:
inner
.
Ascend
Quant
(
80.0
,
0.0
,
False
,
"Floor"
),
'block'
:
inner
.
Quant
(
80.0
,
0.0
,
False
,
"Floor"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_4'
,
{
(
'Quant_4'
,
{
'block'
:
inner
.
Ascend
Quant
(
80.0
,
0.0
,
False
,
"Ceil"
),
'block'
:
inner
.
Quant
(
80.0
,
0.0
,
False
,
"Ceil"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_5'
,
{
(
'Quant_5'
,
{
'block'
:
inner
.
Ascend
Quant
(
80.0
,
0.0
,
False
,
"Trunc"
),
'block'
:
inner
.
Quant
(
80.0
,
0.0
,
False
,
"Trunc"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_6'
,
{
(
'Quant_6'
,
{
'block'
:
inner
.
Ascend
Quant
(
-
80.0
,
10.0
,
False
,
"Round"
),
'block'
:
inner
.
Quant
(
-
80.0
,
10.0
,
False
,
"Round"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_7'
,
{
(
'Quant_7'
,
{
'block'
:
inner
.
Ascend
Quant
(
80.0
,
-
10.0
,
False
,
"Round"
),
'block'
:
inner
.
Quant
(
80.0
,
-
10.0
,
False
,
"Round"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float32
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
(
'
Ascend
Quant_8'
,
{
(
'Quant_8'
,
{
'block'
:
inner
.
Ascend
Quant
(
80.0
,
10.0
,
False
,
"Round"
),
'block'
:
inner
.
Quant
(
80.0
,
10.0
,
False
,
"Round"
),
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float16
)],
'desc_inputs'
:
[
Tensor
([
100.0
,
200.0
],
mstype
.
float16
)],
'skip'
:
[
'backward'
]}),
'skip'
:
[
'backward'
]}),
]
]
...
...
tests/ut/python/train/quant/test_quant.py
浏览文件 @
28d1d370
...
@@ -75,10 +75,20 @@ def test_qat_lenet():
...
@@ -75,10 +75,20 @@ def test_qat_lenet():
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
def
test_qat_mobile
():
def
test_qat_mobile
_per_channel_tf
():
network
=
mobilenetV2
(
num_classes
=
1000
)
network
=
mobilenetV2
(
num_classes
=
1000
)
img
=
Tensor
(
np
.
ones
((
1
,
3
,
224
,
224
)).
astype
(
np
.
float32
))
img
=
Tensor
(
np
.
ones
((
1
,
3
,
224
,
224
)).
astype
(
np
.
float32
))
network
=
qat
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
network
=
qat
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
False
,
True
],
symmetric
=
[
True
,
False
])
# should load the checkpoint. mock here
for
param
in
network
.
get_parameters
():
param
.
init_data
()
qat
.
export
(
network
,
img
,
file_name
=
"quant.pb"
)
@
pytest
.
mark
.
skip
(
reason
=
"no `te.lang.cce` in ut env"
)
def
test_qat_mobile_per_channel_ff
():
network
=
mobilenetV2
(
num_classes
=
1000
)
img
=
Tensor
(
np
.
ones
((
1
,
3
,
224
,
224
)).
astype
(
np
.
float32
))
network
=
qat
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
False
,
False
],
symmetric
=
[
True
,
False
])
# should load the checkpoint. mock here
# should load the checkpoint. mock here
for
param
in
network
.
get_parameters
():
for
param
in
network
.
get_parameters
():
param
.
init_data
()
param
.
init_data
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录