Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
a9c5a794
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a9c5a794
编写于
1月 14, 2020
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add multihead
上级
aa7f8ed7
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
222 addition
and
29 deletion
+222
-29
demo/demo2/data/cls4mrqa/dev.tsv
demo/demo2/data/cls4mrqa/dev.tsv
+62
-0
demo/demo2/data/cls4mrqa/train.tsv
demo/demo2/data/cls4mrqa/train.tsv
+62
-0
demo/demo2/run.py
demo/demo2/run.py
+97
-28
demo/demo2/run.sh
demo/demo2/run.sh
+1
-1
未找到文件。
demo/demo2/data/cls4mrqa/dev.tsv
0 → 100644
浏览文件 @
a9c5a794
此差异已折叠。
点击以展开。
demo/demo2/data/cls4mrqa/train.tsv
0 → 100644
浏览文件 @
a9c5a794
此差异已折叠。
点击以展开。
demo/demo2/run.py
浏览文件 @
a9c5a794
# coding=utf-8
import
paddlepalm
as
palm
import
json
if
__name__
==
'__main__'
:
max_seqlen
=
512
batch_size
=
32
batch_size
=
4
num_epochs
=
2
lr
=
1e-3
vocab_path
=
'./pretrain/ernie/vocab.txt'
match_reader
=
palm
.
reader
.
match
(
train_file
,
vocab
,
\
max_seqlen
,
file_format
=
'csv'
,
tokenizer
=
'wordpiece'
,
\
lang
=
'en'
,
shuffle_train
=
True
)
mrc_reader
=
palm
.
reader
.
mrc
(
train_file
,
phase
=
'train'
)
mlm_reader
=
palm
.
reader
.
mlm
(
train_file
,
phase
=
'train'
)
palm
.
reader
.
train_file
=
'./data/cls4mrqa/train.tsv'
predict_file
=
'./data/cls4mrqa/dev.tsv'
match
=
palm
.
tasktype
.
cls
(
num_classes
=
4
)
mrc
=
palm
.
tasktype
.
match
(
learning_strategy
=
'pairwise'
)
mlm
=
palm
.
tasktype
.
mlm
()
mlm
.
print
()
config
=
json
.
load
(
open
(
'./pretrain/ernie/ernie_config.json'
))
# ernie = palm.backbone.ERNIE(...)
ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
)
bb_flags
=
palm
.
load_json
(
'./pretrain/ernie/ernie_config.json'
)
bb
=
palm
.
backbone
.
ernie
(
bb_flags
[
'xx'
],
xxx
)
bb
.
print
(
)
# cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen)
# cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen
)
# topic_trainer = palm.Trainer('topic_cls', cls_reader2, cls
)
# subj_trainer = palm.Trainer('subj_cls', cls_reader3, cls
)
match4mrqa
=
palm
.
Task
(
'match4mrqa'
,
match_reader
,
match_tt
)
mrc4mrqa
=
palm
.
Task
(
'match4mrqa'
,
match_reader
,
match_tt
)
# 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等
cls_reader
=
palm
.
reader
.
ClassifyReader
(
vocab_path
,
max_seqlen
)
cls_reader2
=
palm
.
reader
.
ClassifyReader
(
vocab_path
,
max_seqlen
)
predict_cls_reader
=
palm
.
reader
.
ClassifyReader
(
vocab_path
,
max_seqlen
,
phase
=
'predict'
)
print
(
cls_reader
.
outputs_attr
)
print
(
predict_cls_reader
.
outputs_attr
)
# 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段
cls_reader
.
register_with
(
ernie
)
print
(
cls_reader
.
outputs_attr
)
print
(
predict_cls_reader
.
outputs_attr
)
# match4mrqa.reuse_with(mrc4mrqa)
print
(
"preparing data..."
)
print
(
cls_reader
.
num_examples
)
cls_reader
.
load_data
(
train_file
,
batch_size
)
cls_reader2
.
load_data
(
train_file
,
batch_size
)
print
(
cls_reader
.
num_examples
)
print
(
'done!'
)
# 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的
cls_head
=
palm
.
head
.
Classify
(
4
,
1024
,
0.1
)
cls_head2
=
palm
.
head
.
Classify
(
4
,
1024
,
0.1
)
controller
=
palm
.
Controller
([
mrqa
,
match4mrqa
,
mlm4mrqa
])
# 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制
trainer
=
palm
.
Trainer
(
'cls'
)
trainer2
=
palm
.
Trainer
(
'senti_cls'
)
mh_trainer
=
palm
.
MultiHeadTrainer
([
trainer
,
trainer2
])
loss
=
controller
.
build_forward
(
bb
,
mask_task
=
[]
)
# match4mrqa.reuse_head_with(mrc4mrqa
)
n_steps
=
controller
.
estimate_train_steps
(
basetask
=
mrqa
,
num_epochs
=
2
,
batch_size
=
8
,
dev_count
=
4
)
adam
=
palm
.
optimizer
.
Adam
(
loss
)
sched
=
palm
.
schedualer
.
LinearWarmup
(
learning_rate
,
max_train_steps
=
n_steps
,
warmup_steps
=
0.1
*
n_steps
)
controller
.
build_backward
(
optimizer
=
adam
,
schedualer
=
sched
,
weight_decay
=
0.001
,
use_ema
=
True
,
ema_decay
=
0.999
)
# data_vars = cls_reader.build()
# output_vars = ernie.build(data_vars)
# cls_head.build({'backbone': output_vars, 'reader': data_vars})
controller
.
random_init_params
()
controller
.
load_pretrain
(
'../../pretrain_model/ernie/params'
)
controller
.
train
()
loss_var
=
mh_trainer
.
build_forward
(
ernie
,
[
cls_head
,
cls_head2
])
# controller.build_forward()
# Error! a head/backbone can be only build once! Try NOT to call build_forward method for any Trainer!
# n_steps = cls_reader.num_examples * num_epochs // batch_size
# warmup_steps = int(0.1 * n_steps)
# print(warmup_steps)
# sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
sched
=
None
adam
=
palm
.
optimizer
.
Adam
(
loss_var
,
lr
,
sched
)
mh_trainer
.
build_backward
(
optimizer
=
adam
,
weight_decay
=
0.001
)
mh_trainer
.
random_init_params
()
mh_trainer
.
load_pretrain
(
'pretrain/ernie/params'
)
# trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict')
mh_trainer
.
fit_readers_with_mixratio
([
cls_reader
,
cls_reader2
],
'cls'
,
2
)
mh_trainer
.
train
(
print_steps
=
1
)
# trainer.save()
# print('prepare to predict...')
# pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred')
# cls_pred_head = palm.head.Classify(4, 1024, phase='pred')
# trainer.build_predict_forward(pred_ernie, cls_pred_head)
# predict_cls_reader.load_data(predict_file, 8)
# print(predict_cls_reader.num_examples)
# predict_cls_reader.register_with(pred_ernie)
# trainer.fit_reader(predict_cls_reader, phase='predict')
# print('predicting..')
# trainer.predict(print_steps=20)
# controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa])
# loss = controller.build_forward(bb, mask_task=[])
# n_steps = controller.estimate_train_steps(basetask=mrqa, num_epochs=2, batch_size=8, dev_count=4)
# adam = palm.optimizer.Adam(loss)
# sched = palm.schedualer.LinearWarmup(learning_rate, max_train_steps=n_steps, warmup_steps=0.1*n_steps)
#
# controller.build_backward(optimizer=adam, schedualer=sched, weight_decay=0.001, use_ema=True, ema_decay=0.999)
# controller.random_init_params()
# controller.load_pretrain('../../pretrain_model/ernie/params')
# controller.train()
...
...
demo/demo2/run.sh
浏览文件 @
a9c5a794
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
3
python run.py
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录