Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
1acb2cec
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
4
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1acb2cec
编写于
11月 05, 2019
作者:
X
Xiaoyao Xi
提交者:
GitHub
11月 05, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #20 from xixiaoyao/master
fix bugs
上级
7398cfed
c6e33be8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
71 addition
and
19 deletion
+71
-19
paddlepalm/interface.py
paddlepalm/interface.py
+6
-2
paddlepalm/mtl_controller.py
paddlepalm/mtl_controller.py
+19
-10
paddlepalm/reader/match.py
paddlepalm/reader/match.py
+0
-4
paddlepalm/task_paradigm/cls.py
paddlepalm/task_paradigm/cls.py
+22
-1
paddlepalm/task_paradigm/match.py
paddlepalm/task_paradigm/match.py
+24
-2
未找到文件。
paddlepalm/interface.py
浏览文件 @
1acb2cec
...
...
@@ -154,7 +154,11 @@ class task_paradigm(object):
raise
NotImplementedError
()
def
build
(
self
,
inputs
):
@
property
def
epoch_inputs_attrs
(
self
):
return
{}
def
build
(
self
,
inputs
,
scope_name
=
""
):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
...
...
@@ -168,6 +172,6 @@ class task_paradigm(object):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。"""
pass
def
post_postprocess
(
self
,
global_buffer
):
def
epoch_postprocess
(
self
,
post_inputs
):
pass
paddlepalm/mtl_controller.py
浏览文件 @
1acb2cec
...
...
@@ -182,21 +182,17 @@ def _fit_attr(conf, fit_attr, strict=False):
class
Controller
(
object
):
def
__init__
(
self
,
config
=
None
,
task_dir
=
'.'
,
for_train
=
True
):
def
__init__
(
self
,
config
,
task_dir
=
'.'
,
for_train
=
True
):
"""
Args:
config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径;
"""
self
.
_for_train
=
for_train
# default mtl_conf
# if config is None and config_path is None:
# raise ValueError('For config and config_path, at least one of them should be set.')
assert
isinstance
(
config
,
str
)
or
isinstance
(
config
,
dict
),
"a config dict or config file path is required to create a Controller."
if
isinstance
(
config
,
str
):
mtl_conf
=
_parse_yaml
(
config
,
support_cmd_line
=
True
)
# if config is not None:
# mtl_conf = _merge_conf(config, mtl_conf)
else
:
mtl_conf
=
config
...
...
@@ -518,6 +514,11 @@ class Controller(object):
def
_init_pred
(
self
,
instance
,
infer_model_path
):
inst
=
instance
if
'pred_output_path'
not
in
inst
.
config
:
inst
.
config
[
'pred_output_path'
]
=
os
.
path
.
join
(
inst
.
config
.
get
(
'save_path'
,
'.'
),
inst
.
name
)
if
not
os
.
path
.
exists
(
inst
.
config
[
'pred_output_path'
]):
os
.
makedirs
(
inst
.
config
[
'pred_output_path'
])
pred_backbone
=
self
.
Backbone
(
self
.
bb_conf
,
phase
=
'pred'
)
pred_parad
=
inst
.
Paradigm
(
inst
.
config
,
phase
=
'pred'
,
backbone_config
=
self
.
bb_conf
)
...
...
@@ -563,7 +564,12 @@ class Controller(object):
finish
=
[]
for
inst
in
instances
:
if
inst
.
is_target
:
finish
.
append
(
False
)
if
inst
.
expected_train_steps
>
0
:
finish
.
append
(
False
)
else
:
finish
.
append
(
True
)
print
(
inst
.
name
+
': train finished!'
)
inst
.
save
()
def
train_finish
():
for
inst
in
instances
:
...
...
@@ -641,9 +647,11 @@ class Controller(object):
pred_prog
=
self
.
_init_pred
(
instance
,
inference_model_dir
)
inst
=
instance
print
(
inst
.
name
+
": loading data..."
)
inst
.
reader
[
'pred'
].
load_data
()
fetch_names
,
fetch_vars
=
inst
.
pred_fetch_list
print
(
'predicting...'
)
mapper
=
{
k
:
v
for
k
,
v
in
inst
.
pred_input
}
buf
=
[]
for
feed
in
inst
.
reader
[
'pred'
].
iterator
():
...
...
@@ -653,12 +661,13 @@ class Controller(object):
rt_outputs
=
self
.
exe
.
run
(
pred_prog
,
feed
,
fetch_vars
)
rt_outputs
=
{
k
:
v
for
k
,
v
in
zip
(
fetch_names
,
rt_outputs
)}
inst
.
postprocess
(
rt_outputs
,
phase
=
'pred'
)
reader_outputs
=
inst
.
reader
[
'pred'
].
get_epoch_outputs
()
if
inst
.
task_layer
[
'pred'
].
epoch_inputs_attrs
:
reader_outputs
=
inst
.
reader
[
'pred'
].
get_epoch_outputs
()
else
:
reader_outputs
=
None
inst
.
epoch_postprocess
({
'reader'
:
reader_outputs
},
phase
=
'pred'
)
if
__name__
==
'__main__'
:
assert
len
(
sys
.
argv
)
==
2
,
"Usage: python mtl_controller.py <mtl_conf_path>"
conf_path
=
sys
.
argv
[
1
]
...
...
paddlepalm/reader/match.py
浏览文件 @
1acb2cec
...
...
@@ -93,10 +93,6 @@ class Reader(reader):
for
batch
in
self
.
_data_generator
():
yield
list_to_dict
(
batch
)
def
get_epoch_outputs
(
self
):
return
{
'examples'
:
self
.
_reader
.
get_examples
(
self
.
_phase
),
'features'
:
self
.
_reader
.
get_features
(
self
.
_phase
)}
@
property
def
num_examples
(
self
):
return
self
.
_reader
.
get_num_examples
(
phase
=
self
.
_phase
)
...
...
paddlepalm/task_paradigm/cls.py
浏览文件 @
1acb2cec
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
import
paddle.fluid
as
fluid
from
paddlepalm.interface
import
task_paradigm
from
paddle.fluid
import
layers
from
paddlepalm.interface
import
task_paradigm
import
numpy
as
np
import
os
class
TaskParadigm
(
task_paradigm
):
'''
...
...
@@ -35,6 +37,8 @@ class TaskParadigm(task_paradigm):
self
.
_dropout_prob
=
config
[
'dropout_prob'
]
else
:
self
.
_dropout_prob
=
backbone_config
.
get
(
'hidden_dropout_prob'
,
0.0
)
self
.
_pred_output_path
=
config
.
get
(
'pred_output_path'
,
None
)
self
.
_preds
=
[]
@
property
def
inputs_attrs
(
self
):
...
...
@@ -78,3 +82,20 @@ class TaskParadigm(task_paradigm):
else
:
return
{
"logits"
:
logits
}
def
postprocess
(
self
,
rt_outputs
):
if
not
self
.
_is_training
:
logits
=
rt_outputs
[
'logits'
]
preds
=
np
.
argmax
(
logits
,
-
1
)
self
.
_preds
.
extend
(
preds
.
tolist
())
def
epoch_postprocess
(
self
,
post_inputs
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
if
self
.
_pred_output_path
is
None
:
raise
ValueError
(
'argument pred_output_path not found in config. Please add it into config dict/file.'
)
with
open
(
os
.
path
.
join
(
self
.
_pred_output_path
,
'predictions.json'
),
'w'
)
as
writer
:
for
p
in
self
.
_preds
:
writer
.
write
(
str
(
p
)
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
self
.
_pred_output_path
,
'predictions.json'
))
paddlepalm/task_paradigm/match.py
浏览文件 @
1acb2cec
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
import
paddle.fluid
as
fluid
from
paddlepalm.interface
import
task_paradigm
from
paddle.fluid
import
layers
from
paddlepalm.interface
import
task_paradigm
import
numpy
as
np
import
os
class
TaskParadigm
(
task_paradigm
):
'''
...
...
@@ -35,6 +37,9 @@ class TaskParadigm(task_paradigm):
else
:
self
.
_dropout_prob
=
backbone_config
.
get
(
'hidden_dropout_prob'
,
0.0
)
self
.
_pred_output_path
=
config
.
get
(
'pred_output_path'
,
None
)
self
.
_preds
=
[]
@
property
def
inputs_attrs
(
self
):
...
...
@@ -50,7 +55,7 @@ class TaskParadigm(task_paradigm):
if
self
.
_is_training
:
return
{
"loss"
:
[[
1
],
'float32'
]}
else
:
return
{
"logits"
:
[[
-
1
,
1
],
'float32'
]}
return
{
"logits"
:
[[
-
1
,
2
],
'float32'
]}
def
build
(
self
,
inputs
,
scope_name
=
""
):
if
self
.
_is_training
:
...
...
@@ -81,3 +86,20 @@ class TaskParadigm(task_paradigm):
else
:
return
{
'logits'
:
logits
}
def
postprocess
(
self
,
rt_outputs
):
if
not
self
.
_is_training
:
logits
=
rt_outputs
[
'logits'
]
preds
=
np
.
argmax
(
logits
,
-
1
)
self
.
_preds
.
extend
(
preds
.
tolist
())
def
epoch_postprocess
(
self
,
post_inputs
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
if
self
.
_pred_output_path
is
None
:
raise
ValueError
(
'argument pred_output_path not found in config. Please add it into config dict/file.'
)
with
open
(
os
.
path
.
join
(
self
.
_pred_output_path
,
'predictions.json'
),
'w'
)
as
writer
:
for
p
in
self
.
_preds
:
writer
.
write
(
str
(
p
)
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
self
.
_pred_output_path
,
'predictions.json'
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录