未验证 提交 7d7dadd1 编写于 作者: M Meiyim 提交者: GitHub

fix #470 (#471)

* Update README.md

* fix grad_clip

* fix distill

* up distill
上级 241c0282
...@@ -82,8 +82,8 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz ...@@ -82,8 +82,8 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz
model = D.parallel.DataParallel(model, ctx) model = D.parallel.DataParallel(model, ctx)
max_steps = len(train_features) * args.epoch // args.bsz max_steps = len(train_features) * args.epoch // args.bsz
opt = AdamW(learning_rate=args.lr, parameter_list=model.parameters(), weight_decay=args.wd)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
opt = AdamW(learning_rate=args.lr, parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip)
train_dataset = train_dataset \ train_dataset = train_dataset \
.repeat() \ .repeat() \
...@@ -97,7 +97,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz ...@@ -97,7 +97,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz
scaled_loss = model.scale_loss(loss) scaled_loss = model.scale_loss(loss)
scaled_loss.backward() scaled_loss.backward()
model.apply_collective_grads() model.apply_collective_grads()
opt.minimize(scaled_loss, grad_clip=g_clip) opt.minimize(scaled_loss)
model.clear_gradients() model.clear_gradients()
if D.parallel.Env().dev_id == 0 and step % 10 == 0: if D.parallel.Env().dev_id == 0 and step % 10 == 0:
log.debug('[step %d] train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr())) log.debug('[step %d] train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr()))
......
...@@ -26,7 +26,6 @@ from functools import reduce, partial ...@@ -26,7 +26,6 @@ from functools import reduce, partial
import numpy as np import numpy as np
import multiprocessing import multiprocessing
import pickle import pickle
import jieba
import logging import logging
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
......
...@@ -95,12 +95,14 @@ if __name__ == '__main__': ...@@ -95,12 +95,14 @@ if __name__ == '__main__':
dev_ds.data_shapes = shapes dev_ds.data_shapes = shapes
dev_ds.data_types = types dev_ds.data_types = types
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
opt = AdamW(learning_rate=LinearDecay( opt = AdamW(learning_rate=LinearDecay(
args.lr, args.lr,
int(args.warmup_proportion * args.max_steps), args.max_steps), int(args.warmup_proportion * args.max_steps), args.max_steps),
parameter_list=model.parameters(), parameter_list=model.parameters(),
weight_decay=args.wd) weight_decay=args.wd,
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental grad_clip=g_clip)
for epoch in range(args.epoch): for epoch in range(args.epoch):
for step, d in enumerate(tqdm(train_ds.start(place), desc='training')): for step, d in enumerate(tqdm(train_ds.start(place), desc='training')):
ids, sids, label = d ids, sids, label = d
...@@ -108,7 +110,7 @@ if __name__ == '__main__': ...@@ -108,7 +110,7 @@ if __name__ == '__main__':
loss.backward() loss.backward()
if step % 10 == 0: if step % 10 == 0:
log.debug('train loss %.5f lr %.3e' % (loss.numpy(), opt.current_step_lr())) log.debug('train loss %.5f lr %.3e' % (loss.numpy(), opt.current_step_lr()))
opt.minimize(loss, grad_clip=g_clip) opt.minimize(loss)
model.clear_gradients() model.clear_gradients()
if step % 100 == 0: if step % 100 == 0:
acc = [] acc = []
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
* [效果验证](#效果验证) * [效果验证](#效果验证)
* [Case#1 用户提供“无标注数据”](#case1) * [Case#1 用户提供“无标注数据”](#case1)
* [Case#2 用户未提供“无标注数据”](#case2) * [Case#2 用户未提供“无标注数据”](#case2)
* [FAQ](#faq)
# ERNIE Slim 数据蒸馏 # ERNIE Slim 数据蒸馏
在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。 在ERNIE强大的语义理解能力背后,是需要同样强大的算力才能支撑起如此大规模模型的训练和预测。很多工业应用场景对性能要求较高,若不能有效压缩则无法实际应用。
...@@ -37,7 +36,7 @@ ...@@ -37,7 +36,7 @@
# 使用教程 # 使用教程
我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从[这里](https://ernie.bj.bcebos.com/distill_data.tar.gz)下载。即可执行下面的脚本开始蒸馏。 我们采用上述3种增强策略制作了chnsenticorp的增强数据:增强后的数据为原训练数据的10倍(96000行),可以从[这里](https://ernie-github.cdn.bcebos.com/data-chnsenticorp-distill.tar.gz)下载。即可执行下面的脚本开始蒸馏。
```shell ```shell
python ./distill/distill.py python ./distill/distill.py
...@@ -64,8 +63,3 @@ python ./distill/distill.py ...@@ -64,8 +63,3 @@ python ./distill/distill.py
|非ERNIE基线(LSTM)|91.2%| |非ERNIE基线(LSTM)|91.2%|
|**+ 数据蒸馏**|93.9%| |**+ 数据蒸馏**|93.9%|
# FAQ
### FQA1: 预测同时蒸馏报错:`Client call failed`
终端打印的错误是client的日志,server端的日志在前面。一般来说可能是server显存超限导致。这种时候需要在student模型finetune的脚本中使用`--server_batch_size ` 显示控制请求服务的batch大小。
...@@ -30,12 +30,13 @@ from ernie.optimization import AdamW, LinearDecay ...@@ -30,12 +30,13 @@ from ernie.optimization import AdamW, LinearDecay
# 本例子采用chnsenticorp中文情感识别任务作为示范;并且事先通过数据增强扩充了蒸馏所需的无监督数据 # 本例子采用chnsenticorp中文情感识别任务作为示范;并且事先通过数据增强扩充了蒸馏所需的无监督数据
# #
# 请从“”下载数据;并数据存放在 ./chnsenticorp-data/ # 下载数据;并存放在 ./chnsenticorp-data/
# 数据分为3列:原文;空格切词;情感标签 # 数据分为3列:原文;空格切词;情感标签
# 其中第一列为ERNIE的输入;第二列为BoW词袋模型的输入 # 其中第一列为ERNIE的输入;第二列为BoW词袋模型的输入
# 事先统计好的BoW 词典在 ./chnsenticorp-data/vocab.bow.txt # 事先统计好的BoW 词典在 ./chnsenticorp-data/vocab.bow.txt
# 定义finetune teacher模型所需要的超参数 # 定义finetune teacher模型所需要的超参数
DATA_DIR='./chnsenticorp-data/'
SEQLEN=256 SEQLEN=256
BATCH=32 BATCH=32
EPOCH=10 EPOCH=10
...@@ -43,7 +44,7 @@ LR=5e-5 ...@@ -43,7 +44,7 @@ LR=5e-5
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
student_vocab = {i.strip(): l for l, i in enumerate(open('./chnsenticorp-data/vocab.bow.txt').readlines())} student_vocab = {i.strip(): l for l, i in enumerate(open(os.path.join(DATA_DIR, 'vocab.bow.txt')).readlines())}
def space_tokenizer(i): def space_tokenizer(i):
return i.decode('utf8').split() return i.decode('utf8').split()
...@@ -63,11 +64,17 @@ def map_fn(seg_a, seg_a_student, label): ...@@ -63,11 +64,17 @@ def map_fn(seg_a, seg_a_student, label):
return seg_a_student, sentence, segments, label return seg_a_student, sentence, segments, label
train_ds = feature_column.build_dataset('train', data_dir='./chnsenticorp-data/train/', shuffle=True, repeat=False, use_gz=False) .map(map_fn) .padded_batch(BATCH,) train_ds = feature_column.build_dataset('train', data_dir=os.path.join(DATA_DIR, 'train/'), shuffle=True, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(BATCH)
train_ds_unlabel = feature_column.build_dataset('train-da', data_dir='./chnsenticorp-data/train-data-augmented/', shuffle=True, repeat=False, use_gz=False) .map(map_fn) .padded_batch(BATCH,) train_ds_unlabel = feature_column.build_dataset('train-da', data_dir=os.path.join(DATA_DIR, 'train-data-augmented/'), shuffle=True, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(BATCH)
dev_ds = feature_column.build_dataset('dev', data_dir='./chnsenticorp-data/dev/', shuffle=False, repeat=False, use_gz=False) .map(map_fn) .padded_batch(BATCH,) dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(DATA_DIR, 'dev/'), shuffle=False, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(BATCH,)
shapes = ([-1,SEQLEN],[-1,SEQLEN], [-1, SEQLEN], [-1]) shapes = ([-1,SEQLEN],[-1,SEQLEN], [-1, SEQLEN], [-1])
types = ('int64', 'int64', 'int64', 'int64') types = ('int64', 'int64', 'int64', 'int64')
...@@ -99,15 +106,15 @@ def evaluate_teacher(model, dataset): ...@@ -99,15 +106,15 @@ def evaluate_teacher(model, dataset):
teacher_model = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=2) teacher_model = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=2)
teacher_model.train() teacher_model.train()
if not os.path.exists('./teacher_model.pdparams'): if not os.path.exists('./teacher_model.pdparams'):
opt = AdamW(learning_rate=LinearDecay(LR, 9600*EPOCH*0.1/BATCH, 9600*EPOCH/BATCH), parameter_list=teacher_model.parameters(), weight_decay=0.01)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) g_clip = F.clip.GradientClipByGlobalNorm(1.0)
opt = AdamW(learning_rate=LinearDecay(LR, 9600*EPOCH*0.1/BATCH, 9600*EPOCH/BATCH), parameter_list=teacher_model.parameters(), weight_decay=0.01, grad_clip=g_clip)
for epoch in range(EPOCH): for epoch in range(EPOCH):
for step, (ids_student, ids, sids, labels) in enumerate(train_ds.start(place)): for step, (ids_student, ids, sids, labels) in enumerate(train_ds.start(place)):
loss, logits = teacher_model(ids, labels=labels) loss, logits = teacher_model(ids, labels=labels)
loss.backward() loss.backward()
if step % 10 == 0: if step % 10 == 0:
print('[step %03d] teacher train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr())) print('[step %03d] teacher train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr()))
opt.minimize(loss, grad_clip=g_clip) opt.minimize(loss)
teacher_model.clear_gradients() teacher_model.clear_gradients()
if step % 100 == 0: if step % 100 == 0:
f1 = evaluate_teacher(teacher_model, dev_ds) f1 = evaluate_teacher(teacher_model, dev_ds)
...@@ -199,32 +206,34 @@ def KL(pred, target): ...@@ -199,32 +206,34 @@ def KL(pred, target):
teacher_model.eval() teacher_model.eval()
model = BOW() model = BOW()
opt = AdamW(learning_rate=LR, parameter_list=model.parameters(), weight_decay=0.01)
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
opt = AdamW(learning_rate=LR, parameter_list=model.parameters(), weight_decay=0.01, grad_clip=g_clip)
model.train() model.train()
for epoch in range(EPOCH): for epoch in range(EPOCH):
for step, (ids_student, ids, sids, _ ) in enumerate(train_ds.start(place)): for step, (ids_student, ids, sids, label) in enumerate(train_ds.start(place)):
_, logits_t = teacher_model(ids, sids) # teacher 模型输出logits _, logits_t = teacher_model(ids, sids) # teacher 模型输出logits
logits_t.stop_gradient=True logits_t.stop_gradient=True
_, logits_s = model(ids_student) # student 模型输出logits _, logits_s = model(ids_student) # student 模型输出logits
loss = KL(logits_s, logits_t) # 由KL divergence度量两个分布的距离 loss_ce, _ = model(ids_student, labels=label)
loss_kd = KL(logits_s, logits_t) # 由KL divergence度量两个分布的距离
loss = loss_ce + loss_kd
loss.backward() loss.backward()
if step % 10 == 0: if step % 10 == 0:
print('[step %03d] 无监督 train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr())) print('[step %03d] distill train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr()))
opt.minimize(loss, grad_clip=g_clip) opt.minimize(loss)
model.clear_gradients() model.clear_gradients()
f1 = evaluate_student(model, dev_ds) f1 = evaluate_student(model, dev_ds)
print('f1 %.5f' % f1) print('student f1 %.5f' % f1)
for step, (ids_student, ids, sids, label) in enumerate(train_ds.start(place)): # 最后再加一轮hard label训练巩固结果
loss, _ = model(ids_student, labels=label) for step, (ids_student, ids, sids, label) in enumerate(train_ds.start(place)):
loss.backward() loss, _ = model(ids_student, labels=label)
if step % 10 == 0: loss.backward()
print('[step %03d] 监督 train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr())) if step % 10 == 0:
opt.minimize(loss, grad_clip=g_clip) print('[step %03d] train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr()))
model.clear_gradients() opt.minimize(loss)
model.clear_gradients()
f1 = evaluate_student(model, dev_ds)
print('f1 %.5f' % f1) f1 = evaluate_student(model, dev_ds)
print('final f1 %.5f' % f1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册