Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f3fd7a55
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f3fd7a55
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!5073 Add checks and exception handling DS callback
Merge pull request !5073 from h.farahat/map_callback_end
上级
37bae5bf
8eeceb26
master
无相关合并请求
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
40 addition
and
10 deletion
+40
-10
mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc
mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc
+5
-1
mindspore/dataset/callback/ds_callback.py
mindspore/dataset/callback/ds_callback.py
+18
-9
tests/ut/python/dataset/test_callbacks.py
tests/ut/python/dataset/test_callbacks.py
+17
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/callback/py_ds_callback.cc
浏览文件 @
f3fd7a55
...
...
@@ -53,7 +53,11 @@ Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
f
(
cb_param
);
try
{
f
(
cb_param
);
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
}
return
Status
::
OK
();
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/dataset/callback/ds_callback.py
浏览文件 @
f3fd7a55
...
...
@@ -144,6 +144,8 @@ class WaitedDSCallback(Callback, DSCallback):
self
.
epoch_event
=
threading
.
Event
()
self
.
epoch_run_context
=
None
self
.
training_ended
=
False
def
sync_epoch_begin
(
self
,
train_run_context
,
ds_run_context
):
"""
Called before a new dataset epoch is started and after the previous training epoch is ended.
...
...
@@ -180,10 +182,11 @@ class WaitedDSCallback(Callback, DSCallback):
ds_run_context: Include some information of the pipeline.
"""
if
ds_run_context
.
cur_epoch_num
>
1
:
success
=
self
.
epoch_event
.
wait
(
timeout
=
ds
.
config
.
get_callback_timeout
())
self
.
epoch_event
.
clear
()
if
not
success
:
raise
RuntimeError
(
f
"ds_epoch_begin timed out after
{
ds
.
config
.
get_callback_timeout
()
}
second(s)"
)
if
not
self
.
training_ended
:
success
=
self
.
epoch_event
.
wait
(
timeout
=
ds
.
config
.
get_callback_timeout
())
self
.
epoch_event
.
clear
()
if
not
success
:
raise
RuntimeError
(
f
"ds_epoch_begin timed out after
{
ds
.
config
.
get_callback_timeout
()
}
second(s)"
)
# by the time this thread wakes up, self.epoch_run_context is already available
self
.
sync_epoch_begin
(
self
.
epoch_run_context
,
ds_run_context
)
...
...
@@ -205,11 +208,12 @@ class WaitedDSCallback(Callback, DSCallback):
ds_run_context: Include some information of the pipeline.
"""
if
ds_run_context
.
cur_step_num
>
self
.
step_size
:
success
=
self
.
step_event
.
wait
(
timeout
=
ds
.
config
.
get_callback_timeout
())
self
.
step_event
.
clear
()
if
not
success
:
raise
RuntimeError
(
f
"ds_step_begin timed out after
{
ds
.
config
.
get_callback_timeout
()
}
second(s)"
)
# by the time this thread wakes up, self.epoch_run_context is already available
if
not
self
.
training_ended
:
success
=
self
.
step_event
.
wait
(
timeout
=
ds
.
config
.
get_callback_timeout
())
self
.
step_event
.
clear
()
if
not
success
:
raise
RuntimeError
(
f
"ds_step_begin timed out after
{
ds
.
config
.
get_callback_timeout
()
}
second(s)"
)
# by the time this thread wakes up, self.epoch_run_context is already available
self
.
sync_step_begin
(
self
.
step_run_context
,
ds_run_context
)
def
create_runtime_obj
(
self
):
...
...
@@ -233,3 +237,8 @@ class WaitedDSCallback(Callback, DSCallback):
raise
AttributeError
(
"Provided Callback class did not override any of the 2 callback methods."
)
return
c_cb
def
end
(
self
,
run_context
):
self
.
epoch_end
(
run_context
)
self
.
step_end
(
run_context
)
self
.
training_ended
=
True
This diff is collapsed.
Click to expand it.
tests/ut/python/dataset/test_callbacks.py
浏览文件 @
f3fd7a55
...
...
@@ -410,6 +410,22 @@ def test_callbacks_exceptions():
assert
"RuntimeError: Bad begin"
in
str
(
err
.
value
)
def
test_callbacks_train_end
():
logger
.
info
(
"test_callback_sink_simulation"
)
# No asserts are needed, just test there is no deadlock or exceptions
events
=
[]
epochs
=
2
my_cb
=
MyWaitedCallback
(
events
,
1
)
data
=
ds
.
NumpySlicesDataset
([
1
,
2
,
3
,
4
],
shuffle
=
False
)
data
=
data
.
map
(
operations
=
(
lambda
x
:
x
),
callbacks
=
[
my_cb
])
data
=
data
.
to_device
()
data
.
send
(
num_epochs
=
epochs
)
time
.
sleep
(
0.5
)
my_cb
.
end
(
run_context
=
{})
time
.
sleep
(
0.5
)
def
test_callbacks_one_cb
():
logger
.
info
(
"test_callbacks_one_cb"
)
...
...
@@ -458,3 +474,4 @@ if __name__ == '__main__':
test_callbacks_non_sink
()
test_callbacks_one_cb
()
test_callbacks_non_sink_mismatch_size
()
test_callbacks_train_end
()
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部