未验证 提交 7fae3401 编写于 作者: L LiuChiachi 提交者: GitHub

Update seq2seq example (#5016)

* update seq2seq, using paddlenlp

* Using new paddlenlp API

* update seq2seqREADME

* wrap dev ds

* delete useless comments

* update predict.py

* using paddlenlp.bleu

* remove shard

* update README, using bleu perl

* delete cand

* Remove tokens that make sentences longer than max_len

* remove pdb

* remove useless code.

* update url and dataset name of vae dataset(ptb and yahoo)

* update seq2seq and vae, data and README
上级 33d65b31
...@@ -19,79 +19,55 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) ...@@ -19,79 +19,55 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本目录包含Seq2Seq的一个经典样例:机器翻译,带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html) 本目录包含Seq2Seq的一个经典样例:机器翻译,带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)
运行本目录下的范例模型需要安装PaddlePaddle 2.0版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。 运行本目录下的范例模型需要安装PaddlePaddle 2.0-rc版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
## 模型概览 ## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。 本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。
## 代码下载
克隆代码库到本地,并设置`PYTHONPATH`环境变量
```shell
git clone http://gitlab.baidu.com/PaddleSL/PaddleNLP
cd PaddleNLP
export PYTHONPATH=$PYTHONPATH:`pwd`
cd examples/machine_translation/seq2seq
```
## 数据介绍 ## 数据介绍
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集 本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
### 数据获取 ### 数据获取
如果用户在初始化数据集时没有提供路径,数据集会自动下载到`paddlenlp.utils.env.DATA_HOME``/machine_translation/IWSLT15/`路径下,例如在linux系统下,默认存储路径是`/root/.paddlenlp/datasets/machine_translation/IWSLT15`
```
python download.py
```
## 模型训练 ## 模型训练
执行以下命令即可训练带有注意力机制的Seq2Seq机器翻译模型: 执行以下命令即可训练带有注意力机制的Seq2Seq机器翻译模型:
```sh ```sh
export CUDA_VISIBLE_DEVICES=0
python train.py \ python train.py \
--src_lang en --trg_lang vi \
--num_layers 2 \ --num_layers 2 \
--hidden_size 512 \ --hidden_size 512 \
--batch_size 128 \ --batch_size 128 \
--dropout 0.2 \ --dropout 0.2 \
--init_scale 0.1 \ --init_scale 0.1 \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
--eval_data_prefix data/en-vi/tst2012 \
--test_data_prefix data/en-vi/tst2013 \
--vocab_prefix data/en-vi/vocab \
--use_gpu True \ --use_gpu True \
--model_path ./attention_models --model_path ./attention_models
``` ```
各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。 各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。
**NOTE:** 如需恢复模型训练,则`init_from_ckpt`只需指定到文件名即可,不需要添加文件尾缀。如`--init_from_ckpt=attention_models/5`即可,程序会自动加载模型参数`attention_models/5.pdparams`,也会自动加载优化器状态`attention_models/5.pdopt`
## 模型预测 ## 模型预测
训练完成之后,可以使用保存的模型(由 `--reload_model` 指定)对test的数据集(由 `--infer_file` 指定)进行beam search解码,命令如下: 训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`/root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下:
```sh ```sh
export CUDA_VISIBLE_DEVICES=0
python predict.py \ python predict.py \
--src_lang en --trg_lang vi \
--num_layers 2 \ --num_layers 2 \
--hidden_size 512 \ --hidden_size 512 \
--batch_size 128 \ --batch_size 128 \
--dropout 0.2 \ --dropout 0.2 \
--init_scale 0.1 \ --init_scale 0.1 \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \ --init_from_ckpt attention_models/9 \
--infer_file data/en-vi/tst2013.en \ --infer_target_file /root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \
--reload_model attention_models/9 \
--infer_output_file infer_output.txt \ --infer_output_file infer_output.txt \
--beam_size 10 \ --beam_size 10 \
--use_gpu True --use_gpu True
...@@ -100,7 +76,6 @@ python predict.py \ ...@@ -100,7 +76,6 @@ python predict.py \
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。 各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 效果评价 ## 效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下: 使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
```sh ```sh
...@@ -110,6 +85,6 @@ perl mosesdecoder/scripts/generic/multi-bleu.perl data/en-vi/tst2013.vi < infer_ ...@@ -110,6 +85,6 @@ perl mosesdecoder/scripts/generic/multi-bleu.perl data/en-vi/tst2013.vi < infer_
取第10个epoch保存的模型进行预测,取beam_size=10。效果如下: 取第10个epoch保存的模型进行预测,取beam_size=10。效果如下:
``` ```
tst2013 BLEU: tst2013 BLEU: 24.40
25.36
``` ```
...@@ -17,16 +17,6 @@ import argparse ...@@ -17,16 +17,6 @@ import argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--train_data_prefix", type=str, help="file prefix for train data")
parser.add_argument(
"--eval_data_prefix", type=str, help="file prefix for eval data")
parser.add_argument(
"--test_data_prefix", type=str, help="file prefix for test data")
parser.add_argument(
"--vocab_prefix", type=str, help="file prefix for vocab")
parser.add_argument("--src_lang", type=str, help="source language suffix")
parser.add_argument("--trg_lang", type=str, help="target language suffix")
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
...@@ -45,13 +35,12 @@ def parse_args(): ...@@ -45,13 +35,12 @@ def parse_args():
type=int, type=int,
default=1, default=1,
help="layers number of encoder and decoder") help="layers number of encoder and decoder")
parser.add_argument( parser.add_argument(
"--hidden_size", "--hidden_size",
type=int, type=int,
default=100, default=100,
help="hidden size of encoder and decoder") help="hidden size of encoder and decoder")
parser.add_argument("--src_vocab_size", type=int, help="source vocab size")
parser.add_argument("--trg_vocab_size", type=int, help="target vocab size")
parser.add_argument( parser.add_argument(
"--batch_size", type=int, help="batch size of each step") "--batch_size", type=int, help="batch size of each step")
...@@ -64,13 +53,16 @@ def parse_args(): ...@@ -64,13 +53,16 @@ def parse_args():
type=int, type=int,
default=50, default=50,
help="max length for source and target sentence") help="max length for source and target sentence")
parser.add_argument( parser.add_argument(
"--dropout", type=float, default=0.0, help="drop probability") "--dropout", type=float, default=0.0, help="drop probability")
parser.add_argument( parser.add_argument(
"--init_scale", "--init_scale",
type=float, type=float,
default=0.0, default=0.0,
help="init scale for parameter") help="init scale for parameter")
parser.add_argument( parser.add_argument(
"--max_grad_norm", "--max_grad_norm",
type=float, type=float,
...@@ -90,15 +82,13 @@ def parse_args(): ...@@ -90,15 +82,13 @@ def parse_args():
help="model path for model to save") help="model path for model to save")
parser.add_argument( parser.add_argument(
"--reload_model", type=str, help="reload model to inference") "--infer_target_file", type=str, help="target file name for inference")
parser.add_argument(
"--infer_file", type=str, help="file name for inference")
parser.add_argument( parser.add_argument(
"--infer_output_file", "--infer_output_file",
type=str, type=str,
default='infer_output', default='infer_output',
help="file name for inference output") help="file name for inference output")
parser.add_argument( parser.add_argument(
"--beam_size", type=int, default=10, help="file name for inference") "--beam_size", type=int, default=10, help="file name for inference")
...@@ -108,16 +98,6 @@ def parse_args(): ...@@ -108,16 +98,6 @@ def parse_args():
default=False, default=False,
help='Whether using gpu [True|False]') help='Whether using gpu [True|False]')
parser.add_argument(
"--profile", action='store_true', help="Whether enable the profile.")
# NOTE: profiler args, used for benchmark
parser.add_argument(
"--profiler_path",
type=str,
default='./seq2seq.profile',
help="the profiler output file path. (used for benchmark)")
parser.add_argument( parser.add_argument(
"--init_from_ckpt", "--init_from_ckpt",
type=str, type=str,
......
...@@ -12,402 +12,98 @@ ...@@ -12,402 +12,98 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import io import io
import itertools import os
from functools import partial
from functools import partial
import numpy as np import numpy as np
import paddle
def create_data_loader(args, device, for_train=True):
data_loaders = [None, None]
data_prefixes = [args.train_data_prefix, args.eval_data_prefix
] if args.eval_data_prefix else [args.train_data_prefix]
for i, data_prefix in enumerate(data_prefixes):
dataset = Seq2SeqDataset(
fpattern=data_prefix + "." + args.src_lang,
trg_fpattern=data_prefix + "." + args.trg_lang,
src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
trg_vocab_fpath=args.vocab_prefix + "." + args.trg_lang,
token_delimiter=None,
start_mark="<s>",
end_mark="</s>",
unk_mark="<unk>",
max_length=args.max_len if i == 0 else None,
truncate=True,
trg_add_bos_eos=True)
(args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.batch_size * 20,
sort_type=SortType.POOL,
shuffle_batch=True,
min_length=1,
shuffle=True,
distribute_mode=True if i == 0 else False)
data_loader = paddle.io.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
bos_id=bos_id,
eos_id=eos_id,
pad_id=eos_id),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
return data_loaders, eos_id
def prepare_train_input(insts, bos_id, eos_id, pad_id):
src, src_length = pad_batch_data([inst[0] for inst in insts], pad_id=pad_id)
trg, trg_length = pad_batch_data([inst[1] for inst in insts], pad_id=pad_id)
return src, src_length, trg[:, :-1], trg[:, 1:, np.newaxis]
def prepare_infer_input(insts, bos_id, eos_id, pad_id):
src, src_length = pad_batch_data(insts, pad_id=pad_id)
return src, src_length
def pad_batch_data(insts, pad_id):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
inst_lens = np.array([len(inst) for inst in insts], dtype="int64")
max_len = np.max(inst_lens)
inst_data = np.array(
[inst + [pad_id] * (max_len - len(inst)) for inst in insts],
dtype="int64")
return inst_data, inst_lens
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + ([self._end] if self._add_end else [])
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, fields):
return [
converter(field)
for field, converter in zip(fields, self._converters)
]
class SentenceBatchCreator(object): import paddle
def __init__(self, batch_size): from paddlenlp.data import Vocab, Pad
self.batch = [] from paddlenlp.data import SamplerHelper
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
self.lens = lens
self.max_len = lens[0] # to be consitent with the original reader
self.min_len = lens[0]
def get_ranges(self, min_length=None, max_length=None, truncate=False):
ranges = []
# source
if (min_length is None or self.lens[0] >= min_length) and (
max_length is None or self.lens[0] <= max_length or truncate):
end = max_length if truncate and max_length else self.lens[0]
ranges.append([0, end])
# target
if len(self.lens) == 2:
if (min_length is None or self.lens[1] >= min_length) and (
max_length is None or self.lens[1] <= max_length + 2 or
truncate):
end = max_length + 2 if truncate and max_length else self.lens[
1]
ranges.append([0, end])
return ranges if len(ranges) == len(self.lens) else None
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if (self._min_len is None or info.min_len >= self._min_len) and (
self._max_len is None or info.max_len <= self._max_len):
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class Seq2SeqDataset(paddle.io.Dataset):
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
trg_fpattern=None,
trg_add_bos_eos=False,
min_length=None,
max_length=None,
truncate=False):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self._min_length = min_length
self._max_length = max_length
self._truncate = truncate
self._trg_add_bos_eos = trg_add_bos_eos
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
src_converter = Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False,
add_end=False)
trg_converter = Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True if self._trg_add_bos_eos else False,
add_end=True if self._trg_add_bos_eos else False)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = []
self._trg_seq_ids = []
self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids] from paddlenlp.datasets import IWSLT15
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
fields = converters(line)
lens = [len(field) for field in fields]
sample = SampleInfo(i, lens)
field_ranges = sample.get_ranges(self._min_length, self._max_length,
self._truncate)
if field_ranges:
for field, field_range, slot in zip(fields, field_ranges,
slots):
slot.append(field[field_range[0]:field_range[1]])
self._sample_infos.append(sample)
def _load_lines(self, fpattern, trg_fpattern=None): trans_func_tuple = IWSLT15.get_default_transform_func()
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
if trg_fpattern is None:
for fpath in fpaths:
with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths): def create_train_loader(args):
with io.open(fpath, f_mode, encoding=f_encoding) as f: batch_size = args.batch_size
with io.open( max_len = args.max_len
trg_fpath, f_mode, encoding=f_encoding) as trg_f: src_vocab, tgt_vocab = IWSLT15.get_vocab()
for line in zip(f, trg_f): bos_id = src_vocab[src_vocab.bos_token]
fields = [field.strip(endl) for field in line] eos_id = src_vocab[src_vocab.eos_token]
yield fields pad_id = eos_id
@staticmethod train_ds, dev_ds = IWSLT15.get_datasets(
def load_dict(dict_path, reverse=False): mode=["train", "dev"],
word_dict = {} transform_func=[trans_func_tuple, trans_func_tuple])
(f_mode, f_encoding, endl) = ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip(endl)
else:
word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self): key = (lambda x, data_source: len(data_source[x][0]))
return len(self._src_vocab), len( cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx): train_ds = train_ds.filter(
return (self._src_seq_ids[idx], self._trg_seq_ids[idx] lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
) if self._trg_seq_ids else self._src_seq_ids[idx] dev_ds = dev_ds.filter(
lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)
def __len__(self): dev_batch_sampler = SamplerHelper(dev_ds).sort(
return len(self._sample_infos) key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)
train_loader = paddle.io.DataLoader(
train_ds,
batch_sampler=train_batch_sampler,
collate_fn=partial(
prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
class Seq2SeqBatchSampler(paddle.io.BatchSampler): dev_loader = paddle.io.DataLoader(
def __init__(self, dev_ds,
dataset, batch_sampler=dev_batch_sampler,
batch_size, collate_fn=partial(
pool_size=10000, prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
sort_type=SortType.NONE,
min_length=None,
max_length=None,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# For multi-devices
self._distribute_mode = distribute_mode
self._nranks = paddle.distributed.get_world_size()
self._local_rank = paddle.distributed.get_rank()
def __iter__(self): return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
# Global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(
self._dataset._sample_infos,
key=lambda x: x.max_len,
reverse=True)
else:
if self._shuffle:
infos = self._dataset._sample_infos
self._random.shuffle(infos)
else:
infos = self._dataset._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# To avoid placing short next to long sentences
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
batches = [] def create_infer_loader(args):
batch_creator = TokenBatchCreator( batch_size = args.batch_size
self. max_len = args.max_len
_batch_size) if self._use_token_batch else SentenceBatchCreator( trans_func_tuple = IWSLT15.get_default_transform_func()
self._batch_size * self._nranks) test_ds = IWSLT15.get_datasets(
batch_creator = MinMaxFilter(self._max_length, self._min_length, mode=["test"], transform_func=[trans_func_tuple])
batch_creator) src_vocab, tgt_vocab = IWSLT15.get_vocab()
bos_id = src_vocab[src_vocab.bos_token]
eos_id = src_vocab[src_vocab.eos_token]
pad_id = eos_id
for info in infos: test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0: test_loader = paddle.io.DataLoader(
batches.append(batch_creator.batch) test_ds,
batch_sampler=test_batch_sampler,
collate_fn=partial(
prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
if self._shuffle_batch:
self._random.shuffle(batches)
if not self._use_token_batch: def prepare_infer_input(insts, bos_id, eos_id, pad_id):
# We take them as a whole and shuffle and split here to confirm insts = [([bos_id] + inst[0] + [eos_id], [bos_id] + inst[1] + [eos_id])
# neighbor batches have similar length (for similar computational for inst in insts]
# cost) after shuffling while generating batches according to src, src_length = Pad(pad_val=pad_id, ret_length=True)(
# sequence number. [inst[0] for inst in insts])
batches = [[ return src, src_length
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
# For multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# Use previous data to pad
yield batch_indices
def __len__(self): def prepare_train_input(insts, bos_id, eos_id, pad_id):
if not self._use_token_batch: # Add eos token id and bos token id.
batch_number = ( insts = [([bos_id] + inst[0] + [eos_id], [bos_id] + inst[1] + [eos_id])
len(self._dataset) + self._batch_size * self._nranks - 1) // ( for inst in insts]
self._batch_size * self._nranks) # Pad sequence using eos id.
else: src, src_length = Pad(pad_val=pad_id, ret_length=True)(
# TODO(guosheng): fix the uncertain length [inst[0] for inst in insts])
batch_number = 1 tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)(
return batch_number [inst[1] for inst in insts])
tgt_mask = (tgt[:, :-1] != pad_id).astype("float32")
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Script for downloading training data.
'''
import os
import urllib
import sys
if sys.version_info >= (3, 0):
import urllib.request
import zipfile
URLLIB = urllib
if sys.version_info >= (3, 0):
URLLIB = urllib.request
remote_path = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi'
base_path = 'data'
trg_path = os.path.join(base_path, 'en-vi')
filenames = [
'train.en', 'train.vi', 'tst2012.en', 'tst2012.vi', 'tst2013.en',
'tst2013.vi', 'vocab.en', 'vocab.vi'
]
def main(arguments):
print("Downloading data......")
if not os.path.exists(trg_path):
if not os.path.exists(base_path):
os.mkdir(base_path)
os.mkdir(trg_path)
for filename in filenames:
url = os.path.join(remote_path, filename)
trg_file = os.path.join(trg_path, filename)
URLLIB.urlretrieve(url, trg_file)
print("Downloaded success......")
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
...@@ -13,18 +13,17 @@ ...@@ -13,18 +13,17 @@
# limitations under the License. # limitations under the License.
import io import io
from functools import partial
import numpy as np import numpy as np
import paddle import paddle
from args import parse_args from args import parse_args
from seq2seq_attn import Seq2SeqAttnInferModel from seq2seq_attn import Seq2SeqAttnInferModel
from data import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input from data import create_infer_loader
from paddlenlp.datasets import IWSLT15
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
output_eos=False):
""" """
Post-process the decoded sequence. Post-process the decoded sequence.
""" """
...@@ -43,35 +42,15 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, ...@@ -43,35 +42,15 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
def do_predict(args): def do_predict(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu") device = paddle.set_device("gpu" if args.use_gpu else "cpu")
# Define dataloader test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
dataset = Seq2SeqDataset( args)
fpattern=args.infer_file, _, vocab = IWSLT15.get_vocab()
src_vocab_fpath=args.vocab_prefix + "." + args.src_lang, trg_idx2word = vocab.idx_to_token
trg_vocab_fpath=args.vocab_prefix + "." + args.trg_lang,
token_delimiter=None,
start_mark="<s>",
end_mark="</s>",
unk_mark="<unk>")
trg_idx2word = Seq2SeqDataset.load_dict(
dict_path=args.vocab_prefix + "." + args.trg_lang, reverse=True)
(args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset, use_token_batch=False,
batch_size=args.batch_size) #, min_length=1)
data_loader = paddle.io.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=eos_id),
num_workers=0,
return_list=True)
model = paddle.Model( model = paddle.Model(
Seq2SeqAttnInferModel( Seq2SeqAttnInferModel(
args.src_vocab_size, src_vocab_size,
args.trg_vocab_size, tgt_vocab_size,
args.hidden_size, args.hidden_size,
args.hidden_size, args.hidden_size,
args.num_layers, args.num_layers,
...@@ -84,14 +63,14 @@ def do_predict(args): ...@@ -84,14 +63,14 @@ def do_predict(args):
model.prepare() model.prepare()
# Load the trained model # Load the trained model
assert args.reload_model, ( assert args.init_from_ckpt, (
"Please set reload_model to load the infer model.") "Please set reload_model to load the infer model.")
model.load(args.reload_model) model.load(args.init_from_ckpt)
# TODO(guosheng): use model.predict when support variant length
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f: with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
for data in data_loader(): for data in test_loader():
finished_seq = model.predict_batch(inputs=list(data))[0] with paddle.no_grad():
finished_seq = model.predict_batch(inputs=data)[0]
finished_seq = finished_seq[:, :, np.newaxis] if len( finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape) == 2 else finished_seq finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1]) finished_seq = np.transpose(finished_seq, [0, 2, 1])
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -23,7 +24,7 @@ class CrossEntropyCriterion(nn.Layer): ...@@ -23,7 +24,7 @@ class CrossEntropyCriterion(nn.Layer):
def __init__(self): def __init__(self):
super(CrossEntropyCriterion, self).__init__() super(CrossEntropyCriterion, self).__init__()
def forward(self, predict, trg_mask, label): def forward(self, predict, label, trg_mask):
cost = F.softmax_with_cross_entropy( cost = F.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False) logits=predict, label=label, soft_label=False)
cost = paddle.squeeze(cost, axis=[2]) cost = paddle.squeeze(cost, axis=[2])
...@@ -200,7 +201,6 @@ class Seq2SeqAttnModel(nn.Layer): ...@@ -200,7 +201,6 @@ class Seq2SeqAttnModel(nn.Layer):
(encoder_final_state[0][i], encoder_final_state[1][i]) (encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers) for i in range(self.num_layers)
] ]
# Construct decoder initial states: use input_feed and the shape is # Construct decoder initial states: use input_feed and the shape is
# [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states # [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states
decoder_initial_states = [ decoder_initial_states = [
...@@ -215,8 +215,7 @@ class Seq2SeqAttnModel(nn.Layer): ...@@ -215,8 +215,7 @@ class Seq2SeqAttnModel(nn.Layer):
predict = self.decoder(trg, decoder_initial_states, encoder_output, predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask) encoder_padding_mask)
trg_mask = (trg != self.eos_id).astype(paddle.get_default_dtype()) return predict
return predict, trg_mask
class Seq2SeqAttnInferModel(Seq2SeqAttnModel): class Seq2SeqAttnInferModel(Seq2SeqAttnModel):
......
...@@ -12,88 +12,26 @@ ...@@ -12,88 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from args import parse_args from args import parse_args
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.metric import Metric from paddlenlp.metrics import Perplexity
from seq2seq_attn import Seq2SeqAttnModel, CrossEntropyCriterion from seq2seq_attn import Seq2SeqAttnModel, CrossEntropyCriterion
from data import create_data_loader from data import create_train_loader
class TrainCallback(paddle.callbacks.ProgBarLogger):
def __init__(self, ppl, log_freq, verbose=2):
super(TrainCallback, self).__init__(log_freq, verbose)
self.ppl = ppl
def on_train_begin(self, logs=None):
super(TrainCallback, self).on_train_begin(logs)
self.train_metrics = ["loss", "ppl"]
def on_epoch_begin(self, epoch=None, logs=None):
super(TrainCallback, self).on_epoch_begin(epoch, logs)
self.ppl.reset()
def on_train_batch_end(self, step, logs=None):
logs["ppl"] = self.ppl.cal_acc_ppl(logs["loss"][0], logs["batch_size"])
if step > 0 and step % self.ppl.reset_freq == 0:
self.ppl.reset()
super(TrainCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(TrainCallback, self).on_eval_begin(logs)
self.eval_metrics = ["ppl"]
self.ppl.reset()
def on_eval_batch_end(self, step, logs=None):
logs["ppl"] = self.ppl.cal_acc_ppl(logs["loss"][0], logs["batch_size"])
super(TrainCallback, self).on_eval_batch_end(step, logs)
class Perplexity(Metric):
def __init__(self, reset_freq=100, name=None):
super(Perplexity, self).__init__()
self._name = name or "Perplexity"
self.reset_freq = reset_freq
self.reset()
def compute(self, pred, seq_length, label):
word_num = paddle.sum(seq_length)
return word_num
def update(self, word_num):
self.word_count += word_num
return word_num
def reset(self):
self.total_loss = 0
self.word_count = 0
def accumulate(self):
return self.word_count
def name(self):
return self._name
def cal_acc_ppl(self, batch_loss, batch_size):
self.total_loss += batch_loss * batch_size
ppl = math.exp(self.total_loss / self.word_count)
return ppl
def do_train(args): def do_train(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu") device = paddle.set_device("gpu" if args.use_gpu else "cpu")
# Define dataloader # Define dataloader
(train_loader, eval_loader), eos_id = create_data_loader(args, device) train_loader, eval_loader, src_vocab_size, tgt_vocab_size, eos_id = create_train_loader(
args)
model = paddle.Model( model = paddle.Model(
Seq2SeqAttnModel(args.src_vocab_size, args.trg_vocab_size, Seq2SeqAttnModel(src_vocab_size, tgt_vocab_size, args.hidden_size, args.
args.hidden_size, args.hidden_size, args.num_layers, hidden_size, args.num_layers, args.dropout, eos_id))
args.dropout, eos_id))
grad_clip = nn.ClipGradByGlobalNorm(args.max_grad_norm) grad_clip = nn.ClipGradByGlobalNorm(args.max_grad_norm)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
...@@ -101,7 +39,7 @@ def do_train(args): ...@@ -101,7 +39,7 @@ def do_train(args):
parameters=model.parameters(), parameters=model.parameters(),
grad_clip=grad_clip) grad_clip=grad_clip)
ppl_metric = Perplexity(reset_freq=args.log_freq) ppl_metric = Perplexity()
model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric) model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)
print(args) print(args)
...@@ -115,8 +53,7 @@ def do_train(args): ...@@ -115,8 +53,7 @@ def do_train(args):
eval_freq=1, eval_freq=1,
save_freq=1, save_freq=1,
save_dir=args.model_path, save_dir=args.model_path,
log_freq=args.log_freq, log_freq=args.log_freq)
callbacks=[TrainCallback(ppl_metric, args.log_freq)])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -30,7 +30,8 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) ...@@ -30,7 +30,8 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本教程使用[couplet数据集](https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz)数据集作为训练语料,train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及test_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。 本教程使用[couplet数据集](https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz)数据集作为训练语料,train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及test_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。
数据集会在`CoupletDataset`初始化时自动下载 数据集会在`CoupletDataset`初始化时自动下载,如果用户在初始化数据集时没有提供路径,在linux系统下,数据集会自动下载到`/root/.paddlenlp/datasets/machine_translation/CoupletDataset/`目录下
## 模型训练 ## 模型训练
......
...@@ -78,9 +78,6 @@ def parse_args(): ...@@ -78,9 +78,6 @@ def parse_args():
default=None, default=None,
help="The path of checkpoint to be loaded.") help="The path of checkpoint to be loaded.")
parser.add_argument(
"--infer_file", type=str, help="file name for inference")
parser.add_argument( parser.add_argument(
"--infer_output_file", "--infer_output_file",
type=str, type=str,
......
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
本教程使用了两个文本数据集: 本教程使用了两个文本数据集:
PTB dataset,原始下载地址为: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz PTB数据集由华尔街日报的文章组成,包含929k个训练tokens,词汇量为10k。下载地址为: https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz。
yahoo,原始下载地址为:https://drive.google.com/file/d/13IsiffVjcQ-wrrbBGMwiG3sYf-DFxtXH/view?usp=sharing/ Yahoo数据集来自[(Yang et al., 2017) Improved Variational Autoencoders for Text Modeling using Dilated Convolutions](https://arxiv.org/pdf/1702.08139.pdf),该数据集从原始Yahoo Answer数据中采样100k个文档,数据集的平均文档长度为78,词汇量为200k。下载地址为:https://paddlenlp.bj.bcebos.com/datasets/yahoo-answer-100k.tar.gz
### 数据获取 ### 数据获取
...@@ -91,7 +91,7 @@ python -m paddle.distributed.launch train.py \ ...@@ -91,7 +91,7 @@ python -m paddle.distributed.launch train.py \
## 模型预测 ## 模型预测
当模型训练完成之后,可以选择加载模型保存目录下的第 50 个epoch的模型进行预测,生成batch_size条短文本。如果使用ptb数据集,可以通过下面命令配置: 当模型训练完成之后,可以选择加载模型保存目录下的第 50 个epoch的模型进行预测,生成batch_size条短文本。生成的文本位于参数`infer_output_file`指定的路径下。如果使用ptb数据集,可以通过下面命令配置:
``` ```
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -101,6 +101,7 @@ python predict.py \ ...@@ -101,6 +101,7 @@ python predict.py \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--dataset ptb \ --dataset ptb \
--use_gpu True \ --use_gpu True \
--infer_output_file infer_output.txt \
--init_from_ckpt ptb_model/49 \ --init_from_ckpt ptb_model/49 \
``` ```
......
...@@ -27,12 +27,6 @@ def parse_args(): ...@@ -27,12 +27,6 @@ def parse_args():
default=0.001, default=0.001,
help="Learning rate of optimizer.") help="Learning rate of optimizer.")
parser.add_argument(
"--decay_factor",
type=float,
default=0.5,
help="Decay factor of learning rate.")
parser.add_argument( parser.add_argument(
"--num_layers", "--num_layers",
type=int, type=int,
...@@ -83,12 +77,6 @@ def parse_args(): ...@@ -83,12 +77,6 @@ def parse_args():
default=0., default=0.,
help="Drop probability of encoder.") help="Drop probability of encoder.")
parser.add_argument(
"--word_keep_prob",
type=float,
default=0.5,
help="Word keep probability.")
parser.add_argument( parser.add_argument(
"--init_scale", "--init_scale",
type=float, type=float,
...@@ -122,15 +110,6 @@ def parse_args(): ...@@ -122,15 +110,6 @@ def parse_args():
default=False, default=False,
help='Whether to use gpu [True|False].') help='Whether to use gpu [True|False].')
parser.add_argument(
"--enable_ce",
action='store_true',
help="The flag indicating whether to run the task "
"for continuous evaluation.")
parser.add_argument(
"--profile", action='store_true', help="Whether enable the profile.")
parser.add_argument( parser.add_argument(
"--warm_up", "--warm_up",
type=int, type=int,
...@@ -143,26 +122,6 @@ def parse_args(): ...@@ -143,26 +122,6 @@ def parse_args():
default=0.1, default=0.1,
help='KL start value, up to 1.0.') help='KL start value, up to 1.0.')
parser.add_argument(
"--attr_init",
type=str,
default='normal_initializer',
help="Initializer for paramters.")
parser.add_argument(
"--cache_num", type=int, default=1, help='Cache num for reader.')
parser.add_argument(
"--max_decay",
type=int,
default=5,
help='Max decay tries (if exceeds, early stop).')
parser.add_argument(
"--sort_cache",
action='store_true',
help='Sort cache before batch to accelerate training.')
parser.add_argument( parser.add_argument(
"--init_from_ckpt", "--init_from_ckpt",
type=str, type=str,
......
...@@ -117,7 +117,11 @@ class VAEDataset(paddle.io.Dataset): ...@@ -117,7 +117,11 @@ class VAEDataset(paddle.io.Dataset):
return corpus_ids return corpus_ids
def read_raw_data(self, dataset, max_vocab_cnt=-1): def read_raw_data(self, dataset, max_vocab_cnt=-1):
data_path = os.path.join("data", dataset) if dataset == 'yahoo':
dataset_name = 'yahoo-answer-100k'
else:
dataset_name = os.path.join('simple-examples', 'data')
data_path = os.path.join("data", dataset_name)
train_file = os.path.join(data_path, dataset + ".train.txt") train_file = os.path.join(data_path, dataset + ".train.txt")
valid_file = os.path.join(data_path, dataset + ".valid.txt") valid_file = os.path.join(data_path, dataset + ".valid.txt")
test_file = os.path.join(data_path, dataset + ".test.txt") test_file = os.path.join(data_path, dataset + ".test.txt")
...@@ -282,7 +286,13 @@ def create_data_loader(data_path, ...@@ -282,7 +286,13 @@ def create_data_loader(data_path,
def get_vocab(dataset, batch_size, vocab_file=None, max_sequence_len=50): def get_vocab(dataset, batch_size, vocab_file=None, max_sequence_len=50):
train_dataset = VAEDataset(dataset, batch_size, mode='train') train_dataset = VAEDataset(dataset, batch_size, mode='train')
dataset_prefix = os.path.join("data", dataset) if dataset == 'yahoo':
dataset_name = 'yahoo-answer-100k'
else:
dataset_name = os.path.join('simple-examples', 'data')
dataset_prefix = os.path.join("data", dataset_name)
train_file = os.path.join(dataset_prefix, dataset + ".train.txt") train_file = os.path.join(dataset_prefix, dataset + ".train.txt")
vocab_file = None vocab_file = None
if "yahoo" in dataset: if "yahoo" in dataset:
......
...@@ -12,110 +12,21 @@ ...@@ -12,110 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
'''
Script for downloading training data.
'''
import os import os
import sys import sys
import shutil
import argparse import argparse
import urllib
import tarfile
import urllib.request
import zipfile
URLLIB = urllib.request from paddle.utils.download import get_path_from_url
TASKS = ['ptb', 'yahoo'] TASKS = ['ptb', 'yahoo']
TASK2PATH = { URL = {
'ptb': 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz', 'ptb': 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz',
'yahoo': 'yahoo':
'https://drive.google.com/file/d/13IsiffVjcQ-wrrbBGMwiG3sYf-DFxtXH/view?usp=sharing/yahoo.zip', 'https://paddlenlp.bj.bcebos.com/datasets/yahoo-answer-100k.tar.gz',
} }
def un_tar(tar_name, dir_name):
try:
t = tarfile.open(tar_name)
t.extractall(path=dir_name)
return True
except Exception as e:
print(e)
return False
def un_zip(filepath, dir_name):
z = zipfile.ZipFile(filepath, 'r')
for file in z.namelist():
z.extract(file, dir_name)
def download_ptb_and_extract(task, data_path):
print('Downloading and extracting %s...' % task)
data_file = os.path.join(data_path, TASK2PATH[task].split('/')[-1])
URLLIB.urlretrieve(TASK2PATH[task], data_file)
un_tar(data_file, data_path)
os.remove(data_file)
src_dir = os.path.join(data_path, 'simple-examples')
dst_dir = os.path.join(data_path, 'ptb')
if not os.path.exists(dst_dir):
os.mkdir(dst_dir)
shutil.copy(os.path.join(src_dir, 'data/ptb.train.txt'), dst_dir)
shutil.copy(os.path.join(src_dir, 'data/ptb.valid.txt'), dst_dir)
shutil.copy(os.path.join(src_dir, 'data/ptb.test.txt'), dst_dir)
print('\tCompleted!')
def download_yahoo_dataset(task, data_dir):
url = TASK2PATH[task]
# id is between `/d/` and '/'
url_suffix = url[url.find('/d/') + 3:]
if url_suffix.find('/') == -1:
# if there's no trailing '/'
file_id = url_suffix
else:
file_id = url_suffix[:url_suffix.find('/')]
try:
import requests
except ImportError:
print("The requests library must be installed to download files from "
"Google drive. Please see: https://github.com/psf/requests")
raise
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
gurl = "https://docs.google.com/uc?export=download"
sess = requests.Session()
response = sess.get(gurl, params={'id': file_id}, stream=True)
token = None
for key, value in response.cookies.items():
if key.startswith('download_warning'):
token = value
if token:
params = {'id': file_id, 'confirm': token}
response = sess.get(gurl, params=params, stream=True)
filename = 'yahoo.zip'
filepath = os.path.join(data_dir, filename)
CHUNK_SIZE = 32768
with open(filepath, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk:
f.write(chunk)
un_zip(filepath, data_dir)
os.remove(filepath)
print('Successfully downloaded yahoo')
def main(arguments): def main(arguments):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -131,15 +42,7 @@ def main(arguments): ...@@ -131,15 +42,7 @@ def main(arguments):
type=str, type=str,
default='ptb') default='ptb')
args = parser.parse_args(arguments) args = parser.parse_args(arguments)
get_path_from_url(URL[args.task], args.data_dir)
if not os.path.isdir(args.data_dir):
os.mkdir(args.data_dir)
if args.task == 'yahoo':
if args.data_dir == 'data':
args.data_dir = './'
download_yahoo_dataset(args.task, args.data_dir)
else:
download_ptb_and_extract(args.task, args.data_dir)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册