Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
939e6129
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
939e6129
编写于
7月 24, 2020
作者:
P
panfengfeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix get daataset size error
上级
21edd691
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
41 addition
and
23 deletion
+41
-23
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
...src/minddata/dataset/engine/datasetops/device_queue_op.cc
+2
-2
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+1
-1
mindspore/train/dataset_helper.py
mindspore/train/dataset_helper.py
+20
-15
mindspore/train/model.py
mindspore/train/model.py
+11
-3
tests/dataset_mock.py
tests/dataset_mock.py
+1
-1
tests/st/networks/models/resnet50/src_thor/dataset_helper.py
tests/st/networks/models/resnet50/src_thor/dataset_helper.py
+6
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
浏览文件 @
939e6129
...
...
@@ -212,12 +212,12 @@ Status DeviceQueueOp::SendDataToGPU() {
RETURN_IF_NOT_OK
(
RetryPushGPUData
(
data_size
,
curr_row
,
handle
));
total_batch
++
;
}
if
(
!
TaskManager
::
FindMe
()
->
Interrupted
())
if
(
!
TaskManager
::
FindMe
()
->
Interrupted
()
&&
!
GpuBufferMgr
::
GetInstance
().
IsClosed
()
)
RETURN_IF_NOT_OK
(
GetNextInput
(
&
current_buffer
));
else
is_break_loop
=
true
;
}
if
(
!
TaskManager
::
FindMe
()
->
Interrupted
())
if
(
!
TaskManager
::
FindMe
()
->
Interrupted
()
&&
!
GpuBufferMgr
::
GetInstance
().
IsClosed
()
)
RETURN_IF_NOT_OK
(
GetNextInput
(
&
current_buffer
));
else
is_break_loop
=
true
;
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
939e6129
...
...
@@ -2401,7 +2401,7 @@ class TransferDataset(DatasetOp):
# need to keep iterator alive so the executionTree is not destroyed
if
self
.
_noop_mode
():
return
self
.
iterator
=
TupleIterator
(
self
,
num_epochs
=
-
1
)
self
.
iterator
=
TupleIterator
(
self
,
num_epochs
=
num_epochs
)
def
stop_send
(
self
):
self
.
iterator
.
depipeline
.
StopSend
()
...
...
mindspore/train/dataset_helper.py
浏览文件 @
939e6129
...
...
@@ -24,13 +24,18 @@ from ..nn.wrap import GetNextSingleOp
from
..parallel._utils
import
_get_device_num
,
_get_global_rank
,
_need_to_full
def
_send_data
(
dataset
):
def
_send_data
(
dataset
,
epoch_num
):
"""Engine dataset to write data to tdt queue."""
if
not
hasattr
(
dataset
,
'__has_sent__'
):
exec_dataset
=
dataset
.
__TRANSFER_DATASET__
exec_dataset
.
send
()
exec_dataset
.
send
(
epoch_num
)
dataset
.
__has_sent__
=
True
def
_send_data_no_flag
(
dataset
,
epoch_num
):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset
=
dataset
.
__TRANSFER_DATASET__
exec_dataset
.
send
(
epoch_num
)
class
DatasetHelper
:
"""
...
...
@@ -54,7 +59,7 @@ class DatasetHelper:
>>> outputs = network(*inputs)
"""
def
__init__
(
self
,
dataset
,
dataset_sink_mode
=
True
,
sink_size
=-
1
):
def
__init__
(
self
,
dataset
,
dataset_sink_mode
=
True
,
sink_size
=-
1
,
epoch_num
=
1
):
check_bool
(
dataset_sink_mode
)
check_int
(
sink_size
)
if
sink_size
<
-
1
or
sink_size
==
0
:
...
...
@@ -74,7 +79,7 @@ class DatasetHelper:
iterclass
=
_DatasetIterMS
elif
context
.
get_context
(
"device_target"
)
==
"CPU"
:
raise
RuntimeError
(
"Currently dataset sink mode is not supported when the device target is CPU."
)
self
.
iter
=
iterclass
(
dataset
,
sink_size
)
self
.
iter
=
iterclass
(
dataset
,
sink_size
,
epoch_num
)
else
:
iterclass
=
_DatasetIterNormal
self
.
iter
=
iterclass
(
dataset
)
...
...
@@ -98,7 +103,7 @@ class DatasetHelper:
class
_DatasetIter
:
"""Base iter for dataset helper"""
def
__init__
(
self
,
dataset
,
sink_size
):
def
__init__
(
self
,
dataset
,
sink_size
,
epoch_num
):
self
.
dataset
=
dataset
self
.
sink_size
=
sink_size
self
.
sink_count
=
1
...
...
@@ -110,9 +115,9 @@ class _DatasetIter:
dataset
.
__ME_INITED__
=
dataset
.
__TRANSFER_DATASET__
.
queue_name
if
not
hasattr
(
dataset
,
'__no_send__'
):
_send_data
(
dataset
)
_send_data
(
dataset
,
epoch_num
)
else
:
_send_data
(
dataset
)
_send_data
_no_flag
(
dataset
,
epoch_num
)
self
.
stop_send
=
dataset
.
__TRANSFER_DATASET__
.
stop_send
self
.
dataset_types
,
self
.
dataset_shapes
=
_get_types_and_shapes
(
dataset
)
...
...
@@ -156,8 +161,8 @@ class _DatasetIter:
class
_DatasetIterGE
(
_DatasetIter
):
"""Iter for GE."""
def
__init__
(
self
,
dataset
,
sink_size
):
super
().
__init__
(
dataset
,
sink_size
)
def
__init__
(
self
,
dataset
,
sink_size
,
epoch_num
):
super
().
__init__
(
dataset
,
sink_size
,
epoch_num
)
self
.
sink_count
=
self
.
get_sink_count
(
dataset
)
batch_expand_num
=
1
if
_need_to_full
():
...
...
@@ -172,8 +177,8 @@ class _DatasetIterGE(_DatasetIter):
class
_DatasetIterMSLoopSink
(
_DatasetIter
):
"""Iter for context (device_target=Ascend)"""
def
__init__
(
self
,
dataset
,
sink_size
):
super
().
__init__
(
dataset
,
sink_size
)
def
__init__
(
self
,
dataset
,
sink_size
,
epoch_num
):
super
().
__init__
(
dataset
,
sink_size
,
epoch_num
)
self
.
sink_count
=
self
.
get_sink_count
(
dataset
)
ms_role
=
os
.
getenv
(
"MS_ROLE"
)
if
ms_role
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
...
...
@@ -193,8 +198,8 @@ class _DatasetIterMSLoopSink(_DatasetIter):
class
_DatasetIterMS
(
_DatasetIter
):
"""Iter for MS(enable_loop_sink=False)."""
def
__init__
(
self
,
dataset
,
sink_size
):
super
().
__init__
(
dataset
,
sink_size
)
def
__init__
(
self
,
dataset
,
sink_size
,
epoch_num
):
super
().
__init__
(
dataset
,
sink_size
,
epoch_num
)
if
sink_size
>
0
:
self
.
sink_count
=
sink_size
else
:
...
...
@@ -206,8 +211,8 @@ class _DatasetIterMS(_DatasetIter):
class
_DatasetIterPSLite
(
_DatasetIter
):
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
def
__init__
(
self
,
dataset
,
sink_size
):
super
().
__init__
(
dataset
,
sink_size
)
def
__init__
(
self
,
dataset
,
sink_size
,
epoch_num
):
super
().
__init__
(
dataset
,
sink_size
,
epoch_num
)
self
.
sink_count
=
1
self
.
sink_size
=
1
self
.
op
=
None
...
...
mindspore/train/model.py
浏览文件 @
939e6129
...
...
@@ -227,7 +227,7 @@ class Model:
scaling_sens
/=
self
.
_device_number
return
scaling_sens
def
_exec_preprocess
(
self
,
network
,
is_train
,
phase
,
dataset
,
dataset_sink_mode
,
sink_size
=-
1
):
def
_exec_preprocess
(
self
,
network
,
is_train
,
phase
,
dataset
,
dataset_sink_mode
,
sink_size
=-
1
,
epoch_num
=
1
):
"""Initializes dataset."""
need_wrap
=
False
if
dataset_sink_mode
:
...
...
@@ -239,7 +239,7 @@ class Model:
if
not
is_train
:
dataset
.
__loop_size__
=
1
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
,
sink_size
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
,
sink_size
,
epoch_num
)
# remove later to deal with loop sink
if
need_wrap
:
...
...
@@ -399,12 +399,18 @@ class Model:
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data each sink. Default: -1.
"""
if
sink_size
==
-
1
:
epoch_num
=
epoch
else
:
epoch_num
=
epoch
*
sink_size
//
train_dataset
.
get_dataset_size
()
dataset_helper
,
train_network
=
self
.
_exec_preprocess
(
self
.
_train_network
,
is_train
=
True
,
phase
=
'train'
,
dataset
=
train_dataset
,
dataset_sink_mode
=
True
,
sink_size
=
sink_size
)
sink_size
=
sink_size
,
epoch_num
=
epoch_num
)
self
.
_train_network
=
train_network
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
cur_step_num
=
0
...
...
@@ -621,6 +627,8 @@ class Model:
list_callback
.
step_end
(
run_context
)
self
.
_update_metrics
(
outputs
)
valid_dataset
.
reset
()
metrics
=
self
.
_get_metrics
()
cb_params
.
metrics
=
metrics
list_callback
.
end
(
run_context
)
...
...
tests/dataset_mock.py
浏览文件 @
939e6129
...
...
@@ -58,7 +58,7 @@ class MindData:
def
create_tuple_iterator
(
self
):
return
self
.
__iter__
()
def
send
(
self
):
def
send
(
self
,
num_epochs
=-
1
):
pass
def
stop_send
(
self
):
...
...
tests/st/networks/models/resnet50/src_thor/dataset_helper.py
浏览文件 @
939e6129
...
...
@@ -15,11 +15,16 @@
"""Dataset help for minddata dataset"""
from
mindspore._checkparam
import
check_bool
from
mindspore.parallel._utils
import
_get_device_num
,
_get_parallel_mode
from
mindspore.train.dataset_helper
import
_send_data
from
mindspore.train._utils
import
_exec_datagraph
,
_get_types_and_shapes
,
\
_to_full_shapes
from
mindspore.train.parallel_utils
import
ParallelMode
def
_send_data
(
dataset
):
"""Engine dataset to write data to tdt queue."""
if
not
hasattr
(
dataset
,
'__has_sent__'
):
exec_dataset
=
dataset
.
__TRANSFER_DATASET__
exec_dataset
.
send
()
dataset
.
__has_sent__
=
True
class
DatasetHelper
:
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录