未验证 提交 0f2dfd4b 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #1991 from JesseyXujin/elmo_test

ELMO
<h1 align="center">ELMO</h1>
## 介绍
ELMO(Embeddings from Language Models)是一种新型深度语境化词表征,可对词进行复杂特征(如句法和语义)和词在语言语境中的变化进行建模(即对多义词进行建模)。PaddlePaddle版本该模型支持多卡训练,训练速度比主流实现快约1倍, 验证在中文词法分析任务上f1值提升0.68%。
ELMO在大语料上以language model为训练目标,训练出bidirectional LSTM模型,利用LSTM产生词语的表征, 对下游NLP任务(如问答、分类、命名实体识别等)进行微调。
## 基本配置及第三方安装包
Python==2.7
PaddlePaddle lastest版本
numpy ==1.15.1
six==1.11.0
glob
## 预训练模型
1. 把文档文件切分成句子,并基于词表(参考vocabulary_min5k.txt)对句子进行切词。把文件切分成训练集trainset和测试集testset。
```
本 书 介绍 了 中国 经济 发展 的 内外 平衡 问题 、 亚洲 金融 危机 十 周年 回顾 与 反思 、 实践 中 的 城乡 统筹 发展 、 未来 十 年 中国 需要 研究 的 重大 课题 、 科学 发展 与 新型 工业 化 等 方面 。
```
```
吴 敬 琏 曾经 提出 中国 股市 “ 赌场 论 ” , 主张 维护 市场 规则 , 保护 草根 阶层 生计 , 被 誉 为 “ 中国 经济 学界 良心 ” , 是 媒体 和 公众 眼中 的 学术 明星
```
2. 训练模型
```shell
sh run.sh
```
3. 把checkpoint结果写入文件中。
## 单机多卡训练
模型支持单机多卡训练,需要在run.sh里export CUDA_VISIBLE_DEVICES设置指定卡,如下所示:
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
```
## 如何利用ELMO做微调
1. 下载ELMO Paddle官方发布Checkout文件
[PaddlePaddle官方发布Checkout文件下载地址](https://dureader.gz.bcebos.com/elmo/elmo_chinese_checkpoint.tar.gz)
2. 在train部分中加载ELMO checkpoint文件
```shell
src_pretrain_model_path = '490001' #490001为ELMO checkpoint文件
def if_exist(var):
path = os.path.join(src_pretrain_model_path, var.name)
exist = os.path.exists(path)
if exist:
print('Load model: %s' % path)
return exist
fluid.io.load_vars(executor=exe, dirname=src_pretrain_model_path, predicate=if_exist, main_program=main_program)
```
3. 在下游NLP任务文件夹中加入bilm.py文件
4. 基于elmo词表(参考vocabulary_min5k.txt)对输入的句子或段落进行切词,并把切词的词转化为id,放入feed_dict中。
5. 在下游任务网络定义中embedding部分加入ELMO网络的定义
```shell
#引入 bilm.py embedding部分和encoder部分
from bilm import elmo_encoder
from bilm import emb
#word为输入elmo部分切词后的字典
elmo_embedding = emb(word)
elmo_enc= elmo_encoder(elmo_embedding)
#与NLP任务中生成词向量word_embedding做连接操作
word_embedding=layers.concat(input=[elmo_enc, word_embedding], axis=1)
```
## 参考论文
[Deep contextualized word representations](https://arxiv.org/abs/1802.05365)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--load_dir",
type=str,
default="",
help="Specify the path to load trained models.")
parser.add_argument(
"--load_pretraining_params",
type=str,
default="",
help="Specify the path to load pretrained model parameters, NOT including moment and learning_rate")
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--embed_size",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--hidden_size",
type=int,
default=4096,
help="The size of rnn hidden unit. (default: %(default)d)")
parser.add_argument(
"--num_layers",
type=int,
default=2,
help="The size of rnn layers. (default: %(default)d)")
parser.add_argument(
"--num_steps",
type=int,
default=20,
help="The size of sequence len. (default: %(default)d)")
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument("--vocab_path", type=str, help="vocab file path")
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
parser.add_argument('--enable_ce', action='store_true')
parser.add_argument('--test_nccl', action='store_true')
parser.add_argument('--optim', default='adagrad', help='optimizer type')
parser.add_argument('--sample_softmax', action='store_true')
parser.add_argument(
"--learning_rate",
type=float,
default=0.2,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--save_interval",
type=int,
default=10000,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--dev_interval",
type=int,
default=10000,
help="cal dev loss every n batches."
"(default: %(default)d)")
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--max_grad_norm', type=float, default=10.0)
parser.add_argument('--proj_clip', type=float, default=3.0)
parser.add_argument('--cell_clip', type=float, default=3.0)
parser.add_argument('--max_epoch', type=float, default=10)
parser.add_argument('--local', type=bool, default=False)
parser.add_argument('--shuffle', type=bool, default=False)
parser.add_argument('--use_custom_samples', type=bool, default=False)
parser.add_argument('--para_save_dir', type=str, default='model_new')
parser.add_argument('--train_path', type=str, default='')
parser.add_argument('--test_path', type=str, default='')
parser.add_argument('--update_method', type=str, default='nccl2')
parser.add_argument('--random_seed', type=int, default=0)
parser.add_argument('--n_negative_samples_batch', type=int, default=8000)
args = parser.parse_args()
return args
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is used to finetone
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
# if you use our release weight layers,do not use the args.
cell_clip = 3.0
proj_clip = 3.0
hidden_size = 4096
vocab_size = 52445
embed_size = 512
# according to orginal paper, dropout need to be modifyed on finetone
modify_dropout = 1
proj_size = 512
num_layers = 2
random_seed = 0
dropout_rate = 0.5
def dropout(input):
return layers.dropout(
input,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
seed=random_seed,
is_test=False)
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=init),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=proj_clip,
cell_clip=cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=None),
bias_attr=fluid.ParamAttr(initializer=None))
return hidden, cell, input_proj
def encoder(x_emb,
init_hidden=None,
init_cell=None,
para_name=''):
rnn_input = x_emb
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
for i in range(num_layers):
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, hidden_size, h0, c0,
para_name + 'layer{}'.format(i + 1))
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out.stop_gradient = True
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
# add weight layers for finetone
a1 = layers.create_parameter(
[1], dtype="float32", name="gamma1")
a2 = layers.create_parameter(
[1], dtype="float32", name="gamma2")
rnn_outs[0].stop_gradient = True
rnn_outs[1].stop_gradient = True
num_layer1 = rnn_outs[0] * a1
num_layer2 = rnn_outs[1] * a2
output_layer = num_layer1 * 0.5 + num_layer2 * 0.5
return output_layer, rnn_outs_ori
def emb(x):
x_emb = layers.embedding(
input=x,
size=[vocab_size, embed_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
return x_emb
def elmo_encoder(x_emb):
x_emb_r = fluid.layers.sequence_reverse(x_emb, name=None)
fw_hiddens, fw_hiddens_ori = encoder(
x_emb,
para_name='fw_')
bw_hiddens, bw_hiddens_ori = encoder(
x_emb_r,
para_name='bw_')
embedding = layers.concat(input=[fw_hiddens, bw_hiddens], axis=1)
# add dropout on finetone
embedding = dropout(embedding)
a = layers.create_parameter(
[1], dtype="float32", name="gamma")
embedding = embedding * a
return embedding
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import random
import numpy as np
import io
import six
class Vocabulary(object):
'''
A token vocabulary. Holds a map from token to ids and provides
a method for encoding text to a sequence of ids.
'''
def __init__(self, filename, validate_file=False):
'''
filename = the vocabulary file. It is a flat text file with one
(normalized) token per line. In addition, the file should also
contain the special tokens <S>, </S>, <UNK> (case sensitive).
'''
self._id_to_word = []
self._word_to_id = {}
self._unk = -1
self._bos = -1
self._eos = -1
with io.open(filename, 'r', encoding='utf-8') as f:
idx = 0
for line in f:
word_name = line.strip()
if word_name == '<S>':
self._bos = idx
elif word_name == '</S>':
self._eos = idx
elif word_name == '<UNK>':
self._unk = idx
if word_name == '!!!MAXTERMID':
continue
self._id_to_word.append(word_name)
self._word_to_id[word_name] = idx
idx += 1
# check to ensure file has special tokens
if validate_file:
if self._bos == -1 or self._eos == -1 or self._unk == -1:
raise ValueError("Ensure the vocabulary file has "
"<S>, </S>, <UNK> tokens")
@property
def bos(self):
return self._bos
@property
def eos(self):
return self._eos
@property
def unk(self):
return self._unk
@property
def size(self):
return len(self._id_to_word)
def word_to_id(self, word):
if word in self._word_to_id:
return self._word_to_id[word]
return self.unk
def id_to_word(self, cur_id):
return self._id_to_word[cur_id]
def decode(self, cur_ids):
"""Convert a list of ids to a sentence, with space inserted."""
return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
def encode(self, sentence, reverse=False, split=True):
"""Convert a sentence to a list of ids, with special tokens added.
Sentence is a single string with tokens separated by whitespace.
If reverse, then the sentence is assumed to be reversed, and
this method will swap the BOS/EOS tokens appropriately."""
if split:
word_ids = [
self.word_to_id(cur_word) for cur_word in sentence.split()
]
else:
word_ids = [self.word_to_id(cur_word) for cur_word in sentence]
if reverse:
return np.array([self.eos] + word_ids + [self.bos], dtype=np.int32)
else:
return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
class UnicodeCharsVocabulary(Vocabulary):
"""Vocabulary containing character-level and word level information.
Has a word vocabulary that is used to lookup word ids and
a character id that is used to map words to arrays of character ids.
The character ids are defined by ord(c) for c in word.encode('utf-8')
This limits the total number of possible char ids to 256.
To this we add 5 additional special ids: begin sentence, end sentence,
begin word, end word and padding.
WARNING: for prediction, we add +1 to the output ids from this
class to create a special padding id (=0). As a result, we suggest
you use the `Batcher`, `TokenBatcher`, and `LMDataset` classes instead
of this lower level class. If you are using this lower level class,
then be sure to add the +1 appropriately, otherwise embeddings computed
from the pre-trained model will be useless.
"""
def __init__(self, filename, max_word_length, **kwargs):
super(UnicodeCharsVocabulary, self).__init__(filename, **kwargs)
self._max_word_length = max_word_length
# char ids 0-255 come from utf-8 encoding bytes
# assign 256-300 to special chars
self.bos_char = 256 # <begin sentence>
self.eos_char = 257 # <end sentence>
self.bow_char = 258 # <begin word>
self.eow_char = 259 # <end word>
self.pad_char = 260 # <padding>
num_words = len(self._id_to_word)
self._word_char_ids = np.zeros(
[num_words, max_word_length], dtype=np.int32)
# the charcter representation of the begin/end of sentence characters
def _make_bos_eos(c):
r = np.zeros([self.max_word_length], dtype=np.int32)
r[:] = self.pad_char
r[0] = self.bow_char
r[1] = c
r[2] = self.eow_char
return r
self.bos_chars = _make_bos_eos(self.bos_char)
self.eos_chars = _make_bos_eos(self.eos_char)
for i, word in enumerate(self._id_to_word):
self._word_char_ids[i] = self._convert_word_to_char_ids(word)
self._word_char_ids[self.bos] = self.bos_chars
self._word_char_ids[self.eos] = self.eos_chars
@property
def word_char_ids(self):
return self._word_char_ids
@property
def max_word_length(self):
return self._max_word_length
def _convert_word_to_char_ids(self, word):
code = np.zeros([self.max_word_length], dtype=np.int32)
code[:] = self.pad_char
word_encoded = word.encode('utf-8',
'ignore')[:(self.max_word_length - 2)]
code[0] = self.bow_char
for k, chr_id in enumerate(word_encoded, start=1):
code[k] = ord(chr_id)
code[k + 1] = self.eow_char
return code
def word_to_char_ids(self, word):
if word in self._word_to_id:
return self._word_char_ids[self._word_to_id[word]]
else:
return self._convert_word_to_char_ids(word)
def encode_chars(self, sentence, reverse=False, split=True):
'''
Encode the sentence as a white space delimited string of tokens.
'''
if split:
chars_ids = [
self.word_to_char_ids(cur_word)
for cur_word in sentence.split()
]
else:
chars_ids = [
self.word_to_char_ids(cur_word) for cur_word in sentence
]
if reverse:
return np.vstack([self.eos_chars] + chars_ids + [self.bos_chars])
else:
return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
class Batcher(object):
'''
Batch sentences of tokenized text into character id matrices.
'''
# def __init__(self, lm_vocab_file: str, max_token_length: int):
def __init__(self, lm_vocab_file, max_token_length):
'''
lm_vocab_file = the language model vocabulary file (one line per
token)
max_token_length = the maximum number of characters in each token
'''
max_token_length = int(max_token_length)
self._lm_vocab = UnicodeCharsVocabulary(lm_vocab_file,
max_token_length)
self._max_token_length = max_token_length
# def batch_sentences(self, sentences: List[List[str]]):
def batch_sentences(self, sentences):
'''
Batch the sentences as character ids
Each sentence is a list of tokens without <s> or </s>, e.g.
[['The', 'first', 'sentence', '.'], ['Second', '.']]
'''
n_sentences = len(sentences)
max_length = max(len(sentence) for sentence in sentences) + 2
X_char_ids = np.zeros(
(n_sentences, max_length, self._max_token_length), dtype=np.int64)
for k, sent in enumerate(sentences):
length = len(sent) + 2
char_ids_without_mask = self._lm_vocab.encode_chars(
sent, split=False)
# add one so that 0 is the mask value
X_char_ids[k, :length, :] = char_ids_without_mask + 1
return X_char_ids
class TokenBatcher(object):
'''
Batch sentences of tokenized text into token id matrices.
'''
def __init__(self, lm_vocab_file):
# def __init__(self, lm_vocab_file: str):
'''
lm_vocab_file = the language model vocabulary file (one line per
token)
'''
self._lm_vocab = Vocabulary(lm_vocab_file)
# def batch_sentences(self, sentences: List[List[str]]):
def batch_sentences(self, sentences):
'''
Batch the sentences as character ids
Each sentence is a list of tokens without <s> or </s>, e.g.
[['The', 'first', 'sentence', '.'], ['Second', '.']]
'''
n_sentences = len(sentences)
max_length = max(len(sentence) for sentence in sentences) + 2
X_ids = np.zeros((n_sentences, max_length), dtype=np.int64)
for k, sent in enumerate(sentences):
length = len(sent) + 2
ids_without_mask = self._lm_vocab.encode(sent, split=False)
# add one so that 0 is the mask value
X_ids[k, :length] = ids_without_mask + 1
return X_ids
##### for training
def _get_batch(generator, batch_size, num_steps, max_word_length):
"""Read batches of input."""
cur_stream = [None] * batch_size
no_more_data = False
while True:
inputs = np.zeros([batch_size, num_steps], np.int32)
if max_word_length is not None:
char_inputs = np.zeros([batch_size, num_steps, max_word_length],
np.int32)
else:
char_inputs = None
targets = np.zeros([batch_size, num_steps], np.int32)
for i in range(batch_size):
cur_pos = 0
while cur_pos < num_steps:
if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
try:
cur_stream[i] = list(next(generator))
except StopIteration:
# No more data, exhaust current streams and quit
no_more_data = True
break
how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
next_pos = cur_pos + how_many
inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
if max_word_length is not None:
char_inputs[i, cur_pos:next_pos] = cur_stream[i][
1][:how_many]
targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many + 1]
cur_pos = next_pos
cur_stream[i][0] = cur_stream[i][0][how_many:]
if max_word_length is not None:
cur_stream[i][1] = cur_stream[i][1][how_many:]
if no_more_data:
# There is no more data. Note: this will not return data
# for the incomplete batch
break
X = {
'token_ids': inputs,
'tokens_characters': char_inputs,
'next_token_id': targets
}
yield X
class LMDataset(object):
"""
Hold a language model dataset.
A dataset is a list of tokenized files. Each file contains one sentence
per line. Each sentence is pre-tokenized and white space joined.
"""
def __init__(self,
filepattern,
vocab,
reverse=False,
test=False,
shuffle_on_load=False):
'''
filepattern = a glob string that specifies the list of files.
vocab = an instance of Vocabulary or UnicodeCharsVocabulary
reverse = if True, then iterate over tokens in each sentence in reverse
test = if True, then iterate through all data once then stop.
Otherwise, iterate forever.
shuffle_on_load = if True, then shuffle the sentences after loading.
'''
self._vocab = vocab
self._all_shards = glob.glob(filepattern)
print('Found %d shards at %s' % (len(self._all_shards), filepattern))
if test:
self._all_shards = list(np.random.choice(self._all_shards, size=4))
print('sampled %d shards at %s' % (len(self._all_shards), filepattern))
self._shards_to_choose = []
self._reverse = reverse
self._test = test
self._shuffle_on_load = shuffle_on_load
self._use_char_inputs = hasattr(vocab, 'encode_chars')
self._ids = self._load_random_shard()
def _choose_random_shard(self):
if len(self._shards_to_choose) == 0:
self._shards_to_choose = list(self._all_shards)
random.shuffle(self._shards_to_choose)
shard_name = self._shards_to_choose.pop()
return shard_name
def _load_random_shard(self):
"""Randomly select a file and read it."""
if self._test:
if len(self._all_shards) == 0:
# we've loaded all the data
# this will propogate up to the generator in get_batch
# and stop iterating
raise StopIteration
else:
shard_name = self._all_shards.pop()
else:
# just pick a random shard
shard_name = self._choose_random_shard()
ids = self._load_shard(shard_name)
self._i = 0
self._nids = len(ids)
return ids
def _load_shard(self, shard_name):
"""Read one file and convert to ids.
Args:
shard_name: file path.
Returns:
list of (id, char_id) tuples.
"""
print('Loading data from: %s' % shard_name)
with io.open(shard_name, 'r', encoding='utf-8') as f:
sentences_raw = f.readlines()
if self._reverse:
sentences = []
for sentence in sentences_raw:
splitted = sentence.split()
splitted.reverse()
sentences.append(' '.join(splitted))
else:
sentences = sentences_raw
if self._shuffle_on_load:
print('shuffle sentences')
random.shuffle(sentences)
ids = [
self.vocab.encode(sentence, self._reverse)
for sentence in sentences
]
if self._use_char_inputs:
chars_ids = [
self.vocab.encode_chars(sentence, self._reverse)
for sentence in sentences
]
else:
chars_ids = [None] * len(ids)
print('Loaded %d sentences.' % len(ids))
print('Finished loading')
return list(zip(ids, chars_ids))
def get_sentence(self):
while True:
if self._i == self._nids:
self._ids = self._load_random_shard()
ret = self._ids[self._i]
self._i += 1
yield ret
@property
def max_word_length(self):
if self._use_char_inputs:
return self._vocab.max_word_length
else:
return None
def iter_batches(self, batch_size, num_steps):
for X in _get_batch(self.get_sentence(), batch_size, num_steps,
self.max_word_length):
# token_ids = (batch_size, num_steps)
# char_inputs = (batch_size, num_steps, 50) of character ids
# targets = word ID of next word (batch_size, num_steps)
yield X
@property
def vocab(self):
return self._vocab
class BidirectionalLMDataset(object):
def __init__(self, filepattern, vocab, test=False, shuffle_on_load=False):
'''
bidirectional version of LMDataset
'''
self._data_forward = LMDataset(
filepattern,
vocab,
reverse=False,
test=test,
shuffle_on_load=shuffle_on_load)
self._data_reverse = LMDataset(
filepattern,
vocab,
reverse=True,
test=test,
shuffle_on_load=shuffle_on_load)
def iter_batches(self, batch_size, num_steps):
max_word_length = self._data_forward.max_word_length
for X, Xr in six.moves.zip(
_get_batch(self._data_forward.get_sentence(), batch_size,
num_steps, max_word_length),
_get_batch(self._data_reverse.get_sentence(), batch_size,
num_steps, max_word_length)):
for k, v in Xr.items():
X[k + '_reverse'] = v
yield X
class InvalidNumberOfCharacters(Exception):
pass
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
def dropout(input, test_mode, args):
if args.dropout and (not test_mode):
return layers.dropout(
input,
dropout_prob=args.dropout,
dropout_implementation="upscale_in_train",
seed=args.random_seed,
is_test=False)
else:
return input
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name, proj_size, test_mode, args):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_seq = dropout(input_seq, test_mode, args)
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=None),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=args.proj_clip,
cell_clip=args.cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=None),
bias_attr=fluid.ParamAttr(initializer=None))
return hidden, cell, input_proj
def encoder(x,
y,
vocab_size,
emb_size,
init_hidden=None,
init_cell=None,
para_name='',
custom_samples=None,
custom_probabilities=None,
test_mode=False,
args=None):
x_emb = layers.embedding(
input=x,
size=[vocab_size, emb_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
rnn_input = x_emb
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
for i in range(args.num_layers):
rnn_input = dropout(rnn_input, test_mode, args)
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, args.hidden_size, h0, c0,
para_name + 'layer{}'.format(i + 1), emb_size, test_mode, args)
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out = dropout(rnn_out, test_mode, args)
cell = dropout(cell, test_mode, args)
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
rnn_input = rnn_out
cells.append(cell)
projs.append(input_proj)
softmax_weight = layers.create_parameter(
[vocab_size, emb_size], dtype="float32", name="softmax_weight")
softmax_bias = layers.create_parameter(
[vocab_size], dtype="float32", name='softmax_bias')
projection = layers.matmul(rnn_outs[-1], softmax_weight, transpose_y=True)
projection = layers.elementwise_add(projection, softmax_bias)
projection = layers.reshape(projection, shape=[-1, vocab_size])
if args.sample_softmax and (not test_mode):
loss = layers.sampled_softmax_with_cross_entropy(
logits=projection,
label=y,
num_samples=args.n_negative_samples_batch,
seed=args.random_seed)
else:
label = layers.one_hot(input=y, depth=vocab_size)
loss = layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=True)
return [x_emb, projection, loss], rnn_outs, rnn_outs_ori, cells, projs
class LanguageModel(object):
def __init__(self, args, vocab_size, test_mode):
self.args = args
self.vocab_size = vocab_size
self.test_mode = test_mode
def build(self):
args = self.args
emb_size = args.embed_size
proj_size = args.embed_size
hidden_size = args.hidden_size
batch_size = args.batch_size
num_layers = args.num_layers
num_steps = args.num_steps
lstm_outputs = []
x_f = layers.data(name="x", shape=[1], dtype='int64', lod_level=1)
y_f = layers.data(name="y", shape=[1], dtype='int64', lod_level=1)
x_b = layers.data(name="x_r", shape=[1], dtype='int64', lod_level=1)
y_b = layers.data(name="y_r", shape=[1], dtype='int64', lod_level=1)
init_hiddens_ = layers.data(
name="init_hiddens", shape=[1], dtype='float32')
init_cells_ = layers.data(
name="init_cells", shape=[1], dtype='float32')
init_hiddens = layers.reshape(
init_hiddens_, shape=[2 * num_layers, -1, proj_size])
init_cells = layers.reshape(
init_cells_, shape=[2 * num_layers, -1, hidden_size])
init_hidden = layers.slice(
init_hiddens, axes=[0], starts=[0], ends=[num_layers])
init_cell = layers.slice(
init_cells, axes=[0], starts=[0], ends=[num_layers])
init_hidden_r = layers.slice(
init_hiddens, axes=[0], starts=[num_layers],
ends=[2 * num_layers])
init_cell_r = layers.slice(
init_cells, axes=[0], starts=[num_layers], ends=[2 * num_layers])
if args.use_custom_samples:
custom_samples = layers.data(
name="custom_samples",
shape=[args.n_negative_samples_batch + 1],
dtype='int64',
lod_level=1)
custom_samples_r = layers.data(
name="custom_samples_r",
shape=[args.n_negative_samples_batch + 1],
dtype='int64',
lod_level=1)
custom_probabilities = layers.data(
name="custom_probabilities",
shape=[args.n_negative_samples_batch + 1],
dtype='float32',
lod_level=1)
else:
custom_samples = None
custom_samples_r = None
custom_probabilities = None
forward, fw_hiddens, fw_hiddens_ori, fw_cells, fw_projs = encoder(
x_f,
y_f,
self.vocab_size,
emb_size,
init_hidden,
init_cell,
para_name='fw_',
custom_samples=custom_samples,
custom_probabilities=custom_probabilities,
test_mode=self.test_mode,
args=args)
backward, bw_hiddens, bw_hiddens_ori, bw_cells, bw_projs = encoder(
x_b,
y_b,
self.vocab_size,
emb_size,
init_hidden_r,
init_cell_r,
para_name='bw_',
custom_samples=custom_samples_r,
custom_probabilities=custom_probabilities,
test_mode=self.test_mode,
args=args)
losses = layers.concat([forward[-1], backward[-1]])
self.loss = layers.reduce_mean(losses)
self.loss.persistable = True
self.grad_vars = [x_f, y_f, x_b, y_b, self.loss]
self.grad_vars_name = ['x', 'y', 'x_r', 'y_r', 'final_loss']
fw_vars_name = ['x_emb', 'proj', 'loss'] + [
'init_hidden', 'init_cell'
] + ['rnn_out', 'rnn_out2', 'cell', 'cell2', 'xproj', 'xproj2']
bw_vars_name = ['x_emb_r', 'proj_r', 'loss_r'] + [
'init_hidden_r', 'init_cell_r'
] + [
'rnn_out_r', 'rnn_out2_r', 'cell_r', 'cell2_r', 'xproj_r',
'xproj2_r'
]
fw_vars = forward + [init_hidden, init_cell
] + fw_hiddens + fw_cells + fw_projs
bw_vars = backward + [init_hidden_r, init_cell_r
] + bw_hiddens + bw_cells + bw_projs
for i in range(len(fw_vars_name)):
self.grad_vars.append(fw_vars[i])
self.grad_vars.append(bw_vars[i])
self.grad_vars_name.append(fw_vars_name[i])
self.grad_vars_name.append(bw_vars_name[i])
if args.use_custom_samples:
self.feed_order = [
'x', 'y', 'x_r', 'y_r', 'custom_samples', 'custom_samples_r',
'custom_probabilities'
]
else:
self.feed_order = ['x', 'y', 'x_r', 'y_r']
self.last_hidden = [
fluid.layers.sequence_last_step(input=x)
for x in fw_hiddens_ori + bw_hiddens_ori
]
self.last_cell = [
fluid.layers.sequence_last_step(input=x)
for x in fw_cells + bw_cells
]
self.last_hidden = layers.concat(self.last_hidden, axis=0)
self.last_hidden.persistable = True
self.last_cell = layers.concat(self.last_cell, axis=0)
self.last_cell.persistable = True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import sys
import numpy as np
Py3 = sys.version_info[0] == 3
def listDir(rootDir):
res = []
for filename in os.listdir(rootDir):
pathname = os.path.join(rootDir, filename)
if (os.path.isfile(pathname)):
res.append(pathname)
return res
_unk = -1
_bos = -1
_eos = -1
def _read_words(filename):
data = []
with open(filename, "r") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
print("vocab word num", len(words))
word_to_id = dict(zip(words, range(len(words))))
return word_to_id
def _load_vocab(filename):
with open(filename, "r") as f:
words = f.read().decode("utf-8").replace("\n", " ").split()
word_to_id = dict(zip(words, range(len(words))))
_unk = word_to_id['<S>']
_eos = word_to_id['</S>']
_unk = word_to_id['<UNK>']
return word_to_id
def _file_to_word_ids(filenames, word_to_id):
for filename in filenames:
data = _read_words(filename)
for id in [word_to_id[word] for word in data if word in word_to_id]:
yield id
def ptb_raw_data(data_path=None, vocab_path=None, args=None):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
if vocab_path:
word_to_id = _load_vocab(vocab_path)
if not args.train_path:
train_path = os.path.join(data_path, "train")
train_data = _file_to_word_ids(listDir(train_path), word_to_id)
else:
train_path = args.train_path
train_data = _file_to_word_ids([train_path], word_to_id)
valid_path = os.path.join(data_path, "dev")
test_path = os.path.join(data_path, "dev")
valid_data = _file_to_word_ids(listDir(valid_path), word_to_id)
test_data = _file_to_word_ids(listDir(test_path), word_to_id)
vocabulary = len(word_to_id)
return train_data, valid_data, test_data, vocabulary
def get_data_iter(raw_data, batch_size, num_steps):
def __impl__():
buf = []
while True:
if len(buf) >= num_steps * batch_size + 1:
x = np.asarray(
buf[:-1], dtype='int64').reshape((batch_size, num_steps))
y = np.asarray(
buf[1:], dtype='int64').reshape((batch_size, num_steps))
yield (x, y)
buf = [buf[-1]]
try:
buf.append(raw_data.next())
except StopIteration:
break
return __impl__
export CUDA_VISIBLE_DEVICES=0
python train.py \
--train_path='baike/train/sentence_file_*' \
--test_path='baike/dev/sentence_file_*' \
--vocab_path baike/vocabulary_min5k.txt \
--learning_rate 0.2 \
--use_gpu True \
--local True $@
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册