Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
ac321194
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ac321194
编写于
6月 29, 2020
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix fetchs
上级
7713736e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
45 deletion
+12
-45
tools/program.py
tools/program.py
+12
-45
未找到文件。
tools/program.py
浏览文件 @
ac321194
...
...
@@ -49,7 +49,7 @@ def create_dataloader():
dataloader(fluid dataloader):
"""
trainer_num
=
int
(
os
.
environ
.
get
(
'PADDLE_TRAINERS_NUM'
,
1
))
capacity
=
64
if
trainer_num
<
=
1
else
8
capacity
=
64
if
trainer_num
=
=
1
else
8
dataloader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
capacity
,
use_double_buffer
=
True
,
iterable
=
True
)
...
...
@@ -163,15 +163,7 @@ def create_metric(out,
return
fetchs
def
create_fetchs
(
feeds
,
out
,
config
,
architecture
,
topk
=
5
,
classes_num
=
1000
,
epsilon
=
None
,
use_mix
=
False
,
use_distillation
=
False
):
def
create_fetchs
(
feeds
,
net
,
config
,
mode
=
"train"
):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if use_mix).
...
...
@@ -190,6 +182,15 @@ def create_fetchs(feeds,
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
architecture
=
config
.
ARCHITECTURE
topk
=
config
.
topk
classes_num
=
config
.
classes_num
epsilon
=
config
.
get
(
'ls_epsilon'
)
use_mix
=
config
.
get
(
'use_mix'
)
and
mode
==
'train'
use_distillation
=
config
.
get
(
'use_distillation'
)
out
=
net
(
feeds
[
"image"
])
fetchs
=
OrderedDict
()
fetchs
[
'loss'
]
=
create_loss
(
feeds
,
out
,
architecture
,
classes_num
,
epsilon
,
use_mix
,
use_distillation
)
...
...
@@ -276,40 +277,6 @@ def mixed_precision_optimizer(config, optimizer):
return
optimizer
def
compute
(
feeds
,
net
,
config
,
mode
=
'train'
):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
"""
out
=
net
(
feeds
[
"image"
])
fetchs
=
create_fetchs
(
feeds
,
out
,
config
,
config
.
ARCHITECTURE
,
config
.
topk
,
config
.
classes_num
,
epsilon
=
config
.
get
(
'ls_epsilon'
),
use_mix
=
config
.
get
(
'use_mix'
)
and
mode
==
'train'
,
use_distillation
=
config
.
get
(
'use_distillation'
))
return
fetchs
def
create_feeds
(
batch
,
use_mix
):
image
=
to_variable
(
batch
[
0
].
numpy
().
astype
(
"float32"
))
if
use_mix
:
...
...
@@ -360,7 +327,7 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
for
idx
,
batch
in
enumerate
(
dataloader
()):
batch_size
=
len
(
batch
[
0
])
feeds
=
create_feeds
(
batch
,
use_mix
)
fetchs
=
c
ompute
(
feeds
,
net
,
config
,
mode
)
fetchs
=
c
reate_fetchs
(
feeds
,
net
,
config
,
mode
)
if
mode
==
'train'
:
avg_loss
=
net
.
scale_loss
(
fetchs
[
'loss'
])
avg_loss
.
backward
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录