Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
a3012b87
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
4
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a3012b87
编写于
1月 06, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
1月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
fix multi-dev predict
fix bugs
上级
f3d75f5c
6c1a4885
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
30 addition
and
29 deletion
+30
-29
paddlepalm/distribute/__init__.py
paddlepalm/distribute/__init__.py
+1
-1
paddlepalm/distribute/reader.py
paddlepalm/distribute/reader.py
+12
-1
paddlepalm/mtl_controller.py
paddlepalm/mtl_controller.py
+17
-27
未找到文件。
paddlepalm/distribute/__init__.py
浏览文件 @
a3012b87
...
...
@@ -5,5 +5,5 @@ import multiprocessing
gpu_dev_count
=
int
(
fluid
.
core
.
get_cuda_device_count
())
cpu_dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
from
reader
import
yield_pieces
,
data_feeder
from
reader
import
yield_pieces
,
data_feeder
,
decode_fake
paddlepalm/distribute/reader.py
浏览文件 @
a3012b87
...
...
@@ -58,7 +58,6 @@ def yield_pieces(data, distribute_strategy, batch_size):
yield
temp
def
data_feeder
(
reader
,
postprocess_fn
=
None
,
prefetch_steps
=
2
,
phase
=
'train'
):
if
postprocess_fn
is
None
:
def
postprocess_fn
(
batch
):
return
batch
...
...
@@ -108,3 +107,15 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
queue
.
join
()
def
decode_fake
(
nums
,
mask
,
bs
):
n_t
=
0
for
flag
in
mask
:
if
not
flag
:
break
n_t
=
n_t
+
1
n_f
=
len
(
mask
)
-
n_t
p1
=
nums
-
(
n_t
-
1
)
*
bs
each_f
=
p1
/
(
n_f
+
1
)
return
each_f
*
n_f
paddlepalm/mtl_controller.py
浏览文件 @
a3012b87
...
...
@@ -31,7 +31,7 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint
from
paddlepalm.utils.config_helper
import
PDConfig
from
paddlepalm.utils.print_helper
import
print_dict
from
paddlepalm.utils.reader_helper
import
create_net_inputs
,
create_iterator_fn
,
create_joint_iterator_fn
,
merge_input_attrs
from
paddlepalm.distribute
import
data_feeder
from
paddlepalm.distribute
import
data_feeder
,
decode_fake
from
default_settings
import
*
from
task_instance
import
TaskInstance
,
check_instances
...
...
@@ -228,6 +228,7 @@ class Controller(object):
exe
,
dev_count
=
_init_env
(
use_gpu
=
mtl_conf
.
get
(
'use_gpu'
,
True
))
self
.
exe
=
exe
self
.
dev_count
=
dev_count
self
.
batch_size
=
mtl_conf
.
get
(
'batch_size'
)
print_dict
(
mtl_conf
,
title
=
'global configuration'
)
...
...
@@ -350,7 +351,7 @@ class Controller(object):
dev_count
=
self
.
dev_count
num_instances
=
len
(
instances
)
mrs
=
self
.
mrs
branch
=
fluid
.
data
(
name
=
"branch"
,
shape
=
[
1
],
dtype
=
'int
32
'
)
branch
=
fluid
.
data
(
name
=
"branch"
,
shape
=
[
1
],
dtype
=
'int
64
'
)
# set first_target/main task instance
main_inst
=
None
...
...
@@ -536,9 +537,8 @@ class Controller(object):
# prepare for train
self
.
train_backbone
=
train_backbone
#
self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
self
.
train_program
=
fluid
.
CompiledProgram
(
fluid
.
default_main_program
()).
with_data_parallel
(
loss_name
=
loss
.
name
)
self
.
saver_program
=
fluid
.
default_main_program
()
self
.
train_program
=
self
.
saver_program
self
.
main_inst
=
main_inst
self
.
has_init_train
=
True
...
...
@@ -564,7 +564,7 @@ class Controller(object):
insert_taskid
=
False
,
insert_batchsize
=
False
,
insert_seqlen
=
False
,
insert_batchsize_x_seqlen
=
False
)
pred_prog
=
inst
.
load
(
infer_model_path
)
#
pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
pred_prog
=
fluid
.
CompiledProgram
(
pred_prog
).
with_data_parallel
()
if
inst
.
reader
[
'pred'
]
is
None
:
pred_reader
=
inst
.
Reader
(
inst
.
config
,
phase
=
'pred'
)
inst
.
reader
[
'pred'
]
=
pred_reader
...
...
@@ -628,8 +628,8 @@ class Controller(object):
while
not
train_finish
():
feed
,
mask
,
id
=
next
(
distribute_feeder
)
feed
[
0
].
update
({
'branch'
:
np
.
array
([
id
],
dtype
=
'int32
'
)})
for
i
in
range
(
self
.
dev_count
):
feed
[
i
].
update
({
'branch'
:
np
.
array
([
id
],
dtype
=
'int64
'
)})
fetch_list
.
append
(
self
.
_switched_loss
)
rt_outputs
=
self
.
exe
.
run
(
train_program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
rt_loss
=
rt_outputs
.
pop
()
...
...
@@ -714,33 +714,23 @@ class Controller(object):
buf
=
[]
for
feed
,
mask
,
id
in
distribute_feeder
:
# print('before run')
rt_outputs
=
self
.
exe
.
run
(
pred_prog
,
feed
,
fetch_vars
)
# print('after run')
splited_rt_outputs
=
[]
for
item
in
rt_outputs
:
splited_rt_outputs
.
append
(
np
.
split
(
item
,
len
(
mask
)))
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
# print(mask)
while
mask
.
pop
()
==
False
:
print
(
mask
)
for
item
in
splited_rt_outputs
:
nums_fake
=
decode_fake
(
len
(
rt_outputs
[
0
]),
mask
,
self
.
batch_size
)
while
nums_fake
:
for
item
in
rt_outputs
:
item
.
pop
()
rt_outputs
=
[]
# print('cancat')
for
item
in
splited_rt_outputs
:
rt_outputs
.
append
(
np
.
concatenate
(
item
))
nums_fake
=
nums_fake
-
1
rt_outputs
=
{
k
:
v
for
k
,
v
in
zip
(
fetch_names
,
rt_outputs
)}
inst
.
postprocess
(
rt_outputs
,
phase
=
'pred'
)
# print('leave feeder')
if
inst
.
task_layer
[
'pred'
].
epoch_inputs_attrs
:
reader_outputs
=
inst
.
reader
[
'pred'
].
get_epoch_outputs
()
else
:
reader_outputs
=
None
# print('epoch postprocess')
inst
.
epoch_postprocess
({
'reader'
:
reader_outputs
},
phase
=
'pred'
)
...
...
@@ -754,4 +744,4 @@ if __name__ == '__main__':
__all__
=
[
"Controller"
]
\ No newline at end of file
__all__
=
[
"Controller"
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录