Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
7d1ea67d
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7d1ea67d
编写于
4月 01, 2020
作者:
G
Guo Sheng
提交者:
GitHub
4月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15 from guoshengCS/fix-data-train
Reorganize data from data_loader into inputs and labels.
上级
4d22fee0
863897ce
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
28 addition
and
18 deletion
+28
-18
model.py
model.py
+26
-16
progressbar.py
progressbar.py
+2
-2
未找到文件。
model.py
浏览文件 @
7d1ea67d
...
...
@@ -29,6 +29,7 @@ from paddle.fluid.executor import global_scope
from
paddle.fluid.io
import
is_belong_to_optimizer
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.incubate.fleet.collective
import
fleet
,
DistributedStrategy
from
paddle.fluid.incubate.fleet.base
import
role_maker
from
paddle.fluid.io
import
DataLoader
,
Dataset
...
...
@@ -414,13 +415,7 @@ class StaticGraphAdapter(object):
losses
=
[]
metrics
=
[]
with
fluid
.
program_guard
(
prog
,
self
.
_startup_prog
):
if
isinstance
(
self
.
model
.
_inputs
,
dict
):
ins
=
[
self
.
model
.
_inputs
[
n
]
for
n
in
extract_args
(
self
.
model
.
forward
)
if
n
!=
'self'
]
else
:
ins
=
self
.
model
.
_inputs
ins
=
self
.
model
.
_inputs
lbls
=
self
.
model
.
_labels
if
self
.
model
.
_labels
else
[]
inputs
=
[
k
.
forward
()
for
k
in
to_list
(
ins
)]
labels
=
[
k
.
forward
()
for
k
in
to_list
(
lbls
)]
...
...
@@ -867,8 +862,10 @@ class Model(fluid.dygraph.Layer):
metric
.
__class__
.
__name__
)
self
.
_metrics
=
to_list
(
metrics
)
self
.
_inputs
=
inputs
self
.
_labels
=
labels
self
.
_inputs
=
to_list
(
inputs
)
if
not
isinstance
(
inputs
,
dict
)
else
[
inputs
[
n
]
for
n
in
extract_args
(
self
.
forward
)
if
n
!=
'self'
]
self
.
_labels
=
to_list
(
labels
)
if
not
in_dygraph_mode
():
self
.
_adapter
.
prepare
()
...
...
@@ -1174,17 +1171,30 @@ class Model(fluid.dygraph.Layer):
callbacks
.
on_epoch_begin
(
epoch
)
for
step
,
data
in
enumerate
(
data_loader
):
if
not
fluid
.
in_dygraph_mode
():
data
=
data
[
0
]
batch_size
=
data
[
0
].
shape
()[
0
]
else
:
batch_size
=
data
[
0
].
shape
[
0
]
# data might come from different types of data_loader and have
# different format, as following:
# 1. DataLoader in static graph:
# [[input1, input2, ..., label1, lable2, ...]]
# 2. DataLoader in dygraph
# [input1, input2, ..., label1, lable2, ...]
# 3. custumed iterator yield concated inputs and labels:
# [input1, input2, ..., label1, lable2, ...]
# 4. custumed iterator yield seperated inputs and labels:
# ([input1, input2, ...], [label1, lable2, ...])
# To handle all of these, flatten (nested) list to list.
data
=
flatten
(
data
)
# LoDTensor.shape is callable, where LoDTensor comes from
# DataLoader in static graph
batch_size
=
data
[
0
].
shape
()[
0
]
if
callable
(
data
[
0
].
shape
)
else
data
[
0
].
shape
[
0
]
callbacks
.
on_batch_begin
(
mode
,
step
,
logs
)
if
mode
==
'train'
:
outs
=
self
.
train
(
*
data
)
outs
=
self
.
train
(
data
[:
len
(
self
.
_inputs
)],
data
[
len
(
self
.
_inputs
):])
else
:
outs
=
self
.
eval
(
*
data
)
outs
=
self
.
eval
(
data
[:
len
(
self
.
_inputs
)],
data
[
len
(
self
.
_inputs
):])
# losses
loss
=
outs
[
0
]
if
self
.
_metrics
else
outs
...
...
progressbar.py
浏览文件 @
7d1ea67d
...
...
@@ -107,7 +107,7 @@ class ProgressBar(object):
eta
=
time_per_unit
*
(
self
.
_num
-
current_num
)
if
eta
>
3600
:
eta_format
=
'%d:%02d:%02d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
)
60
,
eta
%
60
)
elif
eta
>
60
:
eta_format
=
'%d:%02d'
%
(
eta
//
60
,
eta
%
60
)
else
:
...
...
@@ -148,7 +148,7 @@ class ProgressBar(object):
else
:
info
+=
' %.4e'
%
v
elif
isinstance
(
v
,
np
.
ndarray
)
and
\
isinstance
(
v
.
size
,
1
)
and
\
v
.
size
==
1
and
\
isinstance
(
v
.
dtype
,
(
np
.
float32
,
np
.
float64
)):
if
abs
(
v
[
0
])
>
1e-3
:
info
+=
' %.4f'
%
v
[
0
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录