Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
391a9bbb
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
391a9bbb
编写于
12月 04, 2019
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix standard mtl
上级
ca5921da
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
73 addition
and
5 deletion
+73
-5
paddlepalm/mtl_controller.py
paddlepalm/mtl_controller.py
+63
-3
paddlepalm/utils/reader_helper.py
paddlepalm/utils/reader_helper.py
+10
-2
未找到文件。
paddlepalm/mtl_controller.py
浏览文件 @
391a9bbb
...
...
@@ -35,6 +35,9 @@ from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn
from
paddlepalm.default_settings
import
*
from
task_instance
import
TaskInstance
,
check_instances
import
Queue
from
threading
import
Thread
DEBUG
=
False
VERBOSE
=
0
...
...
@@ -399,11 +402,14 @@ class Controller(object):
prefixes
.
append
(
inst
.
name
)
mrs
.
append
(
inst
.
mix_ratio
)
joint_iterator_fn
=
create_joint_iterator_fn
(
iterators
,
prefixes
,
joint_shape_and_dtypes
,
mrs
,
name_to_position
,
dev_count
=
dev_count
,
verbose
=
VERBOSE
)
joint_iterator_fn
=
create_joint_iterator_fn
(
iterators
,
prefixes
,
joint_shape_and_dtypes
,
mrs
,
name_to_position
,
dev_count
=
dev_count
,
verbose
=
VERBOSE
,
return_type
=
'dict'
)
self
.
_joint_iterator_fn
=
joint_iterator_fn
input_attrs
=
[[
i
,
j
,
k
]
for
i
,
(
j
,
k
)
in
zip
(
joint_input_names
,
joint_shape_and_dtypes
)]
pred_input_attrs
=
[[
i
,
j
,
k
]
for
i
,
(
j
,
k
)
in
zip
(
pred_joint_input_names
,
pred_joint_shape_and_dtypes
)]
net_inputs
=
create_net_inputs
(
input_attrs
,
async
=
True
,
iterator_fn
=
joint_iterator_fn
,
dev_count
=
dev_count
,
n_prefetch
=
3
)
# net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3)
net_inputs
=
create_net_inputs
(
input_attrs
,
async
=
False
)
self
.
_net_inputs
=
net_inputs
# build backbone and task layers
train_prog
=
fluid
.
default_main_program
()
...
...
@@ -568,6 +574,18 @@ class Controller(object):
return
False
return
True
def
pack_multicard_feed
(
iterator
,
net_inputs
,
dev_count
):
ret
=
[]
mask
=
[]
for
i
in
range
(
dev_count
):
temp
=
{}
content
,
flag
=
next
(
iterator
)
for
q
,
var
in
net_inputs
.
items
():
temp
[
var
.
name
]
=
content
[
q
]
ret
.
append
(
temp
)
mask
.
append
(
1
if
flag
else
0
)
return
ret
,
mask
# do training
fetch_names
,
fetch_list
=
zip
(
*
fetches
.
items
())
...
...
@@ -576,8 +594,50 @@ class Controller(object):
epoch
=
0
time_begin
=
time
.
time
()
backbone_buffer
=
[]
def
multi_dev_reader
(
reader
,
dev_count
):
def
worker
(
reader
,
dev_count
,
queue
):
dev_batches
=
[]
for
index
,
data
in
enumerate
(
reader
()):
if
len
(
dev_batches
)
<
dev_count
:
dev_batches
.
append
(
data
)
if
len
(
dev_batches
)
==
dev_count
:
queue
.
put
((
dev_batches
,
0
))
dev_batches
=
[]
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if
len
(
dev_batches
)
>
0
:
num_pad
=
dev_count
-
len
(
dev_batches
)
for
i
in
range
(
len
(
dev_batches
),
dev_count
):
dev_batches
.
append
(
dev_batches
[
-
1
])
queue
.
put
((
dev_batches
,
num_pad
))
queue
.
put
(
None
)
queue
=
Queue
.
Queue
(
dev_count
*
2
)
p
=
Thread
(
target
=
worker
,
args
=
(
reader
,
dev_count
,
queue
))
p
.
daemon
=
True
p
.
start
()
while
True
:
ret
=
queue
.
get
()
if
ret
is
not
None
:
batches
,
num_pad
=
ret
queue
.
task_done
()
for
batch
in
batches
:
flag
=
num_pad
==
0
if
num_pad
>
0
:
num_pad
-=
1
yield
batch
,
flag
else
:
break
queue
.
join
()
joint_iterator
=
multi_dev_reader
(
self
.
_joint_iterator_fn
,
self
.
dev_count
)
while
not
train_finish
():
rt_outputs
=
self
.
exe
.
run
(
train_program
,
fetch_list
=
fetch_list
)
feed
,
mask
=
pack_multicard_feed
(
joint_iterator
,
self
.
_net_inputs
,
self
.
dev_count
)
rt_outputs
=
self
.
exe
.
run
(
train_program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
rt_outputs
=
{
k
:
v
for
k
,
v
in
zip
(
fetch_names
,
rt_outputs
)}
rt_task_id
=
np
.
squeeze
(
rt_outputs
[
'__task_id'
]).
tolist
()
rt_task_id
=
rt_task_id
[
0
]
if
isinstance
(
rt_task_id
,
list
)
else
rt_task_id
...
...
paddlepalm/utils/reader_helper.py
浏览文件 @
391a9bbb
...
...
@@ -105,11 +105,13 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p
return
iterator
def
create_joint_iterator_fn
(
iterators
,
iterator_prefixes
,
joint_shape_and_dtypes
,
mrs
,
outname_to_pos
,
dev_count
=
1
,
keep_one_task
=
True
,
verbose
=
0
):
def
create_joint_iterator_fn
(
iterators
,
iterator_prefixes
,
joint_shape_and_dtypes
,
mrs
,
outname_to_pos
,
dev_count
=
1
,
keep_one_task
=
True
,
verbose
=
0
,
return_type
=
'list'
):
"""
joint_shape_and_dtypes: 本质上是根据bb和parad的attr设定的,并且由reader中的attr自动填充-1(可变)维度得到,因此通过与iterator的校验可以完成runtime的batch正确性检查
"""
pos_to_outname
=
{
j
:
i
for
i
,
j
in
outname_to_pos
.
items
()}
task_ids
=
range
(
len
(
iterators
))
weights
=
[
mr
/
float
(
sum
(
mrs
))
for
mr
in
mrs
]
if
not
keep_one_task
:
...
...
@@ -202,7 +204,13 @@ def create_joint_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtype
print
(
np
.
shape
(
i
))
print
(
''
)
v
-=
1
yield
results
if
return_type
==
'list'
:
yield
results
elif
return_type
==
'dict'
:
temp
=
{}
for
pos
,
i
in
enumerate
(
results
):
temp
[
pos_to_outname
[
pos
]]
=
i
yield
temp
return
iterator
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录