提交 1e4405c8 编写于 作者: X xuezhong

release ELMo

上级 b77d0354
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
import os
cell_clip = 3.0
proj_clip = 3.0
hidden_size = 4096
vocab_size = 52445
emb_size = 512
def dropout(input):
dropout_rate = 0.5
return layers.dropout(
input,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name, proj_size, args):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
init = None
init_b = None
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=init),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=proj_clip,
cell_clip=cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=init),
bias_attr=fluid.ParamAttr(initializer=init_b))
return hidden, cell, input_proj
def encoder_wrapper(x_emb,
vocab_size,
emb_size,
init_hidden=None,
init_cell=None,
para_name='',
args=None):
rnn_input = x_emb
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
num_layers = 2
for i in range(num_layers):
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, hidden_size, h0, c0, para_name + 'layer{}'.format(i + 1),
emb_size, args)
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out.stop_gradient = True
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
return rnn_outs, rnn_outs_ori
def weight_layers(lm_embeddings, name="", l2_coef=0.0):
'''
Weight the layers of a biLM with trainable scalar weights to
compute ELMo representations.
Input:
lm_embeddings(list): representations of 2 layers from biLM.
name = a string prefix used for the trainable variable names
l2_coef: the l2 regularization coefficient $\lambda$.
Pass None or 0.0 for no regularization.
Output:
weighted_lm_layers: weighted embeddings form biLM
'''
n_lm_layers = len(lm_embeddings)
W = layers.create_parameter(
[n_lm_layers, ],
dtype="float32",
name=name + "ELMo_w",
attr=fluid.ParamAttr(
name=name + "ELMo_w",
initializer=fluid.initializer.Constant(0.0),
regularizer=fluid.regularizer.L2Decay(l2_coef)))
normed_weights = layers.softmax(W + 1.0 / n_lm_layers)
splited_normed_weights = layers.split(normed_weights, n_lm_layers, dim=0)
# compute the weighted, normalized LM activations
pieces = []
for w, t in zip(splited_normed_weights, lm_embeddings):
pieces.append(t * w)
sum_pieces = layers.sums(pieces)
# scale the weighted sum by gamma
gamma = layers.create_parameter(
[1],
dtype="float32",
name=name + "ELMo_gamma",
attr=fluid.ParamAttr(
name=name + "ELMo_gamma",
initializer=fluid.initializer.Constant(1.0)))
weighted_lm_layers = sum_pieces * gamma
return weighted_lm_layers
def elmo_encoder(word_ids, elmo_l2_coef):
x_emb = layers.embedding(
input=word_ids,
size=[vocab_size, emb_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
x_emb_r = fluid.layers.sequence_reverse(x_emb, name=None)
fw_hiddens, fw_hiddens_ori = encoder_wrapper(
x_emb, vocab_size, emb_size, para_name='fw_', args=None)
bw_hiddens, bw_hiddens_ori = encoder_wrapper(
x_emb_r, vocab_size, emb_size, para_name='bw_', args=None)
num_layers = len(fw_hiddens_ori)
token_embeddings = layers.concat(input=[x_emb, x_emb], axis=1)
token_embeddings.stop_gradient = True
concate_embeddings = [token_embeddings]
for index in range(num_layers):
embedding = layers.concat(
input=[fw_hiddens_ori[index], bw_hiddens_ori[index]], axis=1)
embedding = dropout(embedding)
embedding.stop_gradient = True
concate_embeddings.append(embedding)
weighted_emb = weight_layers(concate_embeddings, l2_coef=elmo_l2_coef)
return weighted_emb
def init_pretraining_params(exe, pretraining_params_path, main_program):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def if_exist(var):
path = os.path.join(pretraining_params_path, var.name)
exist = os.path.exists(path)
if exist:
print('Load model: %s' % path)
return exist
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=if_exist)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
 
、 ,
。 .
— -
~ ~
‖ |
… .
‘ '
’ '
“ "
” "
〔 (
〕 )
〈 <
〉 >
「 '
」 '
『 "
』 "
〖 [
〗 ]
【 [
】 ]
∶ :
$ $
! !
" "
# #
% %
& &
' '
( (
) )
* *
+ +
, ,
- -
. .
/ /
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
: :
; ;
< <
= =
> >
? ?
@ @
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
[ [
\ \
] ]
^ ^
_ _
` `
a a
b b
c c
d d
e e
f f
g g
h h
i i
j j
k k
l l
m m
n n
o o
p p
q q
r r
s s
t t
u u
v v
w w
x x
y y
z z
{ {
| |
} }
 ̄ ~
〝 "
〞 "
﹐ ,
﹑ ,
﹒ .
﹔ ;
﹕ :
﹖ ?
﹗ !
﹙ (
﹚ )
﹛ {
﹜ {
﹝ [
﹞ ]
﹟ #
﹠ &
﹡ *
﹢ +
﹣ -
﹤ <
﹥ >
﹦ =
﹨ \
﹩ $
﹪ %
﹫ @
,
A a
B b
C c
D d
E e
F f
G g
H h
I i
J j
K k
L l
M m
N n
O o
P p
Q q
R r
S s
T t
U u
V v
W w
X x
Y y
Z z
此差异已折叠。
a
a-B
a-I
ad
ad-B
ad-I
an
an-B
an-I
c
d
d-B
d-I
f
f-B
f-I
m
m-B
m-I
n
n-B
n-I
nr
nr-B
nr-I
ns
ns-B
ns-I
nt
nt-B
nt-I
nw
nw-B
nw-I
nz
nz-B
nz-I
p
q
r
r-B
r-I
s
s-B
s-I
t
t-B
t-I
u
v
v-B
v-I
vd
vd-B
vd-I
vn
vn-B
vn-I
w
xc
\ No newline at end of file
此差异已折叠。
"""
The function lex_net(args) define the lexical analysis network structure
"""
import sys
import os
import math
import paddle.fluid as fluid
from paddle.fluid.initializer import NormalInitializer
import paddle.fluid.layers as layers
from bilm import elmo_encoder
import ipdb
def lex_net(args, word_dict_len, label_dict_len):
"""
define the lexical analysis network structure
"""
word_emb_dim = args.word_emb_dim
grnn_hidden_dim = args.grnn_hidden_dim
emb_lr = args.emb_learning_rate
crf_lr = args.crf_learning_rate
bigru_num = args.bigru_num
init_bound = 0.1
IS_SPARSE = True
def _bigru_layer(input_feature):
"""
define the bidirectional gru layer
"""
pre_gru = fluid.layers.fc(
input=input_feature,
size=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru = fluid.layers.dynamic_gru(
input=pre_gru,
size=grnn_hidden_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
pre_gru_r = fluid.layers.fc(
input=input_feature,
size=grnn_hidden_dim * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
gru_r = fluid.layers.dynamic_gru(
input=pre_gru_r,
size=grnn_hidden_dim,
is_reverse=True,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
bi_merge = fluid.layers.concat(input=[gru, gru_r], axis=1)
return bi_merge
def _net_conf(word_ids, target):
"""
Configure the network
"""
word_embedding = fluid.layers.embedding(
input=word,
size=[word_dict_len, word_emb_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr(
learning_rate=emb_lr,
name="word_emb",
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound)))
# add elmo embedding
elmo_emb = elmo_encoder(word_ids, args.elmo_l2_coef)
input_feature = layers.concat(input=[elmo_emb, word_embedding], axis=1)
for i in range(bigru_num):
bigru_output = _bigru_layer(input_feature)
input_feature = bigru_output
emission = fluid.layers.fc(
size=label_dict_len,
input=bigru_output,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-init_bound, high=init_bound),
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4)))
crf_cost = fluid.layers.linear_chain_crf(
input=emission,
label=target,
param_attr=fluid.ParamAttr(
name='crfw', learning_rate=crf_lr))
crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw'))
avg_cost = fluid.layers.mean(x=crf_cost)
return avg_cost, crf_decode
word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1)
target = fluid.layers.data(
name="target", shape=[1], dtype='int64', lod_level=1)
avg_cost, crf_decode = _net_conf(word, target)
return avg_cost, crf_decode, word, target
#coding: utf-8
"""
The file_reader converts raw corpus to input.
"""
import os
import __future__
import io
def file_reader(file_dir,
word2id_dict,
label2id_dict,
word_replace_dict,
filename_feature=""):
"""
define the reader to read files in file_dir
"""
word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1
def reader():
"""
the data generator
"""
index = 0
for root, dirs, files in os.walk(file_dir):
for filename in files:
for line in io.open(
os.path.join(root, filename), 'r', encoding='utf8'):
index += 1
bad_line = False
line = line.strip("\n")
if len(line) == 0:
continue
seg_tag = line.rfind("\t")
word_part = line[0:seg_tag].strip().split(' ')
label_part = line[seg_tag + 1:]
word_idx = []
words = word_part
for word in words:
if word in word_replace_dict:
word = word_replace_dict[word]
if word in word2id_dict:
word_idx.append(int(word2id_dict[word]))
else:
word_idx.append(int(word2id_dict["<UNK>"]))
target_idx = []
labels = label_part.strip().split(" ")
for label in labels:
if label in label2id_dict:
target_idx.append(int(label2id_dict[label]))
else:
target_idx.append(int(label2id_dict["O"]))
if len(word_idx) != len(target_idx):
print(line)
continue
yield word_idx, target_idx
return reader
def test_reader(file_dir,
word2id_dict,
label2id_dict,
word_replace_dict,
filename_feature=""):
"""
define the reader to read test files in file_dir
"""
word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1
def reader():
"""
the data generator
"""
index = 0
for root, dirs, files in os.walk(file_dir):
for filename in files:
if not filename.startswith(filename_feature):
continue
for line in io.open(
os.path.join(root, filename), 'r', encoding='utf8'):
index += 1
bad_line = False
line = line.strip("\n")
if len(line) == 0:
continue
seg_tag = line.rfind("\t")
if seg_tag == -1:
seg_tag = len(line)
word_part = line[0:seg_tag]
label_part = line[seg_tag + 1:]
word_idx = []
words = word_part
for word in words:
if ord(word) < 0x20:
word = ' '
if word in word_replace_dict:
word = word_replace_dict[word]
if word in word2id_dict:
word_idx.append(int(word2id_dict[word]))
else:
word_idx.append(int(word2id_dict["OOV"]))
yield word_idx, words
return reader
def load_reverse_dict(dict_path):
"""
Load a dict. The first column is the key and the second column is the value.
"""
result_dict = {}
# TODO 字和词模型
for idx, line in enumerate(io.open(dict_path, "r", encoding='utf8')):
terms = line.strip("\n")
result_dict[terms] = idx
return result_dict
def load_dict(dict_path):
"""
Load a dict. The first column is the value and the second column is the key.
"""
result_dict = {}
for idx, line in enumerate(io.open(dict_path, "r", encoding='utf8')):
terms = line.strip("\n")
result_dict[idx] = terms
return result_dict
export FLAGS_fraction_of_gpu_memory_to_use=0.5
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
export CUDA_VISIBLE_DEVICES=0
python train.py \
--traindata_dir data/train \
--model_save_dir model \
--use_gpu 1 \
--corpus_type_list train \
--corpus_proportion_list 1 \
--num_iterations 200000 \
--pretrain_elmo_model_path ${ELMo_MODEL_PATH} \
--testdata_dir data/dev $@ \
"""
This file is used to train the model.
"""
import os
import sys
import math
import time
import random
import argparse
import numpy as np
import paddle
import paddle.fluid as fluid
import reader
from network import lex_net
from bilm import init_pretraining_params
def parse_args():
"""
Parsing the input parameters.
"""
parser = argparse.ArgumentParser("Training for lexical analyzer.")
parser.add_argument(
"--traindata_dir",
type=str,
default="data/train_data",
help="The folder where the training data is located.")
parser.add_argument(
"--testdata_dir",
type=str,
default="data/test_data",
help="The folder where the training data is located.")
parser.add_argument(
"--model_save_dir",
type=str,
default="./models",
help="The model will be saved in this path.")
parser.add_argument(
"--save_model_per_batchs",
type=int,
default=1000,
help="Save the model once per xxxx batch of training")
parser.add_argument(
"--eval_window",
type=int,
default=20,
help="Training will be suspended when the evaluation indicators on the validation set" \
" no longer increase. The eval_window specifies the scope of the evaluation.")
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The number of sequences contained in a mini-batch, or the maximum" \
"number of tokens (include paddings) contained in a mini-batch.")
parser.add_argument(
"--corpus_type_list",
type=str,
default=["human", "feed", "query", "title", "news"],
nargs='+',
help="The pattern list of different types of corpus used in training.")
parser.add_argument(
"--corpus_proportion_list",
type=float,
default=[0.2, 0.2, 0.2, 0.2, 0.2],
nargs='+',
help="The proportion list of different types of corpus used in training."
)
parser.add_argument(
"--use_gpu",
type=int,
default=False,
help="Whether or not to use GPU. 0-->CPU 1-->GPU")
parser.add_argument(
"--traindata_shuffle_buffer",
type=int,
default=200000,
help="The buffer size used in shuffle the training data.")
parser.add_argument(
"--word_emb_dim",
type=int,
default=128,
help="The dimension in which a word is embedded.")
parser.add_argument(
"--grnn_hidden_dim",
type=int,
default=256,
help="The number of hidden nodes in the GRNN layer.")
parser.add_argument(
"--bigru_num",
type=int,
default=2,
help="The number of bi_gru layers in the network.")
parser.add_argument(
"--base_learning_rate",
type=float,
default=1e-3,
help="The basic learning rate that affects the entire network.")
parser.add_argument(
"--emb_learning_rate",
type=float,
default=5,
help="The real learning rate of the embedding layer will be" \
" (emb_learning_rate * base_learning_rate)."
)
parser.add_argument(
"--crf_learning_rate",
type=float,
default=0.2,
help="The real learning rate of the embedding layer will be" \
" (crf_learning_rate * base_learning_rate)."
)
parser.add_argument(
"--word_dict_path",
type=str,
default="../data/vocabulary_min5k.txt",
help="The path of the word dictionary.")
parser.add_argument(
"--label_dict_path",
type=str,
default="data/tag.dic",
help="The path of the label dictionary.")
parser.add_argument(
"--word_rep_dict_path",
type=str,
default="conf/q2b.dic",
help="The path of the word replacement Dictionary.")
parser.add_argument(
"--num_iterations",
type=int,
default=40000,
help="The maximum number of iterations. If set to 0 (default), do not limit the number."
)
#add elmo args
parser.add_argument(
"--elmo_l2_coef",
type=float,
default=0.001,
help="Weight decay. (default: %(default)f)")
parser.add_argument(
"--elmo_dict_dir",
default='data/vocabulary_min5k.txt',
help="If set, load elmo dict.")
parser.add_argument(
'--pretrain_elmo_model_path',
default="data/baike_elmo_checkpoint",
help="If set, load elmo checkpoint.")
args = parser.parse_args()
if len(args.corpus_proportion_list) != len(args.corpus_type_list):
sys.stderr.write(
"The length of corpus_proportion_list should be equal to the length of corpus_type_list.\n"
)
exit(-1)
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def to_lodtensor(data, place):
"""
Convert data in list into lodtensor.
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def test(exe, chunk_evaluator, save_dirname, test_data, place):
"""
Test the network in training.
"""
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
chunk_evaluator.reset()
for data in test_data():
word = to_lodtensor(list(map(lambda x: x[0], data)), place)
target = to_lodtensor(list(map(lambda x: x[1], data)), place)
result_list = exe.run(inference_program,
feed={"word": word,
"target": target},
fetch_list=fetch_targets)
number_infer = np.array(result_list[0])
number_label = np.array(result_list[1])
number_correct = np.array(result_list[2])
chunk_evaluator.update(
int(number_infer[0]),
int(number_label[0]), int(number_correct[0]))
return chunk_evaluator.eval()
def train(args):
"""
Train the network.
"""
if not os.path.exists(args.model_save_dir):
os.mkdir(args.model_save_dir)
word2id_dict = reader.load_reverse_dict(args.word_dict_path)
label2id_dict = reader.load_reverse_dict(args.label_dict_path)
word_rep_dict = reader.load_dict(args.word_rep_dict_path)
word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1
avg_cost, crf_decode, word, target = lex_net(args, word_dict_len,
label_dict_len)
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.base_learning_rate)
adam_optimizer.minimize(avg_cost)
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator()
chunk_evaluator.reset()
train_reader_list = []
corpus_num = len(args.corpus_type_list)
for i in range(corpus_num):
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.file_reader(args.traindata_dir, word2id_dict,
label2id_dict, word_rep_dict,
args.corpus_type_list[i]),
buf_size=args.traindata_shuffle_buffer),
batch_size=int(args.batch_size * args.corpus_proportion_list[i]))
train_reader_list.append(train_reader)
test_reader = paddle.batch(
reader.file_reader(args.testdata_dir, word2id_dict, label2id_dict,
word_rep_dict),
batch_size=args.batch_size)
train_reader_itr_list = []
for train_reader in train_reader_list:
cur_reader_itr = train_reader()
train_reader_itr_list.append(cur_reader_itr)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[word, target], place=place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# load pretrained ELMo model
init_pretraining_params(exe, args.pretrain_elmo_model_path,
fluid.default_main_program())
batch_id = 0
start_time = time.time()
eval_list = []
iter = 0
while True:
full_batch = []
cur_batch = []
for i in range(corpus_num):
reader_itr = train_reader_itr_list[i]
try:
cur_batch = next(reader_itr)
except StopIteration:
print(args.corpus_type_list[i] +
" corpus finish a pass of training")
new_reader = train_reader_list[i]
train_reader_itr_list[i] = new_reader()
cur_batch = next(train_reader_itr_list[i])
full_batch += cur_batch
random.shuffle(full_batch)
cost_var, nums_infer, nums_label, nums_correct = exe.run(
fluid.default_main_program(),
fetch_list=[
avg_cost, num_infer_chunks, num_label_chunks, num_correct_chunks
],
feed=feeder.feed(full_batch))
print("batch_id:" + str(batch_id) + ", avg_cost:" + str(cost_var[0]))
chunk_evaluator.update(nums_infer, nums_label, nums_correct)
batch_id += 1
if (batch_id % args.save_model_per_batchs == 1):
save_exe = fluid.Executor(place)
save_dirname = os.path.join(args.model_save_dir,
"params_batch_%d" % batch_id)
temp_save_model = os.path.join(args.model_save_dir,
"temp_model_for_test")
fluid.io.save_inference_model(
temp_save_model, ['word', 'target'],
[num_infer_chunks, num_label_chunks, num_correct_chunks],
save_exe)
precision, recall, f1_score = chunk_evaluator.eval()
print("[train] batch_id:" + str(batch_id) + ", precision:" + str(
precision) + ", recall:" + str(recall) + ", f1:" + str(
f1_score))
chunk_evaluator.reset()
p, r, f1 = test(exe, chunk_evaluator, temp_save_model, test_reader,
place)
chunk_evaluator.reset()
print("[test] batch_id:" + str(batch_id) + ", precision:" + str(p) +
", recall:" + str(r) + ", f1:" + str(f1))
end_time = time.time()
print("cur_batch_id:" + str(batch_id) + ", last " + str(
args.save_model_per_batchs) + " batchs, time_cost:" + str(
end_time - start_time))
start_time = time.time()
if len(eval_list) < 2 * args.eval_window:
eval_list.append(f1)
else:
eval_list.pop(0)
eval_list.append(f1)
last_avg_f1 = sum(eval_list[
0:args.eval_window]) / args.eval_window
cur_avg_f1 = sum(eval_list[args.eval_window:2 *
args.eval_window]) / args.eval_window
if cur_avg_f1 <= last_avg_f1:
return
else:
print("keep training!")
iter += 1
if (iter == args.num_iterations):
return
if __name__ == "__main__":
args = parse_args()
print_arguments(args)
train(args)
# ELMo
### 介绍
ELMo(Embeddings from Language Models) 是重要的通用语义表示模型之一,以双向 LSTM 为网路基本组件,以 Language Model 为训练目标,通过预训练得到通用的语义表示,将通用的语义表示作为 Feature 迁移到下游 NLP 任务中,会显著提升下游任务的模型性能。本项目是 ELMo 在 Paddle Fluid 上的开源实现, 基于百科类数据训练并发布了预训练模型。
### 发布要点:
1) 基于百科类数据训练的 [ELMo 中文预训练模型](https://dureader.gz.bcebos.com/elmo/baike_elmo_checkpoint.tar.gz);
2) 完整支持 ELMo 模型训练及表示迁移, 包括:
- 支持 ELMo 多卡训练,训练速度比主流实现快约1倍
- 以 LAC 任务为示例提供 ELMo 语义表示迁移到下游 NLP 任务的示例
3)我们在阅读理解任务和 LAC 任务上评估了 ELMo 预训练模型带给下游任务的性能提升:
- LAC 加入 ELMo 后 F1 可以提升 **1.1%**
- 阅读理解任务加入 ELMo 后 Rouge-L 提升 **1%**
| Task | 评估指标 | Baseline | +ELMo |
| :------| :------: | :------: |:------: |
| [LAC](https://github.com/baidu/lac) | F1 | 87.3% | **88.4%** |
| [阅读理解](github.com/PaddlePaddle/models/tree/develop/PaddleNLP/machine_reading_comprehension) | Rouge-L | 39.4% | **40.4%** |
**Note**:
- LAC 任务是基于 20w 训练数据训练的词模型
- 阅读理解任务是基于 [DuReader](https://github.com/baidu/DuReader) 数据集训练的词模型
### 安装
本项目依赖于 Paddle Fluid **1.4.0**,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。
### 预训练
#### 数据预处理
将文档按照句号、问号、感叹切分成句子,然后对句子进行切词。预处理后的数据文件中每行为一个分词后的句子。我们给出了示例训练数据 [`data/train`](data/train) 和测试数据 [`data/dev`](data/dev),数据示例如下:
```
本 书 介绍 了 中国 经济 发展 的 内外 平衡 问题 、 亚洲 金融 危机 十 周年 回顾 与 反思 、 实践 中 的 城乡 统筹 发展 、 未来 十 年 中国 需要 研究 的 重大 课题 、 科学 发展 与 新型 工业 化 等 方面 。
吴 敬 琏 曾经 提出 中国 股市 “ 赌场 论 ” , 主张 维护 市场 规则 , 保护 草根 阶层 生计 , 被 誉 为 “ 中国 经济 学界 良心 ” , 是 媒体 和 公众 眼中 的 学术 明星
```
#### 模型训练
利用提供的示例训练数据和测试数据,我们来说明如何进行单机多卡预训练。关于预训练的启动方式,可以查看脚本 `run.sh` ,该脚本已经默认以示例数据作为输入。在开始预训练之前,需要把 CUDA、cuDNN、NCCL2 等动态库路径加入到环境变量 `LD_LIBRARY_PATH` 之中,然后按如下方式即可开始单机多卡预训练
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
sh run.sh
```
训练过程中,默认每间隔 10000 steps 将模型参数写入到 checkpoints 路径下,可以通过 `--save_interval ${N}` 自定义保存模型的间隔 steps。
### ELMo 预训练模型如何迁移到下游 NLP 任务
我们在 [bilm.py](./LAC_demo/bilm.py) 中提供了 `elmo_encoder` 接口获取 ELMo 预训练模型的语义表示, 便于用户将 ELMo 语义表示快速迁移到下游任务;以 [LAC](https://github.com/baidu/lac) 任务为示例, 将 ELMo 预训练模型的语义表示迁移到 LAC 任务的主要步骤如下:
1) 搭建 LAC 网络结构,并加载 ELMo 预训练模型参数; 我们在 [bilm.py](./LAC_demo/bilm.py) 中提供了加载预训练模型的接口函数 init_pretraining_params
```
#step1: create_lac_model()
#step2: load pretrained ELMo model
from bilm import init_pretraining_params
init_pretraining_params(exe, args.pretrain_elmo_model_path,
fluid.default_main_program())
```
2) 基于 [ELMo 字典](data/vocabulary_min5k.txt) 将输入数据转化为 word_ids,利用 elmo_encoder 接口获取 ELMo embedding
```
from bilm import elmo_encoder
elmo_embedding = elmo_encoder(word_ids)
```
3) ELMo embedding 与 LAC 原有 word_embedding 拼接得到最终的 embedding
```
word_embedding=fluid.layers.concat(input=[elmo_embedding, word_embedding], axis=1)
```
### 参考论文
[Deep contextualized word representations](https://arxiv.org/abs/1802.05365)
### Contributors
本项目由百度深度学习技术平台部 PaddlePaddle 团队([@xuezhong](https://github.com/xuezhong) [@JesseyXujin](https://github.com/JesseyXujin))和百度自然语言处理部语义计算团队([@nbcc](https://github.com/nbcc) [@tianxin1860](https://github.com/tianxin1860))合作完成。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--load_dir",
type=str,
default="",
help="Specify the path to load trained models.")
parser.add_argument(
"--load_pretraining_params",
type=str,
default="",
help="Specify the path to load pretrained model parameters, NOT including moment and learning_rate"
)
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--embed_size",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--hidden_size",
type=int,
default=4096,
help="The size of rnn hidden unit. (default: %(default)d)")
parser.add_argument(
"--num_layers",
type=int,
default=2,
help="The size of rnn layers. (default: %(default)d)")
parser.add_argument(
"--num_steps",
type=int,
default=20,
help="The size of sequence len. (default: %(default)d)")
parser.add_argument(
"--all_train_tokens",
type=int,
default=35479,
help="The size of all training tokens")
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument("--vocab_path", type=str, help="vocab file path")
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
parser.add_argument('--enable_ce', action='store_true')
parser.add_argument('--test_nccl', action='store_true')
parser.add_argument('--optim', default='adagrad', help='optimizer type')
parser.add_argument('--sample_softmax', action='store_true')
parser.add_argument(
"--learning_rate",
type=float,
default=0.2,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--save_interval",
type=int,
default=10000,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--dev_interval",
type=int,
default=10000,
help="cal dev loss every n batches."
"(default: %(default)d)")
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--max_grad_norm', type=float, default=10.0)
parser.add_argument('--proj_clip', type=float, default=3.0)
parser.add_argument('--cell_clip', type=float, default=3.0)
parser.add_argument('--max_epoch', type=float, default=10)
parser.add_argument('--local', type=bool, default=False)
parser.add_argument('--shuffle', type=bool, default=False)
parser.add_argument('--use_custom_samples', type=bool, default=False)
parser.add_argument('--para_save_dir', type=str, default='checkpoints')
parser.add_argument('--train_path', type=str, default='')
parser.add_argument('--test_path', type=str, default='')
parser.add_argument('--update_method', type=str, default='nccl2')
parser.add_argument('--random_seed', type=int, default=0)
parser.add_argument('--n_negative_samples_batch', type=int, default=8000)
args = parser.parse_args()
return args
# originally based on https://github.com/tensorflow/models/tree/master/lm_1b
import glob
import random
import numpy as np
import io
import six
class Vocabulary(object):
'''
A token vocabulary. Holds a map from token to ids and provides
a method for encoding text to a sequence of ids.
'''
def __init__(self, filename, validate_file=False):
'''
filename = the vocabulary file. It is a flat text file with one
(normalized) token per line. In addition, the file should also
contain the special tokens <S>, </S>, <UNK> (case sensitive).
'''
self._id_to_word = []
self._word_to_id = {}
self._unk = -1
self._bos = -1
self._eos = -1
with io.open(filename, 'r', encoding='utf-8') as f:
idx = 0
for line in f:
word_name = line.strip()
if word_name == '<S>':
self._bos = idx
elif word_name == '</S>':
self._eos = idx
elif word_name == '<UNK>':
self._unk = idx
if word_name == '!!!MAXTERMID':
continue
self._id_to_word.append(word_name)
self._word_to_id[word_name] = idx
idx += 1
# check to ensure file has special tokens
if validate_file:
if self._bos == -1 or self._eos == -1 or self._unk == -1:
raise ValueError("Ensure the vocabulary file has "
"<S>, </S>, <UNK> tokens")
@property
def bos(self):
return self._bos
@property
def eos(self):
return self._eos
@property
def unk(self):
return self._unk
@property
def size(self):
return len(self._id_to_word)
def word_to_id(self, word):
if word in self._word_to_id:
return self._word_to_id[word]
return self.unk
def id_to_word(self, cur_id):
return self._id_to_word[cur_id]
def decode(self, cur_ids):
"""Convert a list of ids to a sentence, with space inserted."""
return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
def encode(self, sentence, reverse=False, split=True):
"""Convert a sentence to a list of ids, with special tokens added.
Sentence is a single string with tokens separated by whitespace.
If reverse, then the sentence is assumed to be reversed, and
this method will swap the BOS/EOS tokens appropriately."""
if split:
word_ids = [
self.word_to_id(cur_word) for cur_word in sentence.split()
]
else:
word_ids = [self.word_to_id(cur_word) for cur_word in sentence]
if reverse:
return np.array([self.eos] + word_ids + [self.bos], dtype=np.int32)
else:
return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
class UnicodeCharsVocabulary(Vocabulary):
"""Vocabulary containing character-level and word level information.
Has a word vocabulary that is used to lookup word ids and
a character id that is used to map words to arrays of character ids.
The character ids are defined by ord(c) for c in word.encode('utf-8')
This limits the total number of possible char ids to 256.
To this we add 5 additional special ids: begin sentence, end sentence,
begin word, end word and padding.
WARNING: for prediction, we add +1 to the output ids from this
class to create a special padding id (=0). As a result, we suggest
you use the `Batcher`, `TokenBatcher`, and `LMDataset` classes instead
of this lower level class. If you are using this lower level class,
then be sure to add the +1 appropriately, otherwise embeddings computed
from the pre-trained model will be useless.
"""
def __init__(self, filename, max_word_length, **kwargs):
super(UnicodeCharsVocabulary, self).__init__(filename, **kwargs)
self._max_word_length = max_word_length
# char ids 0-255 come from utf-8 encoding bytes
# assign 256-300 to special chars
self.bos_char = 256 # <begin sentence>
self.eos_char = 257 # <end sentence>
self.bow_char = 258 # <begin word>
self.eow_char = 259 # <end word>
self.pad_char = 260 # <padding>
num_words = len(self._id_to_word)
self._word_char_ids = np.zeros(
[num_words, max_word_length], dtype=np.int32)
# the charcter representation of the begin/end of sentence characters
def _make_bos_eos(c):
r = np.zeros([self.max_word_length], dtype=np.int32)
r[:] = self.pad_char
r[0] = self.bow_char
r[1] = c
r[2] = self.eow_char
return r
self.bos_chars = _make_bos_eos(self.bos_char)
self.eos_chars = _make_bos_eos(self.eos_char)
for i, word in enumerate(self._id_to_word):
self._word_char_ids[i] = self._convert_word_to_char_ids(word)
self._word_char_ids[self.bos] = self.bos_chars
self._word_char_ids[self.eos] = self.eos_chars
@property
def word_char_ids(self):
return self._word_char_ids
@property
def max_word_length(self):
return self._max_word_length
def _convert_word_to_char_ids(self, word):
code = np.zeros([self.max_word_length], dtype=np.int32)
code[:] = self.pad_char
word_encoded = word.encode('utf-8',
'ignore')[:(self.max_word_length - 2)]
code[0] = self.bow_char
for k, chr_id in enumerate(word_encoded, start=1):
code[k] = ord(chr_id)
code[k + 1] = self.eow_char
return code
def word_to_char_ids(self, word):
if word in self._word_to_id:
return self._word_char_ids[self._word_to_id[word]]
else:
return self._convert_word_to_char_ids(word)
def encode_chars(self, sentence, reverse=False, split=True):
'''
Encode the sentence as a white space delimited string of tokens.
'''
if split:
chars_ids = [
self.word_to_char_ids(cur_word) for cur_word in sentence.split()
]
else:
chars_ids = [
self.word_to_char_ids(cur_word) for cur_word in sentence
]
if reverse:
return np.vstack([self.eos_chars] + chars_ids + [self.bos_chars])
else:
return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
class Batcher(object):
'''
Batch sentences of tokenized text into character id matrices.
'''
# def __init__(self, lm_vocab_file: str, max_token_length: int):
def __init__(self, lm_vocab_file, max_token_length):
'''
lm_vocab_file = the language model vocabulary file (one line per
token)
max_token_length = the maximum number of characters in each token
'''
max_token_length = int(max_token_length)
self._lm_vocab = UnicodeCharsVocabulary(lm_vocab_file, max_token_length)
self._max_token_length = max_token_length
# def batch_sentences(self, sentences: List[List[str]]):
def batch_sentences(self, sentences):
'''
Batch the sentences as character ids
Each sentence is a list of tokens without <s> or </s>, e.g.
[['The', 'first', 'sentence', '.'], ['Second', '.']]
'''
n_sentences = len(sentences)
max_length = max(len(sentence) for sentence in sentences) + 2
X_char_ids = np.zeros(
(n_sentences, max_length, self._max_token_length), dtype=np.int64)
for k, sent in enumerate(sentences):
length = len(sent) + 2
char_ids_without_mask = self._lm_vocab.encode_chars(
sent, split=False)
# add one so that 0 is the mask value
X_char_ids[k, :length, :] = char_ids_without_mask + 1
return X_char_ids
class TokenBatcher(object):
'''
Batch sentences of tokenized text into token id matrices.
'''
def __init__(self, lm_vocab_file):
# def __init__(self, lm_vocab_file: str):
'''
lm_vocab_file = the language model vocabulary file (one line per
token)
'''
self._lm_vocab = Vocabulary(lm_vocab_file)
# def batch_sentences(self, sentences: List[List[str]]):
def batch_sentences(self, sentences):
'''
Batch the sentences as character ids
Each sentence is a list of tokens without <s> or </s>, e.g.
[['The', 'first', 'sentence', '.'], ['Second', '.']]
'''
n_sentences = len(sentences)
max_length = max(len(sentence) for sentence in sentences) + 2
X_ids = np.zeros((n_sentences, max_length), dtype=np.int64)
for k, sent in enumerate(sentences):
length = len(sent) + 2
ids_without_mask = self._lm_vocab.encode(sent, split=False)
# add one so that 0 is the mask value
X_ids[k, :length] = ids_without_mask + 1
return X_ids
##### for training
def _get_batch(generator, batch_size, num_steps, max_word_length):
"""Read batches of input."""
cur_stream = [None] * batch_size
no_more_data = False
while True:
inputs = np.zeros([batch_size, num_steps], np.int32)
if max_word_length is not None:
char_inputs = np.zeros([batch_size, num_steps, max_word_length],
np.int32)
else:
char_inputs = None
targets = np.zeros([batch_size, num_steps], np.int32)
for i in range(batch_size):
cur_pos = 0
while cur_pos < num_steps:
if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
try:
cur_stream[i] = list(next(generator))
except StopIteration:
# No more data, exhaust current streams and quit
no_more_data = True
break
how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
next_pos = cur_pos + how_many
inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
if max_word_length is not None:
char_inputs[i, cur_pos:next_pos] = cur_stream[i][
1][:how_many]
targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many + 1]
cur_pos = next_pos
cur_stream[i][0] = cur_stream[i][0][how_many:]
if max_word_length is not None:
cur_stream[i][1] = cur_stream[i][1][how_many:]
if no_more_data:
# There is no more data. Note: this will not return data
# for the incomplete batch
break
X = {
'token_ids': inputs,
'tokens_characters': char_inputs,
'next_token_id': targets
}
yield X
class LMDataset(object):
"""
Hold a language model dataset.
A dataset is a list of tokenized files. Each file contains one sentence
per line. Each sentence is pre-tokenized and white space joined.
"""
def __init__(self,
filepattern,
vocab,
reverse=False,
test=False,
shuffle_on_load=False):
'''
filepattern = a glob string that specifies the list of files.
vocab = an instance of Vocabulary or UnicodeCharsVocabulary
reverse = if True, then iterate over tokens in each sentence in reverse
test = if True, then iterate through all data once then stop.
Otherwise, iterate forever.
shuffle_on_load = if True, then shuffle the sentences after loading.
'''
self._vocab = vocab
self._all_shards = glob.glob(filepattern)
print('Found %d shards at %s' % (len(self._all_shards), filepattern))
if test:
self._all_shards = list(np.random.choice(self._all_shards, size=4))
print('sampled %d shards at %s' %
(len(self._all_shards), filepattern))
self._shards_to_choose = []
self._reverse = reverse
self._test = test
self._shuffle_on_load = shuffle_on_load
self._use_char_inputs = hasattr(vocab, 'encode_chars')
self._ids = self._load_random_shard()
def _choose_random_shard(self):
if len(self._shards_to_choose) == 0:
self._shards_to_choose = list(self._all_shards)
random.shuffle(self._shards_to_choose)
shard_name = self._shards_to_choose.pop()
return shard_name
def _load_random_shard(self):
"""Randomly select a file and read it."""
if self._test:
if len(self._all_shards) == 0:
# we've loaded all the data
# this will propogate up to the generator in get_batch
# and stop iterating
raise StopIteration
else:
shard_name = self._all_shards.pop()
else:
# just pick a random shard
shard_name = self._choose_random_shard()
ids = self._load_shard(shard_name)
self._i = 0
self._nids = len(ids)
return ids
def _load_shard(self, shard_name):
"""Read one file and convert to ids.
Args:
shard_name: file path.
Returns:
list of (id, char_id) tuples.
"""
print('Loading data from: %s' % shard_name)
with io.open(shard_name, 'r', encoding='utf-8') as f:
sentences_raw = f.readlines()
if self._reverse:
sentences = []
for sentence in sentences_raw:
splitted = sentence.split()
splitted.reverse()
sentences.append(' '.join(splitted))
else:
sentences = sentences_raw
if self._shuffle_on_load:
print('shuffle sentences')
random.shuffle(sentences)
ids = [
self.vocab.encode(sentence, self._reverse) for sentence in sentences
]
if self._use_char_inputs:
chars_ids = [
self.vocab.encode_chars(sentence, self._reverse)
for sentence in sentences
]
else:
chars_ids = [None] * len(ids)
print('Loaded %d sentences.' % len(ids))
print('Finished loading')
return list(zip(ids, chars_ids))
def get_sentence(self):
while True:
if self._i == self._nids:
self._ids = self._load_random_shard()
ret = self._ids[self._i]
self._i += 1
yield ret
@property
def max_word_length(self):
if self._use_char_inputs:
return self._vocab.max_word_length
else:
return None
def iter_batches(self, batch_size, num_steps):
for X in _get_batch(self.get_sentence(), batch_size, num_steps,
self.max_word_length):
# token_ids = (batch_size, num_steps)
# char_inputs = (batch_size, num_steps, 50) of character ids
# targets = word ID of next word (batch_size, num_steps)
yield X
@property
def vocab(self):
return self._vocab
class BidirectionalLMDataset(object):
def __init__(self, filepattern, vocab, test=False, shuffle_on_load=False):
'''
bidirectional version of LMDataset
'''
self._data_forward = LMDataset(
filepattern,
vocab,
reverse=False,
test=test,
shuffle_on_load=shuffle_on_load)
self._data_reverse = LMDataset(
filepattern,
vocab,
reverse=True,
test=test,
shuffle_on_load=shuffle_on_load)
def iter_batches(self, batch_size, num_steps):
max_word_length = self._data_forward.max_word_length
for X, Xr in six.moves.zip(
_get_batch(self._data_forward.get_sentence(), batch_size,
num_steps, max_word_length),
_get_batch(self._data_reverse.get_sentence(), batch_size,
num_steps, max_word_length)):
for k, v in Xr.items():
X[k + '_reverse'] = v
yield X
class InvalidNumberOfCharacters(Exception):
pass
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
def dropout(input, test_mode, args):
if args.dropout and (not test_mode):
return layers.dropout(
input,
dropout_prob=args.dropout,
dropout_implementation="upscale_in_train",
seed=args.random_seed,
is_test=False)
else:
return input
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name, proj_size,
test_mode, args):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_seq = dropout(input_seq, test_mode, args)
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=None),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=args.proj_clip,
cell_clip=args.cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=None),
bias_attr=fluid.ParamAttr(initializer=None))
return hidden, cell, input_proj
def encoder(x,
y,
vocab_size,
emb_size,
init_hidden=None,
init_cell=None,
para_name='',
custom_samples=None,
custom_probabilities=None,
test_mode=False,
args=None):
x_emb = layers.embedding(
input=x,
size=[vocab_size, emb_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
rnn_input = x_emb
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
for i in range(args.num_layers):
rnn_input = dropout(rnn_input, test_mode, args)
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, args.hidden_size, h0, c0,
para_name + 'layer{}'.format(i + 1), emb_size, test_mode, args)
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out = dropout(rnn_out, test_mode, args)
cell = dropout(cell, test_mode, args)
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
rnn_input = rnn_out
cells.append(cell)
projs.append(input_proj)
softmax_weight = layers.create_parameter(
[vocab_size, emb_size], dtype="float32", name="softmax_weight")
softmax_bias = layers.create_parameter(
[vocab_size], dtype="float32", name='softmax_bias')
projection = layers.matmul(rnn_outs[-1], softmax_weight, transpose_y=True)
projection = layers.elementwise_add(projection, softmax_bias)
projection = layers.reshape(projection, shape=[-1, vocab_size])
if args.sample_softmax and (not test_mode):
loss = layers.sampled_softmax_with_cross_entropy(
logits=projection,
label=y,
num_samples=args.n_negative_samples_batch,
seed=args.random_seed)
else:
label = layers.one_hot(input=y, depth=vocab_size)
loss = layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=True)
return [x_emb, projection, loss], rnn_outs, rnn_outs_ori, cells, projs
class LanguageModel(object):
def __init__(self, args, vocab_size, test_mode):
self.args = args
self.vocab_size = vocab_size
self.test_mode = test_mode
def build(self):
args = self.args
emb_size = args.embed_size
proj_size = args.embed_size
hidden_size = args.hidden_size
batch_size = args.batch_size
num_layers = args.num_layers
num_steps = args.num_steps
lstm_outputs = []
x_f = layers.data(name="x", shape=[1], dtype='int64', lod_level=1)
y_f = layers.data(name="y", shape=[1], dtype='int64', lod_level=1)
x_b = layers.data(name="x_r", shape=[1], dtype='int64', lod_level=1)
y_b = layers.data(name="y_r", shape=[1], dtype='int64', lod_level=1)
init_hiddens_ = layers.data(
name="init_hiddens", shape=[1], dtype='float32')
init_cells_ = layers.data(name="init_cells", shape=[1], dtype='float32')
init_hiddens = layers.reshape(
init_hiddens_, shape=[2 * num_layers, -1, proj_size])
init_cells = layers.reshape(
init_cells_, shape=[2 * num_layers, -1, hidden_size])
init_hidden = layers.slice(
init_hiddens, axes=[0], starts=[0], ends=[num_layers])
init_cell = layers.slice(
init_cells, axes=[0], starts=[0], ends=[num_layers])
init_hidden_r = layers.slice(
init_hiddens, axes=[0], starts=[num_layers], ends=[2 * num_layers])
init_cell_r = layers.slice(
init_cells, axes=[0], starts=[num_layers], ends=[2 * num_layers])
if args.use_custom_samples:
custom_samples = layers.data(
name="custom_samples",
shape=[args.n_negative_samples_batch + 1],
dtype='int64',
lod_level=1)
custom_samples_r = layers.data(
name="custom_samples_r",
shape=[args.n_negative_samples_batch + 1],
dtype='int64',
lod_level=1)
custom_probabilities = layers.data(
name="custom_probabilities",
shape=[args.n_negative_samples_batch + 1],
dtype='float32',
lod_level=1)
else:
custom_samples = None
custom_samples_r = None
custom_probabilities = None
forward, fw_hiddens, fw_hiddens_ori, fw_cells, fw_projs = encoder(
x_f,
y_f,
self.vocab_size,
emb_size,
init_hidden,
init_cell,
para_name='fw_',
custom_samples=custom_samples,
custom_probabilities=custom_probabilities,
test_mode=self.test_mode,
args=args)
backward, bw_hiddens, bw_hiddens_ori, bw_cells, bw_projs = encoder(
x_b,
y_b,
self.vocab_size,
emb_size,
init_hidden_r,
init_cell_r,
para_name='bw_',
custom_samples=custom_samples_r,
custom_probabilities=custom_probabilities,
test_mode=self.test_mode,
args=args)
losses = layers.concat([forward[-1], backward[-1]])
self.loss = layers.reduce_mean(losses)
self.loss.persistable = True
self.grad_vars = [x_f, y_f, x_b, y_b, self.loss]
self.grad_vars_name = ['x', 'y', 'x_r', 'y_r', 'final_loss']
fw_vars_name = ['x_emb', 'proj', 'loss'] + [
'init_hidden', 'init_cell'
] + ['rnn_out', 'rnn_out2', 'cell', 'cell2', 'xproj', 'xproj2']
bw_vars_name = ['x_emb_r', 'proj_r', 'loss_r'] + [
'init_hidden_r', 'init_cell_r'
] + [
'rnn_out_r', 'rnn_out2_r', 'cell_r', 'cell2_r', 'xproj_r',
'xproj2_r'
]
fw_vars = forward + [init_hidden, init_cell
] + fw_hiddens + fw_cells + fw_projs
bw_vars = backward + [init_hidden_r, init_cell_r
] + bw_hiddens + bw_cells + bw_projs
for i in range(len(fw_vars_name)):
self.grad_vars.append(fw_vars[i])
self.grad_vars.append(bw_vars[i])
self.grad_vars_name.append(fw_vars_name[i])
self.grad_vars_name.append(bw_vars_name[i])
if args.use_custom_samples:
self.feed_order = [
'x', 'y', 'x_r', 'y_r', 'custom_samples', 'custom_samples_r',
'custom_probabilities'
]
else:
self.feed_order = ['x', 'y', 'x_r', 'y_r']
self.last_hidden = [
fluid.layers.sequence_last_step(input=x)
for x in fw_hiddens_ori + bw_hiddens_ori
]
self.last_cell = [
fluid.layers.sequence_last_step(input=x)
for x in fw_cells + bw_cells
]
self.last_hidden = layers.concat(self.last_hidden, axis=0)
self.last_hidden.persistable = True
self.last_cell = layers.concat(self.last_cell, axis=0)
self.last_cell.persistable = True
export CUDA_VISIBLE_DEVICES=0
python train.py \
--train_path='data/train/sentence_file_*' \
--test_path='data/dev/sentence_file_*' \
--vocab_path data/vocabulary_min5k.txt \
--learning_rate 0.2 \
--use_gpu True \
--all_train_tokens 35479 \
--local True $@
此差异已折叠。
......@@ -4,6 +4,7 @@
- [BERT](./BERT): Bidirectional Encoder Representation from Transformers
- [ERNIE](./ERNIE): Enhanced Representation from kNowledge IntEgration
- [ELMo](./ELMo): Embeddings from Language Models
And more is on the way.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册