Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
58d63fef
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58d63fef
编写于
9月 27, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix step count and writer create
上级
e6c28438
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
20 addition
and
21 deletion
+20
-21
python/paddle/hapi/callbacks.py
python/paddle/hapi/callbacks.py
+20
-21
未找到文件。
python/paddle/hapi/callbacks.py
浏览文件 @
58d63fef
...
@@ -507,6 +507,7 @@ class VisualDL(Callback):
...
@@ -507,6 +507,7 @@ class VisualDL(Callback):
self
.
log_dir
=
log_dir
self
.
log_dir
=
log_dir
self
.
epochs
=
None
self
.
epochs
=
None
self
.
steps
=
None
self
.
steps
=
None
self
.
epoch
=
0
def
_is_write
(
self
):
def
_is_write
(
self
):
return
ParallelEnv
().
local_rank
==
0
return
ParallelEnv
().
local_rank
==
0
...
@@ -517,20 +518,24 @@ class VisualDL(Callback):
...
@@ -517,20 +518,24 @@ class VisualDL(Callback):
self
.
train_metrics
=
self
.
params
[
'metrics'
]
self
.
train_metrics
=
self
.
params
[
'metrics'
]
assert
self
.
train_metrics
assert
self
.
train_metrics
self
.
_is_fit
=
True
self
.
_is_fit
=
True
self
.
train_step
=
0
def
on_epoch_begin
(
self
,
epoch
=
None
,
logs
=
None
):
def
on_epoch_begin
(
self
,
epoch
=
None
,
logs
=
None
):
visualdl
=
try_import
(
'visualdl'
)
self
.
steps
=
self
.
params
[
'steps'
]
self
.
steps
=
self
.
params
[
'steps'
]
self
.
epoch
=
epoch
self
.
epoch
=
epoch
self
.
train_step
=
0
self
.
train_writer
=
visualdl
.
LogWriter
(
self
.
log_dir
)
def
_updates
(
self
,
logs
,
mode
):
def
_updates
(
self
,
logs
,
mode
):
if
not
self
.
_is_write
():
return
if
not
hasattr
(
self
,
'writer'
):
visualdl
=
try_import
(
'visualdl'
)
self
.
writer
=
visualdl
.
LogWriter
(
self
.
log_dir
)
metrics
=
getattr
(
self
,
'%s_metrics'
%
(
mode
))
metrics
=
getattr
(
self
,
'%s_metrics'
%
(
mode
))
writer
=
getattr
(
self
,
'%s_writer'
%
(
mode
))
current_step
=
getattr
(
self
,
'%s_step'
%
(
mode
))
current_step
=
getattr
(
self
,
'%s_step'
%
(
mode
))
if
mode
==
'train'
:
if
mode
==
'train'
:
total_step
=
self
.
epoch
*
self
.
steps
+
current_step
total_step
=
current_step
else
:
else
:
total_step
=
self
.
epoch
total_step
=
self
.
epoch
...
@@ -544,7 +549,8 @@ class VisualDL(Callback):
...
@@ -544,7 +549,8 @@ class VisualDL(Callback):
temp_value
=
logs
[
k
]
temp_value
=
logs
[
k
]
else
:
else
:
continue
continue
writer
.
add_scalar
(
self
.
writer
.
add_scalar
(
tag
=
temp_tag
,
step
=
total_step
,
value
=
temp_value
)
tag
=
temp_tag
,
step
=
total_step
,
value
=
temp_value
)
def
on_train_batch_end
(
self
,
step
,
logs
=
None
):
def
on_train_batch_end
(
self
,
step
,
logs
=
None
):
...
@@ -552,30 +558,23 @@ class VisualDL(Callback):
...
@@ -552,30 +558,23 @@ class VisualDL(Callback):
self
.
train_step
+=
1
self
.
train_step
+=
1
if
self
.
_is_write
():
if
self
.
_is_write
():
if
self
.
steps
is
None
or
self
.
train_step
<
self
.
steps
:
self
.
_updates
(
logs
,
'train'
)
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
_is_write
()
and
(
self
.
steps
is
not
None
):
self
.
_updates
(
logs
,
'train'
)
self
.
_updates
(
logs
,
'train'
)
def
on_eval_begin
(
self
,
logs
=
None
):
def
on_eval_begin
(
self
,
logs
=
None
):
visualdl
=
try_import
(
'visualdl'
)
self
.
eval_steps
=
logs
.
get
(
'steps'
,
None
)
self
.
eval_steps
=
logs
.
get
(
'steps'
,
None
)
self
.
eval_metrics
=
logs
.
get
(
'metrics'
,
[])
self
.
eval_metrics
=
logs
.
get
(
'metrics'
,
[])
self
.
eval_step
=
0
self
.
eval_step
=
0
self
.
evaled_samples
=
0
self
.
evaled_samples
=
0
self
.
eval_writer
=
visualdl
.
LogWriter
(
self
.
log_dir
)
def
on_train_end
(
self
,
logs
=
None
):
def
on_train_end
(
self
,
logs
=
None
):
if
hasattr
(
self
,
'train_writer'
):
if
hasattr
(
self
,
'writer'
):
self
.
train_writer
.
close
()
self
.
writer
.
close
()
if
hasattr
(
self
,
'eval_writer'
):
delattr
(
self
,
'writer'
)
self
.
eval_writer
.
close
()
def
on_eval_end
(
self
,
logs
=
None
):
def
on_eval_end
(
self
,
logs
=
None
):
self
.
_updates
(
logs
,
'eval'
)
if
self
.
_is_write
():
self
.
_updates
(
logs
,
'eval'
)
if
(
not
hasattr
(
self
,
'_is_fit'
))
and
hasattr
(
self
,
'eval_writer'
):
if
(
not
hasattr
(
self
,
'_is_fit'
))
and
hasattr
(
self
,
'writer'
):
self
.
eval_writer
.
close
()
self
.
writer
.
close
()
delattr
(
self
,
'writer'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录