未验证 提交 0f2dfd4b 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #1991 from JesseyXujin/elmo_test

ELMO
<h1 align="center">ELMO</h1>
## 介绍
ELMO(Embeddings from Language Models)是一种新型深度语境化词表征,可对词进行复杂特征(如句法和语义)和词在语言语境中的变化进行建模(即对多义词进行建模)。PaddlePaddle版本该模型支持多卡训练,训练速度比主流实现快约1倍, 验证在中文词法分析任务上f1值提升0.68%。
ELMO在大语料上以language model为训练目标,训练出bidirectional LSTM模型,利用LSTM产生词语的表征, 对下游NLP任务(如问答、分类、命名实体识别等)进行微调。
## 基本配置及第三方安装包
Python==2.7
PaddlePaddle lastest版本
numpy ==1.15.1
six==1.11.0
glob
## 预训练模型
1. 把文档文件切分成句子,并基于词表(参考vocabulary_min5k.txt)对句子进行切词。把文件切分成训练集trainset和测试集testset。
```
本 书 介绍 了 中国 经济 发展 的 内外 平衡 问题 、 亚洲 金融 危机 十 周年 回顾 与 反思 、 实践 中 的 城乡 统筹 发展 、 未来 十 年 中国 需要 研究 的 重大 课题 、 科学 发展 与 新型 工业 化 等 方面 。
```
```
吴 敬 琏 曾经 提出 中国 股市 “ 赌场 论 ” , 主张 维护 市场 规则 , 保护 草根 阶层 生计 , 被 誉 为 “ 中国 经济 学界 良心 ” , 是 媒体 和 公众 眼中 的 学术 明星
```
2. 训练模型
```shell
sh run.sh
```
3. 把checkpoint结果写入文件中。
## 单机多卡训练
模型支持单机多卡训练,需要在run.sh里export CUDA_VISIBLE_DEVICES设置指定卡,如下所示:
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
```
## 如何利用ELMO做微调
1. 下载ELMO Paddle官方发布Checkout文件
[PaddlePaddle官方发布Checkout文件下载地址](https://dureader.gz.bcebos.com/elmo/elmo_chinese_checkpoint.tar.gz)
2. 在train部分中加载ELMO checkpoint文件
```shell
src_pretrain_model_path = '490001' #490001为ELMO checkpoint文件
def if_exist(var):
path = os.path.join(src_pretrain_model_path, var.name)
exist = os.path.exists(path)
if exist:
print('Load model: %s' % path)
return exist
fluid.io.load_vars(executor=exe, dirname=src_pretrain_model_path, predicate=if_exist, main_program=main_program)
```
3. 在下游NLP任务文件夹中加入bilm.py文件
4. 基于elmo词表(参考vocabulary_min5k.txt)对输入的句子或段落进行切词,并把切词的词转化为id,放入feed_dict中。
5. 在下游任务网络定义中embedding部分加入ELMO网络的定义
```shell
#引入 bilm.py embedding部分和encoder部分
from bilm import elmo_encoder
from bilm import emb
#word为输入elmo部分切词后的字典
elmo_embedding = emb(word)
elmo_enc= elmo_encoder(elmo_embedding)
#与NLP任务中生成词向量word_embedding做连接操作
word_embedding=layers.concat(input=[elmo_enc, word_embedding], axis=1)
```
## 参考论文
[Deep contextualized word representations](https://arxiv.org/abs/1802.05365)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--load_dir",
type=str,
default="",
help="Specify the path to load trained models.")
parser.add_argument(
"--load_pretraining_params",
type=str,
default="",
help="Specify the path to load pretrained model parameters, NOT including moment and learning_rate")
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--embed_size",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--hidden_size",
type=int,
default=4096,
help="The size of rnn hidden unit. (default: %(default)d)")
parser.add_argument(
"--num_layers",
type=int,
default=2,
help="The size of rnn layers. (default: %(default)d)")
parser.add_argument(
"--num_steps",
type=int,
default=20,
help="The size of sequence len. (default: %(default)d)")
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument("--vocab_path", type=str, help="vocab file path")
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
parser.add_argument('--enable_ce', action='store_true')
parser.add_argument('--test_nccl', action='store_true')
parser.add_argument('--optim', default='adagrad', help='optimizer type')
parser.add_argument('--sample_softmax', action='store_true')
parser.add_argument(
"--learning_rate",
type=float,
default=0.2,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--save_interval",
type=int,
default=10000,
help="log the train loss every n batches."
"(default: %(default)d)")
parser.add_argument(
"--dev_interval",
type=int,
default=10000,
help="cal dev loss every n batches."
"(default: %(default)d)")
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--max_grad_norm', type=float, default=10.0)
parser.add_argument('--proj_clip', type=float, default=3.0)
parser.add_argument('--cell_clip', type=float, default=3.0)
parser.add_argument('--max_epoch', type=float, default=10)
parser.add_argument('--local', type=bool, default=False)
parser.add_argument('--shuffle', type=bool, default=False)
parser.add_argument('--use_custom_samples', type=bool, default=False)
parser.add_argument('--para_save_dir', type=str, default='model_new')
parser.add_argument('--train_path', type=str, default='')
parser.add_argument('--test_path', type=str, default='')
parser.add_argument('--update_method', type=str, default='nccl2')
parser.add_argument('--random_seed', type=int, default=0)
parser.add_argument('--n_negative_samples_batch', type=int, default=8000)
args = parser.parse_args()
return args
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is used to finetone
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
# if you use our release weight layers,do not use the args.
cell_clip = 3.0
proj_clip = 3.0
hidden_size = 4096
vocab_size = 52445
embed_size = 512
# according to orginal paper, dropout need to be modifyed on finetone
modify_dropout = 1
proj_size = 512
num_layers = 2
random_seed = 0
dropout_rate = 0.5
def dropout(input):
return layers.dropout(
input,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
seed=random_seed,
is_test=False)
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=init),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=proj_clip,
cell_clip=cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=None),
bias_attr=fluid.ParamAttr(initializer=None))
return hidden, cell, input_proj
def encoder(x_emb,
init_hidden=None,
init_cell=None,
para_name=''):
rnn_input = x_emb
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
for i in range(num_layers):
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, hidden_size, h0, c0,
para_name + 'layer{}'.format(i + 1))
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out.stop_gradient = True
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
# add weight layers for finetone
a1 = layers.create_parameter(
[1], dtype="float32", name="gamma1")
a2 = layers.create_parameter(
[1], dtype="float32", name="gamma2")
rnn_outs[0].stop_gradient = True
rnn_outs[1].stop_gradient = True
num_layer1 = rnn_outs[0] * a1
num_layer2 = rnn_outs[1] * a2
output_layer = num_layer1 * 0.5 + num_layer2 * 0.5
return output_layer, rnn_outs_ori
def emb(x):
x_emb = layers.embedding(
input=x,
size=[vocab_size, embed_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
return x_emb
def elmo_encoder(x_emb):
x_emb_r = fluid.layers.sequence_reverse(x_emb, name=None)
fw_hiddens, fw_hiddens_ori = encoder(
x_emb,
para_name='fw_')
bw_hiddens, bw_hiddens_ori = encoder(
x_emb_r,
para_name='bw_')
embedding = layers.concat(input=[fw_hiddens, bw_hiddens], axis=1)
# add dropout on finetone
embedding = dropout(embedding)
a = layers.create_parameter(
[1], dtype="float32", name="gamma")
embedding = embedding * a
return embedding
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import random
import numpy as np
import io
import six
class Vocabulary(object):
'''
A token vocabulary. Holds a map from token to ids and provides
a method for encoding text to a sequence of ids.
'''
def __init__(self, filename, validate_file=False):
'''
filename = the vocabulary file. It is a flat text file with one
(normalized) token per line. In addition, the file should also
contain the special tokens <S>, </S>, <UNK> (case sensitive).
'''
self._id_to_word = []
self._word_to_id = {}
self._unk = -1
self._bos = -1
self._eos = -1
with io.open(filename, 'r', encoding='utf-8') as f:
idx = 0
for line in f:
word_name = line.strip()
if word_name == '<S>':
self._bos = idx
elif word_name == '</S>':
self._eos = idx
elif word_name == '<UNK>':
self._unk = idx
if word_name == '!!!MAXTERMID':
continue
self._id_to_word.append(word_name)
self._word_to_id[word_name] = idx
idx += 1
# check to ensure file has special tokens
if validate_file:
if self._bos == -1 or self._eos == -1 or self._unk == -1:
raise ValueError("Ensure the vocabulary file has "
"<S>, </S>, <UNK> tokens")
@property
def bos(self):
return self._bos
@property
def eos(self):
return self._eos
@property
def unk(self):
return self._unk
@property
def size(self):
return len(self._id_to_word)
def word_to_id(self, word):
if word in self._word_to_id:
return self._word_to_id[word]
return self.unk
def id_to_word(self, cur_id):
return self._id_to_word[cur_id]
def decode(self, cur_ids):
"""Convert a list of ids to a sentence, with space inserted."""
return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
def encode(self, sentence, reverse=False, split=True):
"""Convert a sentence to a list of ids, with special tokens added.
Sentence is a single string with tokens separated by whitespace.
If reverse, then the sentence is assumed to be reversed, and
this method will swap the BOS/EOS tokens appropriately."""
if split:
word_ids = [
self.word_to_id(cur_word) for cur_word in sentence.split()
]
else:
word_ids = [self.word_to_id(cur_word) for cur_word in sentence]
if reverse:
return np.array([self.eos] + word_ids + [self.bos], dtype=np.int32)
else:
return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
class UnicodeCharsVocabulary(Vocabulary):
"""Vocabulary containing character-level and word level information.
Has a word vocabulary that is used to lookup word ids and
a character id that is used to map words to arrays of character ids.
The character ids are defined by ord(c) for c in word.encode('utf-8')
This limits the total number of possible char ids to 256.
To this we add 5 additional special ids: begin sentence, end sentence,
begin word, end word and padding.
WARNING: for prediction, we add +1 to the output ids from this
class to create a special padding id (=0). As a result, we suggest
you use the `Batcher`, `TokenBatcher`, and `LMDataset` classes instead
of this lower level class. If you are using this lower level class,
then be sure to add the +1 appropriately, otherwise embeddings computed
from the pre-trained model will be useless.
"""
def __init__(self, filename, max_word_length, **kwargs):
super(UnicodeCharsVocabulary, self).__init__(filename, **kwargs)
self._max_word_length = max_word_length
# char ids 0-255 come from utf-8 encoding bytes
# assign 256-300 to special chars
self.bos_char = 256 # <begin sentence>
self.eos_char = 257 # <end sentence>
self.bow_char = 258 # <begin word>
self.eow_char = 259 # <end word>
self.pad_char = 260 # <padding>
num_words = len(self._id_to_word)
self._word_char_ids = np.zeros(
[num_words, max_word_length], dtype=np.int32)
# the charcter representation of the begin/end of sentence characters
def _make_bos_eos(c):
r = np.zeros([self.max_word_length], dtype=np.int32)
r[:] = self.pad_char
r[0] = self.bow_char
r[1] = c
r[2] = self.eow_char
return r
self.bos_chars = _make_bos_eos(self.bos_char)
self.eos_chars = _make_bos_eos(self.eos_char)
for i, word in enumerate(self._id_to_word):
self._word_char_ids[i] = self._convert_word_to_char_ids(word)
self._word_char_ids[self.bos] = self.bos_chars
self._word_char_ids[self.eos] = self.eos_chars
@property
def word_char_ids(self):
return self._word_char_ids
@property
def max_word_length(self):
return self._max_word_length
def _convert_word_to_char_ids(self, word):
code = np.zeros([self.max_word_length], dtype=np.int32)
code[:] = self.pad_char
word_encoded = word.encode('utf-8',
'ignore')[:(self.max_word_length - 2)]
code[0] = self.bow_char
for k, chr_id in enumerate(word_encoded, start=1):
code[k] = ord(chr_id)
code[k + 1] = self.eow_char
return code
def word_to_char_ids(self, word):
if word in self._word_to_id:
return self._word_char_ids[self._word_to_id[word]]
else:
return self._convert_word_to_char_ids(word)
def encode_chars(self, sentence, reverse=False, split=True):
'''
Encode the sentence as a white space delimited string of tokens.
'''
if split:
chars_ids = [
self.word_to_char_ids(cur_word)
for cur_word in sentence.split()
]
else:
chars_ids = [
self.word_to_char_ids(cur_word) for cur_word in sentence
]
if reverse:
return np.vstack([self.eos_chars] + chars_ids + [self.bos_chars])
else:
return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
class Batcher(object):
'''
Batch sentences of tokenized text into character id matrices.
'''
# def __init__(self, lm_vocab_file: str, max_token_length: int):
def __init__(self, lm_vocab_file, max_token_length):
'''
lm_vocab_file = the language model vocabulary file (one line per
token)
max_token_length = the maximum number of characters in each token
'''
max_token_length = int(max_token_length)
self._lm_vocab = UnicodeCharsVocabulary(lm_vocab_file,
max_token_length)
self._max_token_length = max_token_length
# def batch_sentences(self, sentences: List[List[str]]):
def batch_sentences(self, sentences):
'''
Batch the sentences as character ids
Each sentence is a list of tokens without <s> or </s>, e.g.
[['The', 'first', 'sentence', '.'], ['Second', '.']]
'''
n_sentences = len(sentences)
max_length = max(len(sentence) for sentence in sentences) + 2
X_char_ids = np.zeros(
(n_sentences, max_length, self._max_token_length), dtype=np.int64)
for k, sent in enumerate(sentences):
length = len(sent) + 2
char_ids_without_mask = self._lm_vocab.encode_chars(
sent, split=False)
# add one so that 0 is the mask value
X_char_ids[k, :length, :] = char_ids_without_mask + 1
return X_char_ids
class TokenBatcher(object):
'''
Batch sentences of tokenized text into token id matrices.
'''
def __init__(self, lm_vocab_file):
# def __init__(self, lm_vocab_file: str):
'''
lm_vocab_file = the language model vocabulary file (one line per
token)
'''
self._lm_vocab = Vocabulary(lm_vocab_file)
# def batch_sentences(self, sentences: List[List[str]]):
def batch_sentences(self, sentences):
'''
Batch the sentences as character ids
Each sentence is a list of tokens without <s> or </s>, e.g.
[['The', 'first', 'sentence', '.'], ['Second', '.']]
'''
n_sentences = len(sentences)
max_length = max(len(sentence) for sentence in sentences) + 2
X_ids = np.zeros((n_sentences, max_length), dtype=np.int64)
for k, sent in enumerate(sentences):
length = len(sent) + 2
ids_without_mask = self._lm_vocab.encode(sent, split=False)
# add one so that 0 is the mask value
X_ids[k, :length] = ids_without_mask + 1
return X_ids
##### for training
def _get_batch(generator, batch_size, num_steps, max_word_length):
"""Read batches of input."""
cur_stream = [None] * batch_size
no_more_data = False
while True:
inputs = np.zeros([batch_size, num_steps], np.int32)
if max_word_length is not None:
char_inputs = np.zeros([batch_size, num_steps, max_word_length],
np.int32)
else:
char_inputs = None
targets = np.zeros([batch_size, num_steps], np.int32)
for i in range(batch_size):
cur_pos = 0
while cur_pos < num_steps:
if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
try:
cur_stream[i] = list(next(generator))
except StopIteration:
# No more data, exhaust current streams and quit
no_more_data = True
break
how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
next_pos = cur_pos + how_many
inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
if max_word_length is not None:
char_inputs[i, cur_pos:next_pos] = cur_stream[i][
1][:how_many]
targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many + 1]
cur_pos = next_pos
cur_stream[i][0] = cur_stream[i][0][how_many:]
if max_word_length is not None:
cur_stream[i][1] = cur_stream[i][1][how_many:]
if no_more_data:
# There is no more data. Note: this will not return data
# for the incomplete batch
break
X = {
'token_ids': inputs,
'tokens_characters': char_inputs,
'next_token_id': targets
}
yield X
class LMDataset(object):
"""
Hold a language model dataset.
A dataset is a list of tokenized files. Each file contains one sentence
per line. Each sentence is pre-tokenized and white space joined.
"""
def __init__(self,
filepattern,
vocab,
reverse=False,
test=False,
shuffle_on_load=False):
'''
filepattern = a glob string that specifies the list of files.
vocab = an instance of Vocabulary or UnicodeCharsVocabulary
reverse = if True, then iterate over tokens in each sentence in reverse
test = if True, then iterate through all data once then stop.
Otherwise, iterate forever.
shuffle_on_load = if True, then shuffle the sentences after loading.
'''
self._vocab = vocab
self._all_shards = glob.glob(filepattern)
print('Found %d shards at %s' % (len(self._all_shards), filepattern))
if test:
self._all_shards = list(np.random.choice(self._all_shards, size=4))
print('sampled %d shards at %s' % (len(self._all_shards), filepattern))
self._shards_to_choose = []
self._reverse = reverse
self._test = test
self._shuffle_on_load = shuffle_on_load
self._use_char_inputs = hasattr(vocab, 'encode_chars')
self._ids = self._load_random_shard()
def _choose_random_shard(self):
if len(self._shards_to_choose) == 0:
self._shards_to_choose = list(self._all_shards)
random.shuffle(self._shards_to_choose)
shard_name = self._shards_to_choose.pop()
return shard_name
def _load_random_shard(self):
"""Randomly select a file and read it."""
if self._test:
if len(self._all_shards) == 0:
# we've loaded all the data
# this will propogate up to the generator in get_batch
# and stop iterating
raise StopIteration
else:
shard_name = self._all_shards.pop()
else:
# just pick a random shard
shard_name = self._choose_random_shard()
ids = self._load_shard(shard_name)
self._i = 0
self._nids = len(ids)
return ids
def _load_shard(self, shard_name):
"""Read one file and convert to ids.
Args:
shard_name: file path.
Returns:
list of (id, char_id) tuples.
"""
print('Loading data from: %s' % shard_name)
with io.open(shard_name, 'r', encoding='utf-8') as f:
sentences_raw = f.readlines()
if self._reverse:
sentences = []
for sentence in sentences_raw:
splitted = sentence.split()
splitted.reverse()
sentences.append(' '.join(splitted))
else:
sentences = sentences_raw
if self._shuffle_on_load:
print('shuffle sentences')
random.shuffle(sentences)
ids = [
self.vocab.encode(sentence, self._reverse)
for sentence in sentences
]
if self._use_char_inputs:
chars_ids = [
self.vocab.encode_chars(sentence, self._reverse)
for sentence in sentences
]
else:
chars_ids = [None] * len(ids)
print('Loaded %d sentences.' % len(ids))
print('Finished loading')
return list(zip(ids, chars_ids))
def get_sentence(self):
while True:
if self._i == self._nids:
self._ids = self._load_random_shard()
ret = self._ids[self._i]
self._i += 1
yield ret
@property
def max_word_length(self):
if self._use_char_inputs:
return self._vocab.max_word_length
else:
return None
def iter_batches(self, batch_size, num_steps):
for X in _get_batch(self.get_sentence(), batch_size, num_steps,
self.max_word_length):
# token_ids = (batch_size, num_steps)
# char_inputs = (batch_size, num_steps, 50) of character ids
# targets = word ID of next word (batch_size, num_steps)
yield X
@property
def vocab(self):
return self._vocab
class BidirectionalLMDataset(object):
def __init__(self, filepattern, vocab, test=False, shuffle_on_load=False):
'''
bidirectional version of LMDataset
'''
self._data_forward = LMDataset(
filepattern,
vocab,
reverse=False,
test=test,
shuffle_on_load=shuffle_on_load)
self._data_reverse = LMDataset(
filepattern,
vocab,
reverse=True,
test=test,
shuffle_on_load=shuffle_on_load)
def iter_batches(self, batch_size, num_steps):
max_word_length = self._data_forward.max_word_length
for X, Xr in six.moves.zip(
_get_batch(self._data_forward.get_sentence(), batch_size,
num_steps, max_word_length),
_get_batch(self._data_reverse.get_sentence(), batch_size,
num_steps, max_word_length)):
for k, v in Xr.items():
X[k + '_reverse'] = v
yield X
class InvalidNumberOfCharacters(Exception):
pass
因为 它太大了无法显示 source diff 。你可以改为 查看blob
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
def dropout(input, test_mode, args):
if args.dropout and (not test_mode):
return layers.dropout(
input,
dropout_prob=args.dropout,
dropout_implementation="upscale_in_train",
seed=args.random_seed,
is_test=False)
else:
return input
def lstmp_encoder(input_seq, gate_size, h_0, c_0, para_name, proj_size, test_mode, args):
# A lstm encoder implementation with projection.
# Linear transformation part for input gate, output gate, forget gate
# and cell activation vectors need be done outside of dynamic_lstm.
# So the output size is 4 times of gate_size.
input_seq = dropout(input_seq, test_mode, args)
input_proj = layers.fc(input=input_seq,
param_attr=fluid.ParamAttr(
name=para_name + '_gate_w', initializer=None),
size=gate_size * 4,
act=None,
bias_attr=False)
hidden, cell = layers.dynamic_lstmp(
input=input_proj,
size=gate_size * 4,
proj_size=proj_size,
h_0=h_0,
c_0=c_0,
use_peepholes=False,
proj_clip=args.proj_clip,
cell_clip=args.cell_clip,
proj_activation="identity",
param_attr=fluid.ParamAttr(initializer=None),
bias_attr=fluid.ParamAttr(initializer=None))
return hidden, cell, input_proj
def encoder(x,
y,
vocab_size,
emb_size,
init_hidden=None,
init_cell=None,
para_name='',
custom_samples=None,
custom_probabilities=None,
test_mode=False,
args=None):
x_emb = layers.embedding(
input=x,
size=[vocab_size, emb_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(name='embedding_para'))
rnn_input = x_emb
rnn_outs = []
rnn_outs_ori = []
cells = []
projs = []
for i in range(args.num_layers):
rnn_input = dropout(rnn_input, test_mode, args)
if init_hidden and init_cell:
h0 = layers.squeeze(
layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
c0 = layers.squeeze(
layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1]),
axes=[0])
else:
h0 = c0 = None
rnn_out, cell, input_proj = lstmp_encoder(
rnn_input, args.hidden_size, h0, c0,
para_name + 'layer{}'.format(i + 1), emb_size, test_mode, args)
rnn_out_ori = rnn_out
if i > 0:
rnn_out = rnn_out + rnn_input
rnn_out = dropout(rnn_out, test_mode, args)
cell = dropout(cell, test_mode, args)
rnn_outs.append(rnn_out)
rnn_outs_ori.append(rnn_out_ori)
rnn_input = rnn_out
cells.append(cell)
projs.append(input_proj)
softmax_weight = layers.create_parameter(
[vocab_size, emb_size], dtype="float32", name="softmax_weight")
softmax_bias = layers.create_parameter(
[vocab_size], dtype="float32", name='softmax_bias')
projection = layers.matmul(rnn_outs[-1], softmax_weight, transpose_y=True)
projection = layers.elementwise_add(projection, softmax_bias)
projection = layers.reshape(projection, shape=[-1, vocab_size])
if args.sample_softmax and (not test_mode):
loss = layers.sampled_softmax_with_cross_entropy(
logits=projection,
label=y,
num_samples=args.n_negative_samples_batch,
seed=args.random_seed)
else:
label = layers.one_hot(input=y, depth=vocab_size)
loss = layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=True)
return [x_emb, projection, loss], rnn_outs, rnn_outs_ori, cells, projs
class LanguageModel(object):
def __init__(self, args, vocab_size, test_mode):
self.args = args
self.vocab_size = vocab_size
self.test_mode = test_mode
def build(self):
args = self.args
emb_size = args.embed_size
proj_size = args.embed_size
hidden_size = args.hidden_size
batch_size = args.batch_size
num_layers = args.num_layers
num_steps = args.num_steps
lstm_outputs = []
x_f = layers.data(name="x", shape=[1], dtype='int64', lod_level=1)
y_f = layers.data(name="y", shape=[1], dtype='int64', lod_level=1)
x_b = layers.data(name="x_r", shape=[1], dtype='int64', lod_level=1)
y_b = layers.data(name="y_r", shape=[1], dtype='int64', lod_level=1)
init_hiddens_ = layers.data(
name="init_hiddens", shape=[1], dtype='float32')
init_cells_ = layers.data(
name="init_cells", shape=[1], dtype='float32')
init_hiddens = layers.reshape(
init_hiddens_, shape=[2 * num_layers, -1, proj_size])
init_cells = layers.reshape(
init_cells_, shape=[2 * num_layers, -1, hidden_size])
init_hidden = layers.slice(
init_hiddens, axes=[0], starts=[0], ends=[num_layers])
init_cell = layers.slice(
init_cells, axes=[0], starts=[0], ends=[num_layers])
init_hidden_r = layers.slice(
init_hiddens, axes=[0], starts=[num_layers],
ends=[2 * num_layers])
init_cell_r = layers.slice(
init_cells, axes=[0], starts=[num_layers], ends=[2 * num_layers])
if args.use_custom_samples:
custom_samples = layers.data(
name="custom_samples",
shape=[args.n_negative_samples_batch + 1],
dtype='int64',
lod_level=1)
custom_samples_r = layers.data(
name="custom_samples_r",
shape=[args.n_negative_samples_batch + 1],
dtype='int64',
lod_level=1)
custom_probabilities = layers.data(
name="custom_probabilities",
shape=[args.n_negative_samples_batch + 1],
dtype='float32',
lod_level=1)
else:
custom_samples = None
custom_samples_r = None
custom_probabilities = None
forward, fw_hiddens, fw_hiddens_ori, fw_cells, fw_projs = encoder(
x_f,
y_f,
self.vocab_size,
emb_size,
init_hidden,
init_cell,
para_name='fw_',
custom_samples=custom_samples,
custom_probabilities=custom_probabilities,
test_mode=self.test_mode,
args=args)
backward, bw_hiddens, bw_hiddens_ori, bw_cells, bw_projs = encoder(
x_b,
y_b,
self.vocab_size,
emb_size,
init_hidden_r,
init_cell_r,
para_name='bw_',
custom_samples=custom_samples_r,
custom_probabilities=custom_probabilities,
test_mode=self.test_mode,
args=args)
losses = layers.concat([forward[-1], backward[-1]])
self.loss = layers.reduce_mean(losses)
self.loss.persistable = True
self.grad_vars = [x_f, y_f, x_b, y_b, self.loss]
self.grad_vars_name = ['x', 'y', 'x_r', 'y_r', 'final_loss']
fw_vars_name = ['x_emb', 'proj', 'loss'] + [
'init_hidden', 'init_cell'
] + ['rnn_out', 'rnn_out2', 'cell', 'cell2', 'xproj', 'xproj2']
bw_vars_name = ['x_emb_r', 'proj_r', 'loss_r'] + [
'init_hidden_r', 'init_cell_r'
] + [
'rnn_out_r', 'rnn_out2_r', 'cell_r', 'cell2_r', 'xproj_r',
'xproj2_r'
]
fw_vars = forward + [init_hidden, init_cell
] + fw_hiddens + fw_cells + fw_projs
bw_vars = backward + [init_hidden_r, init_cell_r
] + bw_hiddens + bw_cells + bw_projs
for i in range(len(fw_vars_name)):
self.grad_vars.append(fw_vars[i])
self.grad_vars.append(bw_vars[i])
self.grad_vars_name.append(fw_vars_name[i])
self.grad_vars_name.append(bw_vars_name[i])
if args.use_custom_samples:
self.feed_order = [
'x', 'y', 'x_r', 'y_r', 'custom_samples', 'custom_samples_r',
'custom_probabilities'
]
else:
self.feed_order = ['x', 'y', 'x_r', 'y_r']
self.last_hidden = [
fluid.layers.sequence_last_step(input=x)
for x in fw_hiddens_ori + bw_hiddens_ori
]
self.last_cell = [
fluid.layers.sequence_last_step(input=x)
for x in fw_cells + bw_cells
]
self.last_hidden = layers.concat(self.last_hidden, axis=0)
self.last_hidden.persistable = True
self.last_cell = layers.concat(self.last_cell, axis=0)
self.last_cell.persistable = True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import sys
import numpy as np
Py3 = sys.version_info[0] == 3
def listDir(rootDir):
res = []
for filename in os.listdir(rootDir):
pathname = os.path.join(rootDir, filename)
if (os.path.isfile(pathname)):
res.append(pathname)
return res
_unk = -1
_bos = -1
_eos = -1
def _read_words(filename):
data = []
with open(filename, "r") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
print("vocab word num", len(words))
word_to_id = dict(zip(words, range(len(words))))
return word_to_id
def _load_vocab(filename):
with open(filename, "r") as f:
words = f.read().decode("utf-8").replace("\n", " ").split()
word_to_id = dict(zip(words, range(len(words))))
_unk = word_to_id['<S>']
_eos = word_to_id['</S>']
_unk = word_to_id['<UNK>']
return word_to_id
def _file_to_word_ids(filenames, word_to_id):
for filename in filenames:
data = _read_words(filename)
for id in [word_to_id[word] for word in data if word in word_to_id]:
yield id
def ptb_raw_data(data_path=None, vocab_path=None, args=None):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
if vocab_path:
word_to_id = _load_vocab(vocab_path)
if not args.train_path:
train_path = os.path.join(data_path, "train")
train_data = _file_to_word_ids(listDir(train_path), word_to_id)
else:
train_path = args.train_path
train_data = _file_to_word_ids([train_path], word_to_id)
valid_path = os.path.join(data_path, "dev")
test_path = os.path.join(data_path, "dev")
valid_data = _file_to_word_ids(listDir(valid_path), word_to_id)
test_data = _file_to_word_ids(listDir(test_path), word_to_id)
vocabulary = len(word_to_id)
return train_data, valid_data, test_data, vocabulary
def get_data_iter(raw_data, batch_size, num_steps):
def __impl__():
buf = []
while True:
if len(buf) >= num_steps * batch_size + 1:
x = np.asarray(
buf[:-1], dtype='int64').reshape((batch_size, num_steps))
y = np.asarray(
buf[1:], dtype='int64').reshape((batch_size, num_steps))
yield (x, y)
buf = [buf[-1]]
try:
buf.append(raw_data.next())
except StopIteration:
break
return __impl__
export CUDA_VISIBLE_DEVICES=0
python train.py \
--train_path='baike/train/sentence_file_*' \
--test_path='baike/dev/sentence_file_*' \
--vocab_path baike/vocabulary_min5k.txt \
--learning_rate 0.2 \
--use_gpu True \
--local True $@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import numpy as np
import time
import os
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
import data
from args import *
import lm_model
import logging
logging.basicConfig()
import pickle
def prepare_batch_input(batch, args):
x = batch['token_ids']
x_r = batch['token_ids_reverse']
y = batch['next_token_id']
y_r = batch['next_token_id_reverse']
inst = []
for i in range(len(x)):
if args.use_custom_samples:
custom_samples_array = np.zeros(
(args.num_steps, args.n_negative_samples_batch + 1),
dtype='int64')
custom_samples_array_r = np.zeros(
(args.num_steps, args.n_negative_samples_batch + 1),
dtype='int64')
custom_probabilities_array = np.zeros(
(args.num_steps, args.n_negative_samples_batch + 1),
dtype='float32')
for j in range(args.num_steps):
for k in range(args.n_negative_samples_batch + 1):
custom_samples_array[j][k] = k
custom_samples_array_r[j][k] = k
custom_probabilities_array[j][k] = 1.0
custom_samples_array[j][0] = y[i][j]
custom_samples_array_r[j][0] = y_r[i][j]
inst.append([
x[i], y[i], x_r[i], y_r[i], custom_samples_array,
custom_samples_array_r, custom_probabilities_array
])
else:
inst.append([x[i], y[i], x_r[i], y_r[i]])
return inst
def batch_reader(batch_list, args):
res = []
for batch in batch_list:
res.append(prepare_batch_input(batch, args))
return res
def read_multiple(reader, batch_size, count, clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def __impl__():
# one time read batch_size * count data for rnn
for data in reader():
inst_num_per_part = batch_size
split_data = {}
len_check = True
for k in data.keys():
if data[k] is not None:
if len(data[k]) != batch_size * count:
len_check = False
print("data check error!!, data=" + data[k] + ", k=" + k)
break
if len_check:
res = []
for i in range(count):
split_data = {}
for k in data.keys():
if data[k] is not None:
split_data[k] = data[k][inst_num_per_part * i:inst_num_per_part * (i + 1)]
res.append(split_data)
yield res
return __impl__
def LodTensor_Array(lod_tensor):
lod = lod_tensor.lod()
array = np.array(lod_tensor)
new_array = []
for i in range(len(lod[0]) - 1):
new_array.append(array[lod[0][i]:lod[0][i + 1]])
return new_array
def get_current_model_para(train_prog, train_exe):
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
vals = {}
for p_name in param_name_list:
p_array = np.array(fluid.global_scope().find_var(p_name).get_tensor())
vals[p_name] = p_array
return vals
def save_para_npz(train_prog, train_exe):
logger.info("begin to save model to model_base")
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
vals = {}
for p_name in param_name_list:
p_array = np.array(fluid.global_scope().find_var(p_name).get_tensor())
vals[p_name] = p_array
emb = vals["embedding_para"]
logger.info("begin to save model to model_base")
np.savez("mode_base", **vals)
def prepare_input(batch, epoch_id=0, with_lr=True):
x, y = batch
inst = []
for i in range(len(x)):
inst.append([x[i], y[i]])
return inst
def eval(vocab, infer_progs, dev_count, logger, args):
infer_prog, infer_startup_prog, infer_model = infer_progs
feed_order = infer_model.feed_order
loss = infer_model.loss
# prepare device
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
if not args.use_gpu:
place = fluid.CPUPlace()
import multiprocessing
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
total_loss = 0.0
total_cnt = 0
n_batch_cnt = 0
n_batch_loss = 0.0
val_feed_list = [
infer_prog.global_block().var(var_name) for var_name in feed_order
]
val_feeder = fluid.DataFeeder(val_feed_list, place)
dev_data = data.BidirectionalLMDataset(
args.test_path, vocab, test=True, shuffle_on_load=False)
dev_data_iter = lambda: dev_data.iter_batches(args.batch_size * dev_count, args.num_steps)
dev_reader = read_multiple(dev_data_iter, args.batch_size, dev_count)
last_hidden_values = np.zeros(
(dev_count, args.num_layers * 2 * args.batch_size * args.embed_size),
dtype='float32')
last_cell_values = np.zeros(
(dev_count, args.num_layers * 2 * args.batch_size * args.hidden_size),
dtype='float32')
for batch_id, batch_list in enumerate(dev_reader(), 1):
feed_data = batch_reader(batch_list, args)
feed = list(val_feeder.feed_parallel(feed_data, dev_count))
for i in range(dev_count):
init_hidden_tensor = fluid.core.LoDTensor()
if args.use_gpu:
placex = fluid.CUDAPlace(i)
else:
placex = fluid.CPUPlace()
init_hidden_tensor.set(last_hidden_values[i], placex)
init_cell_tensor = fluid.core.LoDTensor()
init_cell_tensor.set(last_cell_values[i], placex)
feed[i]['init_hiddens'] = init_hidden_tensor
feed[i]['init_cells'] = init_cell_tensor
last_hidden_values = []
last_cell_values = []
for i in range(dev_count):
val_fetch_outs = exe.run(
program=infer_prog,
feed=feed[i],
fetch_list=[
infer_model.loss.name, infer_model.last_hidden.name,
infer_model.last_cell.name
],
return_numpy=False)
last_hidden_values.append(np.array(val_fetch_outs[1]))
last_cell_values.append(np.array(val_fetch_outs[2]))
total_loss += np.array(val_fetch_outs[0]).sum()
n_batch_cnt += len(np.array(val_fetch_outs[0]))
total_cnt += len(np.array(val_fetch_outs[0]))
n_batch_loss += np.array(val_fetch_outs[0]).sum()
last_hidden_values = np.array(last_hidden_values).reshape((
dev_count, args.num_layers * 2 * args.batch_size * args.embed_size))
last_cell_values = np.array(last_cell_values).reshape(
(dev_count,
args.num_layers * 2 * args.batch_size * args.hidden_size))
log_every_n_batch = args.log_interval
if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
logger.info('Average dev loss from batch {} to {} is {}'.format(
batch_id - log_every_n_batch + 1, batch_id, "%.10f" % (
n_batch_loss / n_batch_cnt)))
n_batch_loss = 0.0
n_batch_cnt = 0
batch_offset = 0
ppl = np.exp(total_loss / total_cnt)
return ppl
def train():
args = parse_args()
if args.random_seed == 0:
args.random_seed = None
print("random seed is None")
if args.enable_ce:
random.seed(args.random_seed)
np.random.seed(args.random_seed)
logger = logging.getLogger("lm")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.info('Running with args : {}'.format(args))
logger.info('Running paddle : {}'.format(paddle.version.commit))
hidden_size = args.hidden_size
batch_size = args.batch_size
data_path = args.data_path
logger.info("begin to load vocab")
vocab = data.Vocabulary(args.vocab_path, validate_file=True)
vocab_size = vocab.size
logger.info("finished load vocab")
logger.info('build the model...')
# build model
train_prog = fluid.Program()
train_startup_prog = fluid.Program()
if args.enable_ce:
train_prog.random_seed = args.random_seed
train_startup_prog.random_seed = args.random_seed
# build infer model
infer_prog = fluid.Program()
infer_startup_prog = fluid.Program()
with fluid.program_guard(infer_prog, infer_startup_prog):
with fluid.unique_name.guard():
# Infer process
infer_model = lm_model.LanguageModel(
args, vocab_size, test_mode=True)
infer_model.build()
infer_progs = infer_prog, infer_startup_prog, infer_model
with fluid.program_guard(train_prog, train_startup_prog):
with fluid.unique_name.guard():
# Training process
train_model = lm_model.LanguageModel(
args, vocab_size, test_mode=False)
train_model.build()
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=args.max_grad_norm))
# build optimizer
if args.optim == 'adagrad':
optimizer = fluid.optimizer.Adagrad(
learning_rate=args.learning_rate,
epsilon=0.0,
initial_accumulator_value=1.0)
elif args.optim == 'sgd':
optimizer = fluid.optimizer.SGD(
learning_rate=args.learning_rate)
elif args.optim == 'adam':
optimizer = fluid.optimizer.Adam(
learning_rate=args.learning_rate)
elif args.optim == 'rprop':
optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=args.learning_rate)
else:
logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1)
optimizer.minimize(train_model.loss * args.num_steps)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
train_progs = train_prog, train_startup_prog, train_model
if args.local:
logger.info("local start_up:")
train_loop(args, logger, vocab, train_progs, infer_progs, optimizer)
else:
if args.update_method == "nccl2":
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
if args.test_nccl:
worker_endpoints_env = os.getenv("PADDLE_WORK_ENDPOINTS")
worker_endpoints = worker_endpoints_env.split(',')
trainers_num = len(worker_endpoints)
current_endpoint = worker_endpoints[trainer_id]
else:
port = os.getenv("PADDLE_PORT")
worker_ips = os.getenv("PADDLE_TRAINERS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
worker_endpoints_env = ','.join(worker_endpoints)
trainers_num = len(worker_endpoints)
current_endpoint = os.getenv("POD_IP") + ":" + port
if trainer_id == 0:
logger.info("train_id == 0, sleep 60s")
time.sleep(60)
logger.info("trainers_num:{}".format(trainers_num))
logger.info("worker_endpoints:{}".format(worker_endpoints))
logger.info("current_endpoint:{}".format(current_endpoint))
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=worker_endpoints_env,
current_endpoint=current_endpoint,
program=train_prog,
startup_program=train_startup_prog)
train_progs = train_prog, train_startup_prog, train_model
train_loop(args, logger, vocab, train_progs, infer_progs, optimizer,
trainers_num, trainer_id, worker_endpoints)
else:
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
logger.info("pserver_endpoints:{}".format(pserver_endpoints))
logger.info("current_endpoint:{}".format(current_endpoint))
logger.info("trainer_id:{}".format(trainer_id))
logger.info("pserver_ips:{}".format(pserver_ips))
logger.info("port:{}".format(port))
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
program=train_prog,
startup_program=startup_prog)
if training_role == "PSERVER":
logger.info("distributed: pserver started")
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
logger.critical("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
logger.info("distributed: trainer started")
trainer_prog = t.get_trainer_program()
train_loop(args, logger, vocab, train_progs, infer_progs,
optimizer)
else:
logger.critical(
"environment var TRAINER_ROLE should be TRAINER os PSERVER")
exit(1)
def init_pretraining_params(exe,
pretraining_params_path,
main_program):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
def train_loop(args,
logger,
vocab,
train_progs,
infer_progs,
optimizer,
nccl2_num_trainers=1,
nccl2_trainer_id=0,
worker_endpoints=None):
train_prog, train_startup_prog, train_model = train_progs
infer_prog, infer_startup_prog, infer_model = infer_progs
# prepare device
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
if not args.use_gpu:
place = fluid.CPUPlace()
import multiprocessing
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
if args.load_dir:
logger.info('load pretrained checkpoints from {}'.format(args.load_dir))
fluid.io.load_persistables(exe, args.load_dir, main_program=train_prog)
elif args.load_pretraining_params:
logger.info('load pretrained params from {}'.format(args.load_pretraining_params))
exe.run(train_startup_prog)
init_pretraining_params(exe, args.load_pretraining_params, main_program=train_prog)
else:
exe.run(train_startup_prog)
# prepare data
feed_list = [
train_prog.global_block().var(var_name)
for var_name in train_model.feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
logger.info('Training the model...')
exe_strategy = fluid.parallel_executor.ExecutionStrategy()
parallel_executor = fluid.ParallelExecutor(
loss_name=train_model.loss.name,
main_program=train_prog,
use_cuda=bool(args.use_gpu),
exec_strategy=exe_strategy,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
logger.info("begin to load data")
train_data = data.BidirectionalLMDataset(
args.train_path,
vocab,
test=(not args.shuffle),
shuffle_on_load=args.shuffle)
logger.info("finished load vocab")
# get train epoch size
log_interval = args.log_interval
total_time = 0.0
batch_size = args.batch_size
hidden_size = args.hidden_size
custom_samples_array = np.zeros(
(batch_size, args.num_steps, args.n_negative_samples_batch + 1),
dtype='int64')
custom_probabilities_array = np.zeros(
(batch_size, args.num_steps, args.n_negative_samples_batch + 1),
dtype='float32')
for i in range(batch_size):
for j in range(0, args.num_steps):
for k in range(0, args.n_negative_samples_batch + 1):
custom_samples_array[i][j][k] = k
custom_probabilities_array[i][j][k] = 1.0
for epoch_id in range(args.max_epoch):
start_time = time.time()
logger.info("epoch id {}".format(epoch_id))
train_data_iter = lambda: train_data.iter_batches(batch_size * dev_count, args.num_steps)
train_reader = read_multiple(train_data_iter, batch_size, dev_count)
total_num = 0
n_batch_loss = 0.0
n_batch_cnt = 0
last_hidden_values = np.zeros(
(dev_count, args.num_layers * 2 * batch_size * args.embed_size),
dtype='float32')
last_cell_values = np.zeros(
(dev_count, args.num_layers * 2 * batch_size * hidden_size),
dtype='float32')
begin_time = time.time()
for batch_id, batch_list in enumerate(train_reader(), 1):
feed_data = batch_reader(batch_list, args)
feed = list(feeder.feed_parallel(feed_data, dev_count))
for i in range(dev_count):
init_hidden_tensor = fluid.core.LoDTensor()
if args.use_gpu:
placex = fluid.CUDAPlace(i)
else:
placex = fluid.CPUPlace()
init_hidden_tensor.set(last_hidden_values[i], placex)
init_cell_tensor = fluid.core.LoDTensor()
init_cell_tensor.set(last_cell_values[i], placex)
feed[i]['init_hiddens'] = init_hidden_tensor
feed[i]['init_cells'] = init_cell_tensor
fetch_outs = parallel_executor.run(
feed=feed,
fetch_list=[
train_model.loss.name, train_model.last_hidden.name,
train_model.last_cell.name
],
return_numpy=False)
cost_train = np.array(fetch_outs[0]).mean()
last_hidden_values = np.array(fetch_outs[1])
last_hidden_values = last_hidden_values.reshape(
(dev_count, args.num_layers * 2 * batch_size * args.embed_size))
last_cell_values = np.array(fetch_outs[2])
last_cell_values = last_cell_values.reshape((
dev_count, args.num_layers * 2 * batch_size * args.hidden_size))
total_num += args.batch_size * dev_count
n_batch_loss += np.array(fetch_outs[0]).sum()
n_batch_cnt += len(np.array(fetch_outs[0]))
if batch_id > 0 and batch_id % log_interval == 0:
smoothed_ppl = np.exp(n_batch_loss / n_batch_cnt)
ppl = np.exp(
np.array(fetch_outs[0]).sum() /
len(np.array(fetch_outs[0])))
used_time = time.time() - begin_time
speed = log_interval / used_time
logger.info(
"[train] epoch:{}, step:{}, loss:{:.3f}, ppl:{:.3f}, smoothed_ppl:{:.3f}, speed:{:.3f}".
format(epoch_id, batch_id, n_batch_loss / n_batch_cnt, ppl,
smoothed_ppl, speed))
n_batch_loss = 0.0
n_batch_cnt = 0
begin_time = time.time()
if batch_id > 0 and batch_id % args.dev_interval == 0:
valid_ppl = eval(vocab, infer_progs, dev_count, logger, args)
logger.info("valid ppl {}".format(valid_ppl))
if batch_id > 0 and batch_id % args.save_interval == 0:
model_path = os.path.join(args.para_save_dir,
str(batch_id + epoch_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_prog)
end_time = time.time()
total_time += end_time - start_time
logger.info("train ppl {}".format(ppl))
if epoch_id == args.max_epoch - 1 and args.enable_ce:
logger.info("lstm_language_model_duration\t%s" %
(total_time / args.max_epoch))
logger.info("lstm_language_model_loss\t%s" % ppl[0])
model_path = os.path.join(args.para_save_dir, str(epoch_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_prog)
valid_ppl = eval(vocab, infer_progs, dev_count, logger, args)
logger.info("valid ppl {}".format(valid_ppl))
test_ppl = eval(vocab, infer_progs, dev_count, logger, args)
logger.info("test ppl {}".format(test_ppl))
if __name__ == '__main__':
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册