Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
1bb38efb
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1bb38efb
编写于
2月 04, 2020
作者:
W
wangxiao1021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multi-task example
上级
4cc989d2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
52 addition
and
59 deletion
+52
-59
examples/multi-task/process.py
examples/multi-task/process.py
+11
-20
examples/multi-task/run.py
examples/multi-task/run.py
+12
-20
paddlepalm/distribute/reader.py
paddlepalm/distribute/reader.py
+8
-4
paddlepalm/head/mlm.py
paddlepalm/head/mlm.py
+1
-1
paddlepalm/multihead_trainer.py
paddlepalm/multihead_trainer.py
+4
-4
paddlepalm/trainer.py
paddlepalm/trainer.py
+6
-7
paddlepalm/utils/reader_helper.py
paddlepalm/utils/reader_helper.py
+10
-3
未找到文件。
examples/multi-task/process.py
浏览文件 @
1bb38efb
...
...
@@ -4,33 +4,25 @@ import os
import
io
abs_path
=
os
.
path
.
abspath
(
__file__
)
dst_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"data/mlm/"
)
dst_dir2
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"data/match/"
)
dst_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"data/match/"
)
if
not
os
.
path
.
exists
(
dst_dir
)
or
not
os
.
path
.
isdir
(
dst_dir
):
os
.
makedirs
(
dst_dir
)
if
not
os
.
path
.
exists
(
dst_dir2
)
or
not
os
.
path
.
isdir
(
dst_dir2
):
os
.
makedirs
(
dst_dir2
)
os
.
mknod
(
"./data/mlm/train.tsv"
)
os
.
mknod
(
"./data/match/train.tsv"
)
with
io
.
open
(
"./data/mrc/train.json"
,
"r"
,
encoding
=
'utf-8'
)
as
f
ile
:
data
=
json
.
load
(
f
ile
)[
"data"
]
with
io
.
open
(
"./data/mrc/train.json"
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)[
"data"
]
i
=
0
with
open
(
"./data/mlm/train.tsv"
,
"w"
)
as
f
:
f
.
write
(
"text_a
\n
"
)
with
open
(
"./data/match/train.tsv"
,
"w"
)
as
f2
:
f2
.
write
(
"text_a
\t
text_b
\t
label
\n
"
)
for
dd
in
data
:
for
d
in
dd
[
"paragraphs"
]:
text_a_mlm
=
d
[
"context"
]
l
=
text_a_mlm
+
"
\n
"
f
.
write
(
l
.
encode
(
"utf-8"
))
with
open
(
"./data/match/train.tsv"
,
"w"
)
as
f2
:
f2
.
write
(
"text_a
\t
text_b
\t
label
\n
"
)
for
dd
in
data
:
for
d
in
dd
[
"paragraphs"
]:
context
=
d
[
"context"
]
for
qa
in
d
[
"qas"
]:
text_a
=
qa
[
"question"
]
answer
=
qa
[
"answers"
][
0
]
text_b
=
answer
[
"text"
]
start_pos
=
answer
[
"answer_start"
]
text_b_neg
=
text_a_mlm
[
0
:
start_pos
]
text_b_neg
=
context
[
0
:
start_pos
]
if
len
(
text_b_neg
)
>
512
:
text_b_neg
=
text_b_neg
[
-
512
:
-
1
]
l1
=
text_a
+
"
\t
"
+
text_b
+
"
\t
1
\n
"
...
...
@@ -40,6 +32,5 @@ with io.open("./data/mrc/train.json", "r", encoding='utf-8') as file:
f2
.
write
(
l2
.
encode
(
"utf-8"
))
i
+=
2
f2
.
close
()
f
.
close
()
file
.
close
()
f2
.
close
()
f
.
close
()
examples/multi-task/run.py
浏览文件 @
1bb38efb
...
...
@@ -7,7 +7,7 @@ from paddlepalm.distribute import gpu_dev_count
if
__name__
==
'__main__'
:
# configs
max_seqlen
=
512
max_seqlen
=
128
batch_size
=
8
num_epochs
=
8
lr
=
3e-5
...
...
@@ -15,7 +15,7 @@ if __name__ == '__main__':
max_query_len
=
64
max_ans_len
=
128
weight_decay
=
0.01
print_steps
=
20
print_steps
=
1
num_classes
=
2
random_seed
=
1
dropout_prob
=
0.1
...
...
@@ -33,43 +33,36 @@ if __name__ == '__main__':
pre_params
=
'./pretrain/ernie-zh-base/params'
config
=
json
.
load
(
open
(
'./pretrain/ernie-zh-base/ernie_config.json'
))
input_dim
=
config
[
'hidden_size'
]
vocab_size
=
config
[
'vocab_size'
]
hidden_act
=
config
[
'hidden_act'
]
# ----------------------- for training -----------------------
# step 1-1: create readers for training
mrc_reader
=
palm
.
reader
.
MRCReader
(
vocab_path
,
max_seqlen
,
max_query_len
,
doc_stride
,
do_lower_case
=
do_lower_case
)
match_reader
=
palm
.
reader
.
MatchReader
(
vocab_path
,
max_seqlen
,
seed
=
random_seed
)
# mlm_reader = palm.reader.MaskLMReader(vocab_path, max_seqlen, seed=random_seed)
# step 1-2: load the training data
mrc_reader
.
load_data
(
train_file
,
file_format
=
'json'
,
num_epochs
=
None
,
batch_size
=
batch_size
)
match_reader
.
load_data
(
train_file_match
,
file_format
=
'tsv'
,
num_epochs
=
None
,
batch_size
=
batch_size
)
# mlm_reader.load_data(train_file_mlm, file_format='tsv', num_epochs=num_epochs, batch_size=batch_size)
# step 2: create a backbone of the model to extract text features
ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
)
# step 3: register the backbone in readers
mrc_reader
.
register_with
(
ernie
)
match_reader
.
register_with
(
ernie
)
# mlm_reader.register_with(ernie)
# step 4: create task output heads
mrc_head
=
palm
.
head
.
MRC
(
max_query_len
,
config
[
'hidden_size'
],
do_lower_case
=
do_lower_case
,
max_ans_len
=
max_ans_len
)
match_head
=
palm
.
head
.
Match
(
num_classes
,
input_dim
,
dropout_prob
)
mlm_head
=
palm
.
head
.
MaskLM
(
input_dim
,
hidden_act
,
dropout_prob
)
# step 5-1: create a task trainer
trainer_mrc
=
palm
.
Trainer
(
task_name
,
mix_ratio
=
1.0
)
# trainer_mlm = palm.Trainer("mlm", mix_ratio=0.5)
trainer_match
=
palm
.
Trainer
(
"match"
,
mix_ratio
=
0.5
)
trainer
=
palm
.
MultiHeadTrainer
([
trainer_mrc
,
trainer_match
])
# step 5-2: build forward graph with backbone and task head
loss_var
=
trainer
.
build_forward
(
ernie
,
[
mrc_head
,
match_head
])
# step 6-1*: use warmup
n_steps
=
mrc_reader
.
num_examples
*
num_epochs
//
batch_size
n_steps
=
mrc_reader
.
num_examples
*
2
*
num_epochs
//
batch_size
warmup_steps
=
int
(
0.1
*
n_steps
)
sched
=
palm
.
lr_sched
.
TriangularSchedualer
(
warmup_steps
,
n_steps
)
# step 6-2: create a optimizer
...
...
@@ -79,12 +72,11 @@ if __name__ == '__main__':
# step 7: fit prepared reader and data
trainer
.
fit_readers_with_mixratio
([
mrc_reader
,
match_reader
],
task_name
,
num_epochs
)
# step 8-1*: load pretrained parameters
trainer
.
load_pretrain
(
pre_params
)
# step 8-2*: set saver to save model
# save_steps = n_steps-8
save_steps
=
1520
save_steps
=
n_steps
-
batch_size
trainer
.
set_saver
(
save_path
=
save_path
,
save_steps
=
save_steps
,
save_type
=
save_type
)
# step 8-3: start training
trainer
.
train
(
print_steps
=
print_steps
)
...
...
@@ -106,15 +98,15 @@ if __name__ == '__main__':
mrc_pred_head
=
palm
.
head
.
MRC
(
max_query_len
,
config
[
'hidden_size'
],
do_lower_case
=
do_lower_case
,
max_ans_len
=
max_ans_len
,
phase
=
'predict'
)
# step 5: build forward graph with backbone and task head
trainer
.
build_predict_forward
(
pred_ernie
,
mrc_pred_head
)
trainer
_mrc
.
build_predict_forward
(
pred_ernie
,
mrc_pred_head
)
# step 6: load pretrained model
pred_model_path
=
'./outputs/ckpt.step'
+
str
(
12160
)
pred_ckpt
=
trainer
.
load_ckpt
(
pred_model_path
)
pred_model_path
=
'./outputs/ckpt.step'
+
str
(
save_steps
)
pred_ckpt
=
trainer
_mrc
.
load_ckpt
(
pred_model_path
)
# step 7: fit prepared reader and data
trainer
.
fit_reader
(
predict_mrc_reader
,
phase
=
'predict'
)
trainer
_mrc
.
fit_reader
(
predict_mrc_reader
,
phase
=
'predict'
)
# step 8: predict
print
(
'predicting..'
)
trainer
.
predict
(
print_steps
=
print_steps
,
output_dir
=
"outputs/"
)
trainer
_mrc
.
predict
(
print_steps
=
print_steps
,
output_dir
=
"outputs/"
)
paddlepalm/distribute/reader.py
浏览文件 @
1bb38efb
...
...
@@ -57,9 +57,9 @@ def yield_pieces(data, distribute_strategy, batch_size):
yield
temp
def
data_feeder
(
reader
,
postprocess_fn
=
None
,
prefetch_steps
=
2
):
def
data_feeder
(
reader
,
postprocess_fn
=
None
,
prefetch_steps
=
2
,
phase
=
'train'
,
is_multi
=
False
):
if
postprocess_fn
is
None
:
def
postprocess_fn
(
batch
):
def
postprocess_fn
(
batch
,
id
=-
1
,
phase
=
'train'
,
is_multi
=
False
):
return
batch
def
worker
(
reader
,
dev_count
,
queue
):
...
...
@@ -90,6 +90,10 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
queue
.
task_done
()
if
ret
is
not
None
:
batches
,
num_pad
=
ret
if
dev_count
>
1
and
phase
==
'train'
and
is_multi
:
id
=
batches
[
0
][
'__task_id'
][
0
]
else
:
id
=
-
1
batch_buf
=
[]
flag_buf
=
[]
for
idx
,
batch
in
enumerate
(
batches
):
...
...
@@ -97,8 +101,8 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2):
flag
=
idx
-
len
(
batches
)
<
-
num_pad
# if num_pad > 0:
# num_pad -= 1
# batch = postprocess_fn(batch, id
)
batch
=
postprocess_fn
(
batch
)
batch
=
postprocess_fn
(
batch
,
id
,
phase
,
is_multi
=
is_multi
)
#
batch = postprocess_fn(batch)
batch_buf
.
append
(
batch
)
flag_buf
.
append
(
flag
)
yield
batch_buf
,
flag_buf
...
...
paddlepalm/head/mlm.py
浏览文件 @
1bb38efb
...
...
@@ -93,7 +93,7 @@ class MaskLM(Head):
param_attr
=
fluid
.
ParamAttr
(
name
=
scope_name
+
'mask_lm_trans_fc.w_0'
,
initializer
=
_param_initializer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
scope_name
+
'mask_lm_trans_fc.b_0'
))
bias_attr
=
fluid
.
ParamAttr
(
name
=
scope_name
+
'mask_lm_trans_fc.b_0'
))
# transform: layer norm
mask_trans_feat
=
pre_process_layer
(
mask_trans_feat
,
'n'
,
name
=
scope_name
+
'mask_lm_trans'
)
...
...
paddlepalm/multihead_trainer.py
浏览文件 @
1bb38efb
...
...
@@ -201,9 +201,9 @@ class MultiHeadTrainer(Trainer):
feed_batch_process_fn
=
reader_helper
.
create_feed_batch_process_fn
(
net_inputs
)
if
gpu_dev_count
>
1
:
distribute_feeder_fn
=
data_feeder
(
iterator_fn
,
feed_batch_process_fn
)
distribute_feeder_fn
=
data_feeder
(
iterator_fn
,
feed_batch_process_fn
,
phase
=
phase
,
is_multi
=
True
)
else
:
distribute_feeder_fn
=
iterator_fn
distribute_feeder_fn
=
iterator_fn
()
if
phase
==
'train'
:
self
.
_train_reader
=
distribute_feeder_fn
...
...
@@ -277,8 +277,8 @@ class MultiHeadTrainer(Trainer):
def
train_one_step
(
self
,
batch
):
if
dev_count
>
1
:
assert
isinstance
(
batch
,
list
)
task_id
=
batch
[
0
][
'__task_id'
][
0
]
assert
isinstance
(
batch
,
tuple
)
task_id
=
batch
[
0
][
0
][
'__task_id'
][
0
]
else
:
assert
isinstance
(
batch
,
dict
)
task_id
=
batch
[
'__task_id'
][
0
]
...
...
paddlepalm/trainer.py
浏览文件 @
1bb38efb
...
...
@@ -415,7 +415,7 @@ class Trainer(object):
self
.
_raw_iterator_fn
=
iterator_fn
feed_batch_process_fn
=
reader_helper
.
create_feed_batch_process_fn
(
net_inputs
)
if
gpu_dev_count
>
1
:
distribute_feeder_fn
=
data_feeder
(
iterator_fn
,
feed_batch_process_fn
)
distribute_feeder_fn
=
data_feeder
(
iterator_fn
,
feed_batch_process_fn
,
phase
=
phase
)
else
:
distribute_feeder_fn
=
iterator_fn
()
...
...
@@ -718,9 +718,9 @@ class Trainer(object):
feed
,
mask
=
batch
rt_outputs
=
exe
.
run
(
distribute_train_prog
,
feed
=
feed
,
fetch_list
=
fetch_list
)
num_fakes
=
decode_fake
(
len
(
rt_outputs
[
0
]),
mask
,
self
.
_train_batch_size
)
for
_
in
range
(
num_fakes
)
:
for
item
in
rt_outputs
:
item
.
pop
()
if
num_fakes
:
rt_outputs
=
[
i
[:
-
num_fakes
]
for
i
in
rt_outputs
]
else
:
feed
=
self
.
_feed_batch_process_fn
(
batch
)
rt_outputs
=
exe
.
run
(
distribute_train_prog
,
feed
=
feed
,
fetch_list
=
fetch_list
)
...
...
@@ -735,9 +735,8 @@ class Trainer(object):
feed
,
mask
=
batch
rt_outputs
=
self
.
_exe
.
run
(
self
.
_distribute_pred_prog
,
feed
=
feed
,
fetch_list
=
self
.
_pred_fetch_list
)
num_fakes
=
decode_fake
(
len
(
rt_outputs
[
0
]),
mask
,
self
.
_predict_batch_size
)
for
_
in
range
(
num_fakes
):
for
item
in
rt_outputs
:
item
.
pop
()
if
num_fakes
:
rt_outputs
=
[
i
[:
-
num_fakes
]
for
i
in
rt_outputs
]
else
:
feed
=
self
.
_pred_feed_batch_process_fn
(
batch
)
rt_outputs
=
self
.
_exe
.
run
(
self
.
_distribute_pred_prog
,
feed
=
feed
,
fetch_list
=
self
.
_pred_fetch_list
)
...
...
paddlepalm/utils/reader_helper.py
浏览文件 @
1bb38efb
...
...
@@ -21,13 +21,20 @@ import numpy as np
import
paddle
from
paddle
import
fluid
from
paddle.fluid
import
layers
from
paddlepalm.distribute
import
gpu_dev_count
,
cpu_dev_count
dev_count
=
1
if
gpu_dev_count
<=
1
else
gpu_dev_count
def
create_feed_batch_process_fn
(
net_inputs
):
def
feed_batch_process_fn
(
data
):
def
feed_batch_process_fn
(
data
,
id
=-
1
,
phase
=
'train'
,
is_multi
=
False
):
temp
=
{}
for
q
,
var
in
net_inputs
.
items
():
if
dev_count
>
1
and
phase
==
'train'
and
is_multi
:
inputs
=
net_inputs
[
id
]
else
:
inputs
=
net_inputs
for
q
,
var
in
inputs
.
items
():
if
isinstance
(
var
,
str
)
or
isinstance
(
var
,
unicode
):
temp
[
var
]
=
data
[
q
]
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录