diff --git a/PaddleNLP/examples/machine_translation/seq2seq/README.md b/PaddleNLP/examples/machine_translation/seq2seq/README.md
index d29ee7d0670975a17b69039e15966619b18cae0c..72c1acb28a14da2a9de8e96543ebb2691850a6ff 100644
--- a/PaddleNLP/examples/machine_translation/seq2seq/README.md
+++ b/PaddleNLP/examples/machine_translation/seq2seq/README.md
@@ -19,88 +19,63 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本目录包含Seq2Seq的一个经典样例:机器翻译,带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)。
-运行本目录下的范例模型需要安装PaddlePaddle 2.0版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
+运行本目录下的范例模型需要安装PaddlePaddle 2.0-rc版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。
-## 代码下载
-
-克隆代码库到本地,并设置`PYTHONPATH`环境变量
-
-```shell
-git clone http://gitlab.baidu.com/PaddleSL/PaddleNLP
-
-cd PaddleNLP
-export PYTHONPATH=$PYTHONPATH:`pwd`
-cd examples/machine_translation/seq2seq
-```
-
## 数据介绍
-本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
+本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集。
### 数据获取
-
-```
-python download.py
-```
+如果用户在初始化数据集时没有提供路径,数据集会自动下载到`paddlenlp.utils.env.DATA_HOME`的`/machine_translation/IWSLT15/`路径下,例如在linux系统下,默认存储路径是`/root/.paddlenlp/datasets/machine_translation/IWSLT15`。
## 模型训练
执行以下命令即可训练带有注意力机制的Seq2Seq机器翻译模型:
```sh
-export CUDA_VISIBLE_DEVICES=0
-
python train.py \
- --src_lang en --trg_lang vi \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
- --train_data_prefix data/en-vi/train \
- --eval_data_prefix data/en-vi/tst2012 \
- --test_data_prefix data/en-vi/tst2013 \
- --vocab_prefix data/en-vi/vocab \
--use_gpu True \
--model_path ./attention_models
+
```
各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。
+**NOTE:** 如需恢复模型训练,则`init_from_ckpt`只需指定到文件名即可,不需要添加文件尾缀。如`--init_from_ckpt=attention_models/5`即可,程序会自动加载模型参数`attention_models/5.pdparams`,也会自动加载优化器状态`attention_models/5.pdopt`。
## 模型预测
-训练完成之后,可以使用保存的模型(由 `--reload_model` 指定)对test的数据集(由 `--infer_file` 指定)进行beam search解码,命令如下:
+训练完成之后,可以使用保存的模型(由 `--init_from_ckpt` 指定)对测试集的数据集进行beam search解码,其中译文数据由 `--infer_target_file` 指定),在linux系统下,默认安装路径为`/root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi`,如果您使用的是Windows系统,需要更改下面的路径。预测命令如下:
```sh
-export CUDA_VISIBLE_DEVICES=0
-
python predict.py \
- --src_lang en --trg_lang vi \
- --num_layers 2 \
- --hidden_size 512 \
- --batch_size 128 \
- --dropout 0.2 \
- --init_scale 0.1 \
- --max_grad_norm 5.0 \
- --vocab_prefix data/en-vi/vocab \
- --infer_file data/en-vi/tst2013.en \
- --reload_model attention_models/9 \
- --infer_output_file infer_output.txt \
- --beam_size 10 \
- --use_gpu True
+ --num_layers 2 \
+ --hidden_size 512 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --init_from_ckpt attention_models/9 \
+ --infer_target_file /root/.paddlenlp/datasets/machine_translation/IWSLT15/iwslt15.en-vi/tst2013.vi \
+ --infer_output_file infer_output.txt \
+ --beam_size 10 \
+ --use_gpu True
```
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。
## 效果评价
-
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
```sh
@@ -110,6 +85,6 @@ perl mosesdecoder/scripts/generic/multi-bleu.perl data/en-vi/tst2013.vi < infer_
取第10个epoch保存的模型进行预测,取beam_size=10。效果如下:
```
-tst2013 BLEU:
-25.36
+tst2013 BLEU: 24.40
+
```
diff --git a/PaddleNLP/examples/machine_translation/seq2seq/args.py b/PaddleNLP/examples/machine_translation/seq2seq/args.py
index c69532d6c6b698397437840b458e2a56b024c252..d3c9092baa658372ab2f11636e25ddd385ad5e3f 100644
--- a/PaddleNLP/examples/machine_translation/seq2seq/args.py
+++ b/PaddleNLP/examples/machine_translation/seq2seq/args.py
@@ -17,16 +17,6 @@ import argparse
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument(
- "--train_data_prefix", type=str, help="file prefix for train data")
- parser.add_argument(
- "--eval_data_prefix", type=str, help="file prefix for eval data")
- parser.add_argument(
- "--test_data_prefix", type=str, help="file prefix for test data")
- parser.add_argument(
- "--vocab_prefix", type=str, help="file prefix for vocab")
- parser.add_argument("--src_lang", type=str, help="source language suffix")
- parser.add_argument("--trg_lang", type=str, help="target language suffix")
parser.add_argument(
"--optimizer",
@@ -45,13 +35,12 @@ def parse_args():
type=int,
default=1,
help="layers number of encoder and decoder")
+
parser.add_argument(
"--hidden_size",
type=int,
default=100,
help="hidden size of encoder and decoder")
- parser.add_argument("--src_vocab_size", type=int, help="source vocab size")
- parser.add_argument("--trg_vocab_size", type=int, help="target vocab size")
parser.add_argument(
"--batch_size", type=int, help="batch size of each step")
@@ -64,13 +53,16 @@ def parse_args():
type=int,
default=50,
help="max length for source and target sentence")
+
parser.add_argument(
"--dropout", type=float, default=0.0, help="drop probability")
+
parser.add_argument(
"--init_scale",
type=float,
default=0.0,
help="init scale for parameter")
+
parser.add_argument(
"--max_grad_norm",
type=float,
@@ -90,15 +82,13 @@ def parse_args():
help="model path for model to save")
parser.add_argument(
- "--reload_model", type=str, help="reload model to inference")
-
- parser.add_argument(
- "--infer_file", type=str, help="file name for inference")
+ "--infer_target_file", type=str, help="target file name for inference")
parser.add_argument(
"--infer_output_file",
type=str,
default='infer_output',
help="file name for inference output")
+
parser.add_argument(
"--beam_size", type=int, default=10, help="file name for inference")
@@ -108,16 +98,6 @@ def parse_args():
default=False,
help='Whether using gpu [True|False]')
- parser.add_argument(
- "--profile", action='store_true', help="Whether enable the profile.")
- # NOTE: profiler args, used for benchmark
-
- parser.add_argument(
- "--profiler_path",
- type=str,
- default='./seq2seq.profile',
- help="the profiler output file path. (used for benchmark)")
-
parser.add_argument(
"--init_from_ckpt",
type=str,
diff --git a/PaddleNLP/examples/machine_translation/seq2seq/data.py b/PaddleNLP/examples/machine_translation/seq2seq/data.py
index 1ad0356693f6d0e1c3410b64b634bd7ab58b4d2a..4dad9b417dade96cab791d78ed3d4402b10d8922 100644
--- a/PaddleNLP/examples/machine_translation/seq2seq/data.py
+++ b/PaddleNLP/examples/machine_translation/seq2seq/data.py
@@ -12,402 +12,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import glob
import io
-import itertools
-from functools import partial
+import os
+from functools import partial
import numpy as np
-import paddle
-
-
-def create_data_loader(args, device, for_train=True):
- data_loaders = [None, None]
- data_prefixes = [args.train_data_prefix, args.eval_data_prefix
- ] if args.eval_data_prefix else [args.train_data_prefix]
- for i, data_prefix in enumerate(data_prefixes):
- dataset = Seq2SeqDataset(
- fpattern=data_prefix + "." + args.src_lang,
- trg_fpattern=data_prefix + "." + args.trg_lang,
- src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
- trg_vocab_fpath=args.vocab_prefix + "." + args.trg_lang,
- token_delimiter=None,
- start_mark="",
- end_mark="",
- unk_mark="",
- max_length=args.max_len if i == 0 else None,
- truncate=True,
- trg_add_bos_eos=True)
- (args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
- unk_id) = dataset.get_vocab_summary()
- batch_sampler = Seq2SeqBatchSampler(
- dataset=dataset,
- use_token_batch=False,
- batch_size=args.batch_size,
- pool_size=args.batch_size * 20,
- sort_type=SortType.POOL,
- shuffle_batch=True,
- min_length=1,
- shuffle=True,
- distribute_mode=True if i == 0 else False)
- data_loader = paddle.io.DataLoader(
- dataset=dataset,
- batch_sampler=batch_sampler,
- places=device,
- collate_fn=partial(
- prepare_train_input,
- bos_id=bos_id,
- eos_id=eos_id,
- pad_id=eos_id),
- num_workers=0,
- return_list=True)
- data_loaders[i] = data_loader
- return data_loaders, eos_id
-
-
-def prepare_train_input(insts, bos_id, eos_id, pad_id):
- src, src_length = pad_batch_data([inst[0] for inst in insts], pad_id=pad_id)
- trg, trg_length = pad_batch_data([inst[1] for inst in insts], pad_id=pad_id)
- return src, src_length, trg[:, :-1], trg[:, 1:, np.newaxis]
-
-
-def prepare_infer_input(insts, bos_id, eos_id, pad_id):
- src, src_length = pad_batch_data(insts, pad_id=pad_id)
- return src, src_length
-
-
-def pad_batch_data(insts, pad_id):
- """
- Pad the instances to the max sequence length in batch, and generate the
- corresponding position data and attention bias.
- """
- inst_lens = np.array([len(inst) for inst in insts], dtype="int64")
- max_len = np.max(inst_lens)
- inst_data = np.array(
- [inst + [pad_id] * (max_len - len(inst)) for inst in insts],
- dtype="int64")
- return inst_data, inst_lens
-
-
-class SortType(object):
- GLOBAL = 'global'
- POOL = 'pool'
- NONE = "none"
-
-
-class Converter(object):
- def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
- self._vocab = vocab
- self._beg = beg
- self._end = end
- self._unk = unk
- self._delimiter = delimiter
- self._add_beg = add_beg
- self._add_end = add_end
-
- def __call__(self, sentence):
- return ([self._beg] if self._add_beg else []) + [
- self._vocab.get(w, self._unk)
- for w in sentence.split(self._delimiter)
- ] + ([self._end] if self._add_end else [])
-
-
-class ComposedConverter(object):
- def __init__(self, converters):
- self._converters = converters
-
- def __call__(self, fields):
- return [
- converter(field)
- for field, converter in zip(fields, self._converters)
- ]
+import paddle
+from paddlenlp.data import Vocab, Pad
+from paddlenlp.data import SamplerHelper
-class SentenceBatchCreator(object):
- def __init__(self, batch_size):
- self.batch = []
- self._batch_size = batch_size
-
- def append(self, info):
- self.batch.append(info)
- if len(self.batch) == self._batch_size:
- tmp = self.batch
- self.batch = []
- return tmp
-
-
-class TokenBatchCreator(object):
- def __init__(self, batch_size):
- self.batch = []
- self.max_len = -1
- self._batch_size = batch_size
-
- def append(self, info):
- cur_len = info.max_len
- max_len = max(self.max_len, cur_len)
- if max_len * (len(self.batch) + 1) > self._batch_size:
- result = self.batch
- self.batch = [info]
- self.max_len = cur_len
- return result
- else:
- self.max_len = max_len
- self.batch.append(info)
-
-
-class SampleInfo(object):
- def __init__(self, i, lens):
- self.i = i
- self.lens = lens
- self.max_len = lens[0] # to be consitent with the original reader
- self.min_len = lens[0]
-
- def get_ranges(self, min_length=None, max_length=None, truncate=False):
- ranges = []
- # source
- if (min_length is None or self.lens[0] >= min_length) and (
- max_length is None or self.lens[0] <= max_length or truncate):
- end = max_length if truncate and max_length else self.lens[0]
- ranges.append([0, end])
- # target
- if len(self.lens) == 2:
- if (min_length is None or self.lens[1] >= min_length) and (
- max_length is None or self.lens[1] <= max_length + 2 or
- truncate):
- end = max_length + 2 if truncate and max_length else self.lens[
- 1]
- ranges.append([0, end])
- return ranges if len(ranges) == len(self.lens) else None
-
-
-class MinMaxFilter(object):
- def __init__(self, max_len, min_len, underlying_creator):
- self._min_len = min_len
- self._max_len = max_len
- self._creator = underlying_creator
-
- def append(self, info):
- if (self._min_len is None or info.min_len >= self._min_len) and (
- self._max_len is None or info.max_len <= self._max_len):
- return self._creator.append(info)
-
- @property
- def batch(self):
- return self._creator.batch
-
-
-class Seq2SeqDataset(paddle.io.Dataset):
- def __init__(self,
- src_vocab_fpath,
- trg_vocab_fpath,
- fpattern,
- field_delimiter="\t",
- token_delimiter=" ",
- start_mark="",
- end_mark="",
- unk_mark="",
- trg_fpattern=None,
- trg_add_bos_eos=False,
- min_length=None,
- max_length=None,
- truncate=False):
- self._src_vocab = self.load_dict(src_vocab_fpath)
- self._trg_vocab = self.load_dict(trg_vocab_fpath)
- self._bos_idx = self._src_vocab[start_mark]
- self._eos_idx = self._src_vocab[end_mark]
- self._unk_idx = self._src_vocab[unk_mark]
- self._field_delimiter = field_delimiter
- self._token_delimiter = token_delimiter
- self._min_length = min_length
- self._max_length = max_length
- self._truncate = truncate
- self._trg_add_bos_eos = trg_add_bos_eos
- self.load_src_trg_ids(fpattern, trg_fpattern)
-
- def load_src_trg_ids(self, fpattern, trg_fpattern=None):
- src_converter = Converter(
- vocab=self._src_vocab,
- beg=self._bos_idx,
- end=self._eos_idx,
- unk=self._unk_idx,
- delimiter=self._token_delimiter,
- add_beg=False,
- add_end=False)
-
- trg_converter = Converter(
- vocab=self._trg_vocab,
- beg=self._bos_idx,
- end=self._eos_idx,
- unk=self._unk_idx,
- delimiter=self._token_delimiter,
- add_beg=True if self._trg_add_bos_eos else False,
- add_end=True if self._trg_add_bos_eos else False)
-
- converters = ComposedConverter([src_converter, trg_converter])
-
- self._src_seq_ids = []
- self._trg_seq_ids = []
- self._sample_infos = []
-
- slots = [self._src_seq_ids, self._trg_seq_ids]
- for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
- fields = converters(line)
- lens = [len(field) for field in fields]
- sample = SampleInfo(i, lens)
- field_ranges = sample.get_ranges(self._min_length, self._max_length,
- self._truncate)
- if field_ranges:
- for field, field_range, slot in zip(fields, field_ranges,
- slots):
- slot.append(field[field_range[0]:field_range[1]])
- self._sample_infos.append(sample)
+from paddlenlp.datasets import IWSLT15
- def _load_lines(self, fpattern, trg_fpattern=None):
- fpaths = glob.glob(fpattern)
- fpaths = sorted(fpaths) # TODO: Add custum sort
- assert len(fpaths) > 0, "no matching file to the provided data path"
+trans_func_tuple = IWSLT15.get_default_transform_func()
- (f_mode, f_encoding, endl) = ("r", "utf8", "\n")
- if trg_fpattern is None:
- for fpath in fpaths:
- with io.open(fpath, f_mode, encoding=f_encoding) as f:
- for line in f:
- fields = line.strip(endl).split(self._field_delimiter)
- yield fields
- else:
- trg_fpaths = glob.glob(trg_fpattern)
- trg_fpaths = sorted(trg_fpaths)
- assert len(fpaths) == len(
- trg_fpaths
- ), "the number of source language data files must equal \
- with that of source language"
- for fpath, trg_fpath in zip(fpaths, trg_fpaths):
- with io.open(fpath, f_mode, encoding=f_encoding) as f:
- with io.open(
- trg_fpath, f_mode, encoding=f_encoding) as trg_f:
- for line in zip(f, trg_f):
- fields = [field.strip(endl) for field in line]
- yield fields
+def create_train_loader(args):
+ batch_size = args.batch_size
+ max_len = args.max_len
+ src_vocab, tgt_vocab = IWSLT15.get_vocab()
+ bos_id = src_vocab[src_vocab.bos_token]
+ eos_id = src_vocab[src_vocab.eos_token]
+ pad_id = eos_id
- @staticmethod
- def load_dict(dict_path, reverse=False):
- word_dict = {}
- (f_mode, f_encoding, endl) = ("r", "utf8", "\n")
- with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
- for idx, line in enumerate(fdict):
- if reverse:
- word_dict[idx] = line.strip(endl)
- else:
- word_dict[line.strip(endl)] = idx
- return word_dict
+ train_ds, dev_ds = IWSLT15.get_datasets(
+ mode=["train", "dev"],
+ transform_func=[trans_func_tuple, trans_func_tuple])
- def get_vocab_summary(self):
- return len(self._src_vocab), len(
- self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
+ key = (lambda x, data_source: len(data_source[x][0]))
+ cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])
- def __getitem__(self, idx):
- return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
- ) if self._trg_seq_ids else self._src_seq_ids[idx]
+ train_ds = train_ds.filter(
+ lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
+ dev_ds = dev_ds.filter(
+ lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
+ train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
+ key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)
- def __len__(self):
- return len(self._sample_infos)
+ dev_batch_sampler = SamplerHelper(dev_ds).sort(
+ key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)
+ train_loader = paddle.io.DataLoader(
+ train_ds,
+ batch_sampler=train_batch_sampler,
+ collate_fn=partial(
+ prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
-class Seq2SeqBatchSampler(paddle.io.BatchSampler):
- def __init__(self,
- dataset,
- batch_size,
- pool_size=10000,
- sort_type=SortType.NONE,
- min_length=None,
- max_length=None,
- shuffle=False,
- shuffle_batch=False,
- use_token_batch=False,
- clip_last_batch=False,
- distribute_mode=True,
- seed=0):
- for arg, value in locals().items():
- if arg != "self":
- setattr(self, "_" + arg, value)
- self._random = np.random
- self._random.seed(seed)
- # For multi-devices
- self._distribute_mode = distribute_mode
- self._nranks = paddle.distributed.get_world_size()
- self._local_rank = paddle.distributed.get_rank()
+ dev_loader = paddle.io.DataLoader(
+ dev_ds,
+ batch_sampler=dev_batch_sampler,
+ collate_fn=partial(
+ prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
- def __iter__(self):
- # Global sort or global shuffle
- if self._sort_type == SortType.GLOBAL:
- infos = sorted(
- self._dataset._sample_infos,
- key=lambda x: x.max_len,
- reverse=True)
- else:
- if self._shuffle:
- infos = self._dataset._sample_infos
- self._random.shuffle(infos)
- else:
- infos = self._dataset._sample_infos
+ return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
- if self._sort_type == SortType.POOL:
- reverse = True
- for i in range(0, len(infos), self._pool_size):
- # To avoid placing short next to long sentences
- infos[i:i + self._pool_size] = sorted(
- infos[i:i + self._pool_size],
- key=lambda x: x.max_len,
- reverse=reverse)
- batches = []
- batch_creator = TokenBatchCreator(
- self.
- _batch_size) if self._use_token_batch else SentenceBatchCreator(
- self._batch_size * self._nranks)
- batch_creator = MinMaxFilter(self._max_length, self._min_length,
- batch_creator)
+def create_infer_loader(args):
+ batch_size = args.batch_size
+ max_len = args.max_len
+ trans_func_tuple = IWSLT15.get_default_transform_func()
+ test_ds = IWSLT15.get_datasets(
+ mode=["test"], transform_func=[trans_func_tuple])
+ src_vocab, tgt_vocab = IWSLT15.get_vocab()
+ bos_id = src_vocab[src_vocab.bos_token]
+ eos_id = src_vocab[src_vocab.eos_token]
+ pad_id = eos_id
- for info in infos:
- batch = batch_creator.append(info)
- if batch is not None:
- batches.append(batch)
+ test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)
- if not self._clip_last_batch and len(batch_creator.batch) != 0:
- batches.append(batch_creator.batch)
+ test_loader = paddle.io.DataLoader(
+ test_ds,
+ batch_sampler=test_batch_sampler,
+ collate_fn=partial(
+ prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
+ return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
- if self._shuffle_batch:
- self._random.shuffle(batches)
- if not self._use_token_batch:
- # We take them as a whole and shuffle and split here to confirm
- # neighbor batches have similar length (for similar computational
- # cost) after shuffling while generating batches according to
- # sequence number.
- batches = [[
- batch[self._batch_size * i:self._batch_size * (i + 1)]
- for i in range(self._nranks)
- ] for batch in batches]
- batches = list(itertools.chain.from_iterable(batches))
+def prepare_infer_input(insts, bos_id, eos_id, pad_id):
+ insts = [([bos_id] + inst[0] + [eos_id], [bos_id] + inst[1] + [eos_id])
+ for inst in insts]
+ src, src_length = Pad(pad_val=pad_id, ret_length=True)(
+ [inst[0] for inst in insts])
+ return src, src_length
- # For multi-device
- for batch_id, batch in enumerate(batches):
- if not self._distribute_mode or (
- batch_id % self._nranks == self._local_rank):
- batch_indices = [info.i for info in batch]
- yield batch_indices
- if self._distribute_mode and len(batches) % self._nranks != 0:
- if self._local_rank >= len(batches) % self._nranks:
- # Use previous data to pad
- yield batch_indices
- def __len__(self):
- if not self._use_token_batch:
- batch_number = (
- len(self._dataset) + self._batch_size * self._nranks - 1) // (
- self._batch_size * self._nranks)
- else:
- # TODO(guosheng): fix the uncertain length
- batch_number = 1
- return batch_number
+def prepare_train_input(insts, bos_id, eos_id, pad_id):
+ # Add eos token id and bos token id.
+ insts = [([bos_id] + inst[0] + [eos_id], [bos_id] + inst[1] + [eos_id])
+ for inst in insts]
+ # Pad sequence using eos id.
+ src, src_length = Pad(pad_val=pad_id, ret_length=True)(
+ [inst[0] for inst in insts])
+ tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)(
+ [inst[1] for inst in insts])
+ tgt_mask = (tgt[:, :-1] != pad_id).astype("float32")
+ return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask
diff --git a/PaddleNLP/examples/machine_translation/seq2seq/download.py b/PaddleNLP/examples/machine_translation/seq2seq/download.py
deleted file mode 100644
index 15bcf7fba19ff9ce1c5eb06b0e07ca7472f1f81b..0000000000000000000000000000000000000000
--- a/PaddleNLP/examples/machine_translation/seq2seq/download.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-'''
-Script for downloading training data.
-'''
-import os
-import urllib
-import sys
-
-if sys.version_info >= (3, 0):
- import urllib.request
-import zipfile
-
-URLLIB = urllib
-if sys.version_info >= (3, 0):
- URLLIB = urllib.request
-
-remote_path = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi'
-base_path = 'data'
-trg_path = os.path.join(base_path, 'en-vi')
-filenames = [
- 'train.en', 'train.vi', 'tst2012.en', 'tst2012.vi', 'tst2013.en',
- 'tst2013.vi', 'vocab.en', 'vocab.vi'
-]
-
-
-def main(arguments):
- print("Downloading data......")
-
- if not os.path.exists(trg_path):
- if not os.path.exists(base_path):
- os.mkdir(base_path)
- os.mkdir(trg_path)
-
- for filename in filenames:
- url = os.path.join(remote_path, filename)
- trg_file = os.path.join(trg_path, filename)
- URLLIB.urlretrieve(url, trg_file)
- print("Downloaded success......")
-
-
-if __name__ == '__main__':
- sys.exit(main(sys.argv[1:]))
diff --git a/PaddleNLP/examples/machine_translation/seq2seq/predict.py b/PaddleNLP/examples/machine_translation/seq2seq/predict.py
index c653c610e9789229859db0a2b622e655632d63a3..839bda0bdd6094c617c655401d1c1dde42850224 100644
--- a/PaddleNLP/examples/machine_translation/seq2seq/predict.py
+++ b/PaddleNLP/examples/machine_translation/seq2seq/predict.py
@@ -13,18 +13,17 @@
# limitations under the License.
import io
-from functools import partial
import numpy as np
import paddle
from args import parse_args
from seq2seq_attn import Seq2SeqAttnInferModel
-from data import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_infer_input
+from data import create_infer_loader
+from paddlenlp.datasets import IWSLT15
-def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
- output_eos=False):
+def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
@@ -43,35 +42,15 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
def do_predict(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu")
- # Define dataloader
- dataset = Seq2SeqDataset(
- fpattern=args.infer_file,
- src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
- trg_vocab_fpath=args.vocab_prefix + "." + args.trg_lang,
- token_delimiter=None,
- start_mark="",
- end_mark="",
- unk_mark="")
- trg_idx2word = Seq2SeqDataset.load_dict(
- dict_path=args.vocab_prefix + "." + args.trg_lang, reverse=True)
- (args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
- unk_id) = dataset.get_vocab_summary()
- batch_sampler = Seq2SeqBatchSampler(
- dataset=dataset, use_token_batch=False,
- batch_size=args.batch_size) #, min_length=1)
- data_loader = paddle.io.DataLoader(
- dataset=dataset,
- batch_sampler=batch_sampler,
- places=device,
- collate_fn=partial(
- prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=eos_id),
- num_workers=0,
- return_list=True)
+ test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
+ args)
+ _, vocab = IWSLT15.get_vocab()
+ trg_idx2word = vocab.idx_to_token
model = paddle.Model(
Seq2SeqAttnInferModel(
- args.src_vocab_size,
- args.trg_vocab_size,
+ src_vocab_size,
+ tgt_vocab_size,
args.hidden_size,
args.hidden_size,
args.num_layers,
@@ -84,14 +63,14 @@ def do_predict(args):
model.prepare()
# Load the trained model
- assert args.reload_model, (
+ assert args.init_from_ckpt, (
"Please set reload_model to load the infer model.")
- model.load(args.reload_model)
+ model.load(args.init_from_ckpt)
- # TODO(guosheng): use model.predict when support variant length
with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
- for data in data_loader():
- finished_seq = model.predict_batch(inputs=list(data))[0]
+ for data in test_loader():
+ with paddle.no_grad():
+ finished_seq = model.predict_batch(inputs=data)[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1])
diff --git a/PaddleNLP/examples/machine_translation/seq2seq/seq2seq_attn.py b/PaddleNLP/examples/machine_translation/seq2seq/seq2seq_attn.py
index ead6692ec52d3f3e8c46f2fd5d50fa7d94c01ffe..f97a101a187fb38d5269387c891e4e8f5e99093e 100644
--- a/PaddleNLP/examples/machine_translation/seq2seq/seq2seq_attn.py
+++ b/PaddleNLP/examples/machine_translation/seq2seq/seq2seq_attn.py
@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
+
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
@@ -23,7 +24,7 @@ class CrossEntropyCriterion(nn.Layer):
def __init__(self):
super(CrossEntropyCriterion, self).__init__()
- def forward(self, predict, trg_mask, label):
+ def forward(self, predict, label, trg_mask):
cost = F.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False)
cost = paddle.squeeze(cost, axis=[2])
@@ -200,7 +201,6 @@ class Seq2SeqAttnModel(nn.Layer):
(encoder_final_state[0][i], encoder_final_state[1][i])
for i in range(self.num_layers)
]
-
# Construct decoder initial states: use input_feed and the shape is
# [[h,c] * num_layers, input_feed], consistent with Seq2SeqDecoderCell.states
decoder_initial_states = [
@@ -215,8 +215,7 @@ class Seq2SeqAttnModel(nn.Layer):
predict = self.decoder(trg, decoder_initial_states, encoder_output,
encoder_padding_mask)
- trg_mask = (trg != self.eos_id).astype(paddle.get_default_dtype())
- return predict, trg_mask
+ return predict
class Seq2SeqAttnInferModel(Seq2SeqAttnModel):
diff --git a/PaddleNLP/examples/machine_translation/seq2seq/train.py b/PaddleNLP/examples/machine_translation/seq2seq/train.py
index 292cf8ffd02e83ef51bf168e56c69435936295e0..834fcf19d51c877f7e15e4bd4f7263f279b3bf30 100644
--- a/PaddleNLP/examples/machine_translation/seq2seq/train.py
+++ b/PaddleNLP/examples/machine_translation/seq2seq/train.py
@@ -12,88 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import math
from args import parse_args
-import numpy as np
import paddle
import paddle.nn as nn
-from paddle.metric import Metric
+from paddlenlp.metrics import Perplexity
from seq2seq_attn import Seq2SeqAttnModel, CrossEntropyCriterion
-from data import create_data_loader
-
-
-class TrainCallback(paddle.callbacks.ProgBarLogger):
- def __init__(self, ppl, log_freq, verbose=2):
- super(TrainCallback, self).__init__(log_freq, verbose)
- self.ppl = ppl
-
- def on_train_begin(self, logs=None):
- super(TrainCallback, self).on_train_begin(logs)
- self.train_metrics = ["loss", "ppl"]
-
- def on_epoch_begin(self, epoch=None, logs=None):
- super(TrainCallback, self).on_epoch_begin(epoch, logs)
- self.ppl.reset()
-
- def on_train_batch_end(self, step, logs=None):
- logs["ppl"] = self.ppl.cal_acc_ppl(logs["loss"][0], logs["batch_size"])
- if step > 0 and step % self.ppl.reset_freq == 0:
- self.ppl.reset()
- super(TrainCallback, self).on_train_batch_end(step, logs)
-
- def on_eval_begin(self, logs=None):
- super(TrainCallback, self).on_eval_begin(logs)
- self.eval_metrics = ["ppl"]
- self.ppl.reset()
-
- def on_eval_batch_end(self, step, logs=None):
- logs["ppl"] = self.ppl.cal_acc_ppl(logs["loss"][0], logs["batch_size"])
- super(TrainCallback, self).on_eval_batch_end(step, logs)
-
-
-class Perplexity(Metric):
- def __init__(self, reset_freq=100, name=None):
- super(Perplexity, self).__init__()
- self._name = name or "Perplexity"
- self.reset_freq = reset_freq
- self.reset()
-
- def compute(self, pred, seq_length, label):
- word_num = paddle.sum(seq_length)
- return word_num
-
- def update(self, word_num):
- self.word_count += word_num
- return word_num
-
- def reset(self):
- self.total_loss = 0
- self.word_count = 0
-
- def accumulate(self):
- return self.word_count
-
- def name(self):
- return self._name
-
- def cal_acc_ppl(self, batch_loss, batch_size):
- self.total_loss += batch_loss * batch_size
- ppl = math.exp(self.total_loss / self.word_count)
- return ppl
+from data import create_train_loader
def do_train(args):
device = paddle.set_device("gpu" if args.use_gpu else "cpu")
# Define dataloader
- (train_loader, eval_loader), eos_id = create_data_loader(args, device)
+ train_loader, eval_loader, src_vocab_size, tgt_vocab_size, eos_id = create_train_loader(
+ args)
model = paddle.Model(
- Seq2SeqAttnModel(args.src_vocab_size, args.trg_vocab_size,
- args.hidden_size, args.hidden_size, args.num_layers,
- args.dropout, eos_id))
+ Seq2SeqAttnModel(src_vocab_size, tgt_vocab_size, args.hidden_size, args.
+ hidden_size, args.num_layers, args.dropout, eos_id))
grad_clip = nn.ClipGradByGlobalNorm(args.max_grad_norm)
optimizer = paddle.optimizer.Adam(
@@ -101,7 +39,7 @@ def do_train(args):
parameters=model.parameters(),
grad_clip=grad_clip)
- ppl_metric = Perplexity(reset_freq=args.log_freq)
+ ppl_metric = Perplexity()
model.prepare(optimizer, CrossEntropyCriterion(), ppl_metric)
print(args)
@@ -115,8 +53,7 @@ def do_train(args):
eval_freq=1,
save_freq=1,
save_dir=args.model_path,
- log_freq=args.log_freq,
- callbacks=[TrainCallback(ppl_metric, args.log_freq)])
+ log_freq=args.log_freq)
if __name__ == "__main__":
diff --git a/PaddleNLP/examples/text_generation/couplet/README.md b/PaddleNLP/examples/text_generation/couplet/README.md
index 3a7b1bc3677ab8008e8a8cf7e432f5efb1dec6d5..c4dcc9bb3eab9e65b80f6fa594ba321ebbf711f4 100644
--- a/PaddleNLP/examples/text_generation/couplet/README.md
+++ b/PaddleNLP/examples/text_generation/couplet/README.md
@@ -30,7 +30,8 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本教程使用[couplet数据集](https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz)数据集作为训练语料,train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及test_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。
-数据集会在`CoupletDataset`初始化时自动下载
+数据集会在`CoupletDataset`初始化时自动下载,如果用户在初始化数据集时没有提供路径,在linux系统下,数据集会自动下载到`/root/.paddlenlp/datasets/machine_translation/CoupletDataset/`目录下
+
## 模型训练
@@ -69,24 +70,24 @@ python predict.py \
## 生成对联样例
-上联:崖悬风雨骤 下联:月落水云寒
+上联:崖悬风雨骤 下联:月落水云寒
-上联:约春章柳下 下联:邀月醉花间
+上联:约春章柳下 下联:邀月醉花间
-上联:箬笠红尘外 下联:扁舟明月中
+上联:箬笠红尘外 下联:扁舟明月中
-上联:书香醉倒窗前月 下联:烛影摇红梦里人
+上联:书香醉倒窗前月 下联:烛影摇红梦里人
-上联:踏雪寻梅求雅趣 下联:临风把酒觅知音
+上联:踏雪寻梅求雅趣 下联:临风把酒觅知音
-上联:未出南阳天下论 下联:先登北斗汉中书
+上联:未出南阳天下论 下联:先登北斗汉中书
-上联:朱联妙语千秋颂 下联:赤胆忠心万代传
+上联:朱联妙语千秋颂 下联:赤胆忠心万代传
-上联:月半举杯圆月下 下联:花间对酒醉花间
+上联:月半举杯圆月下 下联:花间对酒醉花间
-上联:挥笔如剑倚麓山豪气干云揽月去 下联:落笔似龙飞沧海龙吟破浪乘风来
+上联:挥笔如剑倚麓山豪气干云揽月去 下联:落笔似龙飞沧海龙吟破浪乘风来
diff --git a/PaddleNLP/examples/text_generation/couplet/args.py b/PaddleNLP/examples/text_generation/couplet/args.py
index 035f3d92d0cdf678ae5fab187eb970305ea4187b..8cbe0abe750fe1f137742c8d045af9524d483a38 100644
--- a/PaddleNLP/examples/text_generation/couplet/args.py
+++ b/PaddleNLP/examples/text_generation/couplet/args.py
@@ -78,9 +78,6 @@ def parse_args():
default=None,
help="The path of checkpoint to be loaded.")
- parser.add_argument(
- "--infer_file", type=str, help="file name for inference")
-
parser.add_argument(
"--infer_output_file",
type=str,
diff --git a/PaddleNLP/examples/text_generation/vae-seq2seq/README.md b/PaddleNLP/examples/text_generation/vae-seq2seq/README.md
index 7682d4548100c29443d5d57a8aa9bc83984c0f0c..b403a32aefd51cdc9a3089fe4f9cb3a5c1371157 100644
--- a/PaddleNLP/examples/text_generation/vae-seq2seq/README.md
+++ b/PaddleNLP/examples/text_generation/vae-seq2seq/README.md
@@ -23,9 +23,9 @@
本教程使用了两个文本数据集:
-PTB dataset,原始下载地址为: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+PTB数据集由华尔街日报的文章组成,包含929k个训练tokens,词汇量为10k。下载地址为: https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz。
-yahoo,原始下载地址为:https://drive.google.com/file/d/13IsiffVjcQ-wrrbBGMwiG3sYf-DFxtXH/view?usp=sharing/
+Yahoo数据集来自[(Yang et al., 2017) Improved Variational Autoencoders for Text Modeling using Dilated Convolutions](https://arxiv.org/pdf/1702.08139.pdf),该数据集从原始Yahoo Answer数据中采样100k个文档,数据集的平均文档长度为78,词汇量为200k。下载地址为:https://paddlenlp.bj.bcebos.com/datasets/yahoo-answer-100k.tar.gz
### 数据获取
@@ -91,7 +91,7 @@ python -m paddle.distributed.launch train.py \
## 模型预测
-当模型训练完成之后,可以选择加载模型保存目录下的第 50 个epoch的模型进行预测,生成batch_size条短文本。如果使用ptb数据集,可以通过下面命令配置:
+当模型训练完成之后,可以选择加载模型保存目录下的第 50 个epoch的模型进行预测,生成batch_size条短文本。生成的文本位于参数`infer_output_file`指定的路径下。如果使用ptb数据集,可以通过下面命令配置:
```
export CUDA_VISIBLE_DEVICES=0
@@ -101,6 +101,7 @@ python predict.py \
--max_grad_norm 5.0 \
--dataset ptb \
--use_gpu True \
+ --infer_output_file infer_output.txt \
--init_from_ckpt ptb_model/49 \
```
diff --git a/PaddleNLP/examples/text_generation/vae-seq2seq/args.py b/PaddleNLP/examples/text_generation/vae-seq2seq/args.py
index c5d8796db67b97e8588b3129165bd01bdf1a98e0..881748d0a8dc05e0545b58be73a24815a6df9980 100644
--- a/PaddleNLP/examples/text_generation/vae-seq2seq/args.py
+++ b/PaddleNLP/examples/text_generation/vae-seq2seq/args.py
@@ -27,12 +27,6 @@ def parse_args():
default=0.001,
help="Learning rate of optimizer.")
- parser.add_argument(
- "--decay_factor",
- type=float,
- default=0.5,
- help="Decay factor of learning rate.")
-
parser.add_argument(
"--num_layers",
type=int,
@@ -83,12 +77,6 @@ def parse_args():
default=0.,
help="Drop probability of encoder.")
- parser.add_argument(
- "--word_keep_prob",
- type=float,
- default=0.5,
- help="Word keep probability.")
-
parser.add_argument(
"--init_scale",
type=float,
@@ -122,15 +110,6 @@ def parse_args():
default=False,
help='Whether to use gpu [True|False].')
- parser.add_argument(
- "--enable_ce",
- action='store_true',
- help="The flag indicating whether to run the task "
- "for continuous evaluation.")
-
- parser.add_argument(
- "--profile", action='store_true', help="Whether enable the profile.")
-
parser.add_argument(
"--warm_up",
type=int,
@@ -143,26 +122,6 @@ def parse_args():
default=0.1,
help='KL start value, up to 1.0.')
- parser.add_argument(
- "--attr_init",
- type=str,
- default='normal_initializer',
- help="Initializer for paramters.")
-
- parser.add_argument(
- "--cache_num", type=int, default=1, help='Cache num for reader.')
-
- parser.add_argument(
- "--max_decay",
- type=int,
- default=5,
- help='Max decay tries (if exceeds, early stop).')
-
- parser.add_argument(
- "--sort_cache",
- action='store_true',
- help='Sort cache before batch to accelerate training.')
-
parser.add_argument(
"--init_from_ckpt",
type=str,
diff --git a/PaddleNLP/examples/text_generation/vae-seq2seq/data.py b/PaddleNLP/examples/text_generation/vae-seq2seq/data.py
index 6656a3cc75347afc152ff4169b3645022f665f96..6c346810b4f9648b06e5d96b08d31050b2f05892 100644
--- a/PaddleNLP/examples/text_generation/vae-seq2seq/data.py
+++ b/PaddleNLP/examples/text_generation/vae-seq2seq/data.py
@@ -117,7 +117,11 @@ class VAEDataset(paddle.io.Dataset):
return corpus_ids
def read_raw_data(self, dataset, max_vocab_cnt=-1):
- data_path = os.path.join("data", dataset)
+ if dataset == 'yahoo':
+ dataset_name = 'yahoo-answer-100k'
+ else:
+ dataset_name = os.path.join('simple-examples', 'data')
+ data_path = os.path.join("data", dataset_name)
train_file = os.path.join(data_path, dataset + ".train.txt")
valid_file = os.path.join(data_path, dataset + ".valid.txt")
test_file = os.path.join(data_path, dataset + ".test.txt")
@@ -282,7 +286,13 @@ def create_data_loader(data_path,
def get_vocab(dataset, batch_size, vocab_file=None, max_sequence_len=50):
train_dataset = VAEDataset(dataset, batch_size, mode='train')
- dataset_prefix = os.path.join("data", dataset)
+ if dataset == 'yahoo':
+ dataset_name = 'yahoo-answer-100k'
+ else:
+ dataset_name = os.path.join('simple-examples', 'data')
+
+ dataset_prefix = os.path.join("data", dataset_name)
+
train_file = os.path.join(dataset_prefix, dataset + ".train.txt")
vocab_file = None
if "yahoo" in dataset:
diff --git a/PaddleNLP/examples/text_generation/vae-seq2seq/download.py b/PaddleNLP/examples/text_generation/vae-seq2seq/download.py
index 97e1114b86b748678f1cfb46b98b4ef05b29a026..9e78dbfcb89a5e6af796b9ce4d538eba23a477bd 100644
--- a/PaddleNLP/examples/text_generation/vae-seq2seq/download.py
+++ b/PaddleNLP/examples/text_generation/vae-seq2seq/download.py
@@ -12,110 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-'''
-Script for downloading training data.
-'''
+
import os
import sys
-import shutil
import argparse
-import urllib
-import tarfile
-import urllib.request
-import zipfile
-URLLIB = urllib.request
+from paddle.utils.download import get_path_from_url
TASKS = ['ptb', 'yahoo']
-TASK2PATH = {
- 'ptb': 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz',
+URL = {
+ 'ptb': 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz',
'yahoo':
- 'https://drive.google.com/file/d/13IsiffVjcQ-wrrbBGMwiG3sYf-DFxtXH/view?usp=sharing/yahoo.zip',
+ 'https://paddlenlp.bj.bcebos.com/datasets/yahoo-answer-100k.tar.gz',
}
-def un_tar(tar_name, dir_name):
- try:
- t = tarfile.open(tar_name)
- t.extractall(path=dir_name)
- return True
- except Exception as e:
- print(e)
- return False
-
-
-def un_zip(filepath, dir_name):
- z = zipfile.ZipFile(filepath, 'r')
- for file in z.namelist():
- z.extract(file, dir_name)
-
-
-def download_ptb_and_extract(task, data_path):
- print('Downloading and extracting %s...' % task)
- data_file = os.path.join(data_path, TASK2PATH[task].split('/')[-1])
- URLLIB.urlretrieve(TASK2PATH[task], data_file)
- un_tar(data_file, data_path)
- os.remove(data_file)
- src_dir = os.path.join(data_path, 'simple-examples')
- dst_dir = os.path.join(data_path, 'ptb')
- if not os.path.exists(dst_dir):
- os.mkdir(dst_dir)
- shutil.copy(os.path.join(src_dir, 'data/ptb.train.txt'), dst_dir)
- shutil.copy(os.path.join(src_dir, 'data/ptb.valid.txt'), dst_dir)
- shutil.copy(os.path.join(src_dir, 'data/ptb.test.txt'), dst_dir)
- print('\tCompleted!')
-
-
-def download_yahoo_dataset(task, data_dir):
- url = TASK2PATH[task]
- # id is between `/d/` and '/'
- url_suffix = url[url.find('/d/') + 3:]
- if url_suffix.find('/') == -1:
- # if there's no trailing '/'
- file_id = url_suffix
- else:
- file_id = url_suffix[:url_suffix.find('/')]
-
- try:
- import requests
- except ImportError:
- print("The requests library must be installed to download files from "
- "Google drive. Please see: https://github.com/psf/requests")
- raise
-
- def _get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
-
- gurl = "https://docs.google.com/uc?export=download"
- sess = requests.Session()
- response = sess.get(gurl, params={'id': file_id}, stream=True)
-
- token = None
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- token = value
-
- if token:
- params = {'id': file_id, 'confirm': token}
- response = sess.get(gurl, params=params, stream=True)
-
- filename = 'yahoo.zip'
- filepath = os.path.join(data_dir, filename)
- CHUNK_SIZE = 32768
- with open(filepath, "wb") as f:
- for chunk in response.iter_content(CHUNK_SIZE):
- if chunk:
- f.write(chunk)
-
- un_zip(filepath, data_dir)
- os.remove(filepath)
-
- print('Successfully downloaded yahoo')
-
-
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -131,15 +42,7 @@ def main(arguments):
type=str,
default='ptb')
args = parser.parse_args(arguments)
-
- if not os.path.isdir(args.data_dir):
- os.mkdir(args.data_dir)
- if args.task == 'yahoo':
- if args.data_dir == 'data':
- args.data_dir = './'
- download_yahoo_dataset(args.task, args.data_dir)
- else:
- download_ptb_and_extract(args.task, args.data_dir)
+ get_path_from_url(URL[args.task], args.data_dir)
if __name__ == '__main__':