Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
1ebf98b7
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1ebf98b7
编写于
5月 19, 2020
作者:
W
wangnan39@huawei.com
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add model init api to compile df graph before exec
上级
1ba8e052
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
162 addition
and
40 deletion
+162
-40
mindspore/common/api.py
mindspore/common/api.py
+4
-0
mindspore/nn/cell.py
mindspore/nn/cell.py
+22
-3
mindspore/train/model.py
mindspore/train/model.py
+114
-34
tests/ut/python/train/test_training.py
tests/ut/python/train/test_training.py
+22
-3
未找到文件。
mindspore/common/api.py
浏览文件 @
1ebf98b7
...
...
@@ -383,6 +383,10 @@ class _Executor:
obj
.
parameter_layout_dict
=
self
.
_executor
.
get_parameter_layout
(
phase
)
obj
.
load_parameter_slice
(
params
)
# set parallel inputs in sink mode
if
auto_parallel_mode
and
(
args
and
isinstance
(
args
[
0
],
Tensor
)
and
args
[
0
].
virtual_flag
):
obj
.
set_parallel_input_with_inputs
(
*
args
)
# the following GE init process is not needed when use vm or ms backend
if
enable_ge
:
# decide whether to sink based on whether the inputs is virtual or not
...
...
mindspore/nn/cell.py
浏览文件 @
1ebf98b7
...
...
@@ -288,6 +288,15 @@ class Cell:
parallel_inputs_run
.
append
(
new_tensor
)
return
tuple
(
parallel_inputs_run
)
def
set_parallel_input_with_inputs
(
self
,
*
inputs
):
"""
Slice inputs tensors by parallel strategies, and set the sliced inputs to `_parallel_input_run`
Args:
inputs (tuple): inputs of construct method.
"""
self
.
_parallel_inputs_run
=
self
.
_load_inputs
(
*
inputs
)
def
_get_construct_inputs_number_and_name
(
self
):
"""Compute self._construct_inputs_names and self._construct_inputs_num"""
import
inspect
...
...
@@ -304,6 +313,15 @@ class Cell:
self
.
_construct_inputs_names
=
self
.
_construct_inputs_names
[
1
:
self
.
_construct_inputs_num
]
self
.
_construct_inputs_num
=
self
.
_construct_inputs_num
-
1
def
compile
(
self
,
*
inputs
):
"""
Compiles cell.
Args:
inputs (tuple): Input parameters.
"""
_executor
.
compile
(
self
,
*
inputs
,
phase
=
self
.
phase
,
auto_parallel_mode
=
self
.
_auto_parallel_mode
)
def
compile_and_run
(
self
,
*
inputs
):
"""
Compiles and runs cell.
...
...
@@ -314,13 +332,14 @@ class Cell:
Returns:
Object, the result of executing.
"""
_
,
compile_flag
=
_executor
.
compile
(
self
,
*
inputs
,
phase
=
self
.
phase
,
auto_parallel_mode
=
self
.
_auto_parallel_mode
)
_executor
.
compile
(
self
,
*
inputs
,
phase
=
self
.
phase
,
auto_parallel_mode
=
self
.
_auto_parallel_mode
)
if
self
.
_auto_parallel_mode
:
if
inputs
and
isinstance
(
inputs
[
0
],
Tensor
)
and
inputs
[
0
].
virtual_flag
and
(
not
compile_flag
):
if
inputs
and
isinstance
(
inputs
[
0
],
Tensor
)
and
inputs
[
0
].
virtual_flag
:
# get parallel inputs in sink mode, parallel inputs set in _executor.compile
parallel_inputs_run
=
self
.
_parallel_inputs_run
else
:
# set parallel inputs in normal mode
self
.
_parallel_inputs_run
=
self
.
_load_inputs
(
*
inputs
)
parallel_inputs_run
=
self
.
_parallel_inputs_run
return
_executor
(
self
,
*
parallel_inputs_run
,
phase
=
self
.
phase
)
...
...
mindspore/train/model.py
浏览文件 @
1ebf98b7
...
...
@@ -217,6 +217,94 @@ class Model:
scaling_sens
/=
self
.
_device_number
return
scaling_sens
def
_exec_preprocess
(
self
,
network
,
is_train
,
phase
,
dataset
,
dataset_sink_mode
):
"""Initializes dataset."""
need_wrap
=
False
if
dataset_sink_mode
:
# remove later to deal with loop sink
if
not
hasattr
(
dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"device_target"
)
==
"Ascend"
\
and
not
context
.
get_context
(
"enable_ge"
):
need_wrap
=
True
if
not
is_train
:
dataset
.
__loop_size__
=
1
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
)
# remove later to deal with loop sink
if
need_wrap
:
network
=
nn
.
DataWrapper
(
network
,
*
(
dataset_helper
.
types_shapes
()),
dataset
.
__ME_INITED__
)
network
.
set_train
(
is_train
)
network
.
phase
=
phase
return
dataset_helper
,
network
def
init
(
self
,
train_dataset
=
None
,
valid_dataset
=
None
):
"""
Initializes compute graphs and data graphs with sink mode.
Note:
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
Args:
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
initialized. Default: None.
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
be initialized, and `metrics` in `Model` can not be None. Default: None.
Examples:
>>> train_dataset = get_train_dataset()
>>> valid_dataset = get_valid_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
>>> model.init(train_dataset, valid_dataset)
>>> model.train(2, train_dataset)
>>> model.eval(valid_dataset)
"""
if
context
.
get_context
(
"mode"
)
!=
context
.
GRAPH_MODE
or
context
.
get_context
(
"device_target"
)
!=
"Ascend"
:
raise
RuntimeError
(
'Pre-init process only supports GRAPH MODE and Ascend target currently.'
)
if
not
train_dataset
and
not
valid_dataset
:
raise
ValueError
(
'Both train_dataset and valid_dataset can not be None or empty.'
)
_device_number_check
(
self
.
_parallel_mode
,
self
.
_device_number
)
if
train_dataset
:
_parameter_broadcast_check
(
self
.
_parallel_mode
,
self
.
_parameter_broadcast
)
self
.
_train_network
.
set_train
()
self
.
_train_network
.
phase
=
'train'
if
self
.
_parameter_broadcast
:
self
.
_train_network
.
set_broadcast_flag
()
train_dataset_helper
,
train_network
=
self
.
_exec_preprocess
(
self
.
_train_network
,
is_train
=
True
,
phase
=
'train'
,
dataset
=
train_dataset
,
dataset_sink_mode
=
True
)
self
.
_train_network
=
train_network
for
inputs
in
train_dataset_helper
:
self
.
_train_network
.
compile
(
*
inputs
)
break
if
valid_dataset
:
if
not
self
.
_metric_fns
:
raise
RuntimeError
(
'If define `valid_dataset`, metric fn can not be None or empty.'
)
self
.
_eval_network
.
set_train
(
False
)
self
.
_eval_network
.
phase
=
'eval'
valid_dataset_helper
,
eval_network
=
self
.
_exec_preprocess
(
self
.
_eval_network
,
is_train
=
False
,
phase
=
'eval'
,
dataset
=
valid_dataset
,
dataset_sink_mode
=
True
)
self
.
_eval_network
=
eval_network
for
inputs
in
valid_dataset_helper
:
self
.
_eval_network
.
compile
(
*
inputs
)
break
def
_train
(
self
,
epoch
,
train_dataset
,
callbacks
=
None
,
dataset_sink_mode
=
True
):
"""
Training.
...
...
@@ -277,21 +365,15 @@ class Model:
list_callback (_ListCallback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
# remove later to deal with loop sink
need_wrap
=
False
if
not
hasattr
(
train_dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"device_target"
)
==
"Ascend"
\
and
not
context
.
get_context
(
"enable_ge"
):
need_wrap
=
True
dataset_helper
=
DatasetHelper
(
train_dataset
)
# remove later to deal with loop sink
if
need_wrap
:
self
.
_train_network
=
nn
.
DataWrapper
(
self
.
_train_network
,
*
(
dataset_helper
.
types_shapes
()),
train_dataset
.
__ME_INITED__
)
cb_params
.
train_network
=
self
.
_train_network
self
.
_train_network
.
set_train
()
dataset_helper
,
train_network
=
self
.
_exec_preprocess
(
self
.
_train_network
,
is_train
=
True
,
phase
=
'train'
,
dataset
=
train_dataset
,
dataset_sink_mode
=
True
)
self
.
_train_network
=
train_network
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
cur_step_num
=
0
loop_size
=
dataset_helper
.
loop_size
()
run_context
=
RunContext
(
cb_params
)
list_callback
.
begin
(
run_context
)
...
...
@@ -331,7 +413,11 @@ class Model:
list_callback (_ListCallback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper
=
DatasetHelper
(
train_dataset
,
dataset_sink_mode
=
False
)
dataset_helper
,
_
=
self
.
_exec_preprocess
(
self
.
_train_network
,
is_train
=
True
,
phase
=
'train'
,
dataset
=
train_dataset
,
dataset_sink_mode
=
False
)
cb_params
.
cur_step_num
=
0
run_context
=
RunContext
(
cb_params
)
list_callback
.
begin
(
run_context
)
...
...
@@ -437,26 +523,15 @@ class Model:
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
"""
_device_number_check
(
self
.
_parallel_mode
,
self
.
_device_number
)
run_context
=
RunContext
(
cb_params
)
# remove later to deal with loop sink
need_wrap
=
False
if
not
hasattr
(
valid_dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"device_target"
)
==
"Ascend"
\
and
not
context
.
get_context
(
"enable_ge"
):
need_wrap
=
True
valid_dataset
.
__loop_size__
=
1
dataset_helper
=
DatasetHelper
(
valid_dataset
)
# remove later to deal with loop sink
if
need_wrap
:
self
.
_eval_network
=
nn
.
DataWrapper
(
self
.
_eval_network
,
*
(
dataset_helper
.
types_shapes
()),
valid_dataset
.
__ME_INITED__
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
phase
=
'eval'
dataset_helper
,
eval_network
=
self
.
_exec_preprocess
(
self
.
_eval_network
,
is_train
=
False
,
phase
=
'eval'
,
dataset
=
valid_dataset
,
dataset_sink_mode
=
True
)
self
.
_eval_network
=
eval_network
cb_params
.
eval_network
=
self
.
_eval_network
list_callback
.
begin
(
run_context
)
for
inputs
in
dataset_helper
:
...
...
@@ -490,7 +565,11 @@ class Model:
run_context
=
RunContext
(
cb_params
)
list_callback
.
begin
(
run_context
)
dataset_helper
=
DatasetHelper
(
valid_dataset
,
dataset_sink_mode
=
False
)
dataset_helper
,
_
=
self
.
_exec_preprocess
(
self
.
_eval_network
,
is_train
=
False
,
phase
=
'eval'
,
dataset
=
valid_dataset
,
dataset_sink_mode
=
False
)
for
next_element
in
dataset_helper
:
cb_params
.
cur_step_num
+=
1
list_callback
.
step_begin
(
run_context
)
...
...
@@ -532,6 +611,7 @@ class Model:
>>> model.eval(dataset)
"""
check_bool
(
dataset_sink_mode
)
_device_number_check
(
self
.
_parallel_mode
,
self
.
_device_number
)
if
not
self
.
_metric_fns
:
raise
ValueError
(
"metric fn can not be None or empty."
)
...
...
tests/ut/python/train/test_training.py
浏览文件 @
1ebf98b7
...
...
@@ -68,12 +68,12 @@ class LossNet(nn.Cell):
return
out
def
get_model
():
def
get_model
(
metrics
=
None
):
""" get_model """
net
=
Net
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
optim
=
Momentum
(
net
.
trainable_params
(),
learning_rate
=
0.1
,
momentum
=
0.9
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
None
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optim
,
metrics
=
metrics
)
return
model
...
...
@@ -215,8 +215,27 @@ def test_model_build_abnormal_string():
assert
err
def
test_model_init
_error
():
def
test_model_init
():
""" test_model_init_error """
train_dataset
=
get_dataset
()
eval_dataset
=
get_dataset
()
with
pytest
.
raises
(
RuntimeError
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
get_model
().
init
(
train_dataset
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
get_model
().
init
(
train_dataset
)
get_model
(
metrics
=
{
'acc'
}).
init
(
eval_dataset
)
with
pytest
.
raises
(
RuntimeError
):
get_model
().
init
(
train_dataset
,
eval_dataset
)
with
pytest
.
raises
(
ValueError
):
get_model
().
init
()
def
test_init_model_error
():
""" test_init_model_error """
net
=
nn
.
ReLU
()
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
()
with
pytest
.
raises
(
KeyError
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录