提交 a95d4932 编写于 作者: S sserdoubleh 提交者: Yibing Liu

Update Dialog-PLATO: Support paddlepaddle1.6. Release PLATO w/o latent. (#3931)

* Upload mode: Dialogue-BLATO.

* Update README.md.

* Update Dialog-PLATO: Support APIs in paddlepaddle 1.6 and more features. Release PLATO w/o latent.
上级 2c5bf11a
...@@ -2,19 +2,25 @@ ...@@ -2,19 +2,25 @@
**PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable** **PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable**
[paper link](http://arxiv.org/abs/1910.07931) [paper link](http://arxiv.org/abs/1910.07931)
**\*\*\*\*\* Update \*\*\*\*\***
Nov. 14: Support new APIs in paddlepaddle 1.6.0 (model files in the link have been updated accordingly), multi-GPU training and decoding strategy of top-k sampling. Release our baseline model `PLATO w/o latent`.
## Requirements ## Requirements
``` ```
- python >= 3.6 - python >= 3.6
- paddlepaddle >= 1.5.2 - paddlepaddle >= 1.6.0
- numpy - numpy
- nltk - nltk
- tqdm - tqdm
- visualdl >= 1.3.0 (optional) - visualdl >= 1.3.0 (optional)
- regex
``` ```
## Pre-trained dialogue generation model ## Pre-trained dialogue generation model
A novel pre-training model for dialogue generation is introduced in this work, incorporated with latent discrete variables for one-to-many relationship modeling. Our model is flexible enough to support various kinds of conversations, including chit-chat, knowledge grounded dialogues, and conversational question answering. The pre-training is carried out with Reddit and Twitter corpora. You can download the uncased pre-trained model from: A novel pre-training model for dialogue generation is introduced in this work, incorporated with latent discrete variables for one-to-many relationship modeling. Our model is flexible enough to support various kinds of conversations, including chit-chat, knowledge grounded dialogues, and conversational question answering. The pre-training is carried out with Reddit and Twitter corpora. You can download the uncased pre-trained model from:
* PLATO, uncased [model](https://baidu-nlp.bj.bcebos.com/PLATO/model.tar.gz): 12-layers, 768-hidden, 12-heads, 132M parameters * PLATO, uncased [model](https://baidu-nlp.bj.bcebos.com/PLATO/model.tar.gz): 12-layers, 768-hidden, 12-heads, 132M parameters
* PLATO w/o latent, uncased [model](https://baidu-nlp.bj.bcebos.com/PLATO/model-baseline.tar.gz): 12-layers 768-hidden, 12-heads, 109M parameters
```bash ```bash
mv /path/to/model.tar.gz . mv /path/to/model.tar.gz .
...@@ -26,19 +32,19 @@ We also provide instructions to fine-tune PLATO on different conversation datase ...@@ -26,19 +32,19 @@ We also provide instructions to fine-tune PLATO on different conversation datase
### Data preparation ### Data preparation
Download data from the [link](https://baidu-nlp.bj.bcebos.com/PLATO/data.tar.gz). Download data from the [link](https://baidu-nlp.bj.bcebos.com/PLATO/data.tar.gz).
The tar file contains three processed datasets: DailyDialog, PersonaChat and DSTC7_AVSD. The tar file contains three processed datasets: `DailyDialog`, `PersonaChat` and `DSTC7_AVSD`.
```bash ```bash
mv /path/to/data.tar.gz . mv /path/to/data.tar.gz .
tar xzf data.tar.gz tar xzf data.tar.gz
``` ```
### Data format ### Data format
Our model supports two kinds of data formats for dialogue context: "multi" and "multi_knowledge". Our model supports two kinds of data formats for dialogue context: `multi` and `multi_knowledge`.
* multi: multi-turn dialogue context. * `multi`: multi-turn dialogue context.
```txt ```txt
u_1 __eou__ u_2 __eou__ ... u_n \t r u_1 __eou__ u_2 __eou__ ... u_n \t r
``` ```
* multi_knowledge: multi-turn dialogue context with background knowledge. * `multi_knowledge`: multi-turn dialogue context with background knowledges.
```txt ```txt
k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r
``` ```
...@@ -46,7 +52,7 @@ k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r ...@@ -46,7 +52,7 @@ k_1 __eou__ k_2 __eou__ ... k_m \t u_1 __eou__ u_2 __eou__ ... u_n \t r
If you want to use this model on other datasets, you can process your data accordingly. If you want to use this model on other datasets, you can process your data accordingly.
### Train ### Train
Fine-tuning the pre-trained model on different ${DATASET}. Fine-tuning the pre-trained model on different `${DATASET}`.
```bash ```bash
# DailyDialog / PersonaChat / DSTC7_AVSD # DailyDialog / PersonaChat / DSTC7_AVSD
DATASET=DailyDialog DATASET=DailyDialog
...@@ -54,11 +60,24 @@ sh scripts/${DATASET}/train.sh ...@@ -54,11 +60,24 @@ sh scripts/${DATASET}/train.sh
``` ```
After training, you can find the output folder `outputs/${DATASET}` (by default). It contatins `best.model` (best results on validation dataset), `hparams.json` (hyper-parameters of training script) and `trainer.log` (training log). After training, you can find the output folder `outputs/${DATASET}` (by default). It contatins `best.model` (best results on validation dataset), `hparams.json` (hyper-parameters of training script) and `trainer.log` (training log).
Fine-tuning the pre-trained model on multiple GPUs.
Note: You need to install NCCL library and set up the environment variable `LD_LIBRARY` properly.
```bash
sh scripts/DailyDialog/multi_gpu_train.sh
```
You can fine-tune PLATO w/o latent on different `${DATASET}`. We provide an example script on DailyDialog dataset.
```bash
sh scripts/DailyDialog/baseline_train.sh
```
#### Recommended settings #### Recommended settings
For the fine-tuning of our pre-trained model, it usually requires about 10 epochs to reach convergence with learning rate = 1e-5 and about 2-3 epochs to reach convergence with learning rate = 5e-5. For the fine-tuning of our pre-trained model, it usually requires about 10 epochs to reach convergence with learning rate = 1e-5 and about 2-3 epochs to reach convergence with learning rate = 5e-5.
GPU_MEM | batch_size | max_len GPU Memory | batch size | max len
------|------|------ ------|------|------
16G | 6 | 256 16G | 6 | 256
32G | 12 | 256 32G | 12 | 256
...@@ -69,9 +88,17 @@ Running inference on test dataset. ...@@ -69,9 +88,17 @@ Running inference on test dataset.
# DailyDialog / PersonaChat / DSTC7_AVSD # DailyDialog / PersonaChat / DSTC7_AVSD
DATASET=DailyDialog DATASET=DailyDialog
sh scripts/${DATASET}/infer.sh sh scripts/${DATASET}/infer.sh
# Running inference of PLATO w/o latent
sh scripts/DailyDialog/baseline_infer.sh
``` ```
After inference, you can find the output foler `outputs/${DATASET}.infer` (by default). It contains `infer_0.result.json` (the inference result), `hparams.json` (hyper-parameters of inference scipt) and `trainer.log` (inference log). After inference, you can find the output foler `outputs/${DATASET}.infer` (by default). It contains `infer_0.result.json` (the inference result), `hparams.json` (hyper-parameters of inference scipt) and `trainer.log` (inference log).
If you want to use top-k sampling (beam search by default), you can follow the example script:
```bash
sh scripts/DailyDialog/topk_infer.sh
```
## Result ## Result
### DailyDialog ### DailyDialog
...@@ -79,37 +106,41 @@ Model | BLEU-1/2 | Distinct-1/2 | Fluency | Coherence | Informativeness | Overal ...@@ -79,37 +106,41 @@ Model | BLEU-1/2 | Distinct-1/2 | Fluency | Coherence | Informativeness | Overal
------|------|------|------|------|------|------- ------|------|------|------|------|------|-------
Seq2Seq | 0.336/0.268 | 0.030/0.128 | 1.85 | 0.37 | 0.44 | 0.33 Seq2Seq | 0.336/0.268 | 0.030/0.128 | 1.85 | 0.37 | 0.44 | 0.33
iVAE_MI | 0.309/0.249 | 0.029/0.250 | 1.53 | 0.34 | 0.59 | 0.30 iVAE_MI | 0.309/0.249 | 0.029/0.250 | 1.53 | 0.34 | 0.59 | 0.30
Our w/o Latent | 0.405/0.322 | 0.046/0.246 | 1.91 | 1.58 | 1.03 | 1.44 Our w/o Latent | **0.405/0.322** | 0.046/0.246 | 1.91 | **1.58** | 1.03 | 1.44
Our Method | 0.352/0.275 | 0.045/0.253 | 1.97 | 1.57 | 1.23 | 1.48 Our Method | 0.397/0.311 | **0.053/0.291** | **1.97** | 1.57 | **1.23** | **1.48**
### PersonaChat ### PersonaChat
Model | BLEU-1/2 | Distinct-1/2 | Knowledge R/P/F1 | Fluency | Coherence | Informativeness | Overall Model | BLEU-1/2 | Distinct-1/2 | Knowledge R/P/F1 | Fluency | Coherence | Informativeness | Overall
------|------|------|------|------|------|-------|------- ------|------|------|------|------|------|-------|-------
Seq2Seq | 0.448/0.353 | 0.004/0.016 | 0.004/0.016/0.006 | 1.82 | 0.37 | 0.85 | 0.34 Seq2Seq | 0.448/0.353 | 0.004/0.016 | 0.004/0.016/0.006 | 1.82 | 0.37 | 0.85 | 0.34
LIC | 0.405/0.320 | 0.019/0.113 | 0.042/0.154/0.064 | 1.95 | 1.34 | 1.09 | 1.29 LIC | 0.405/0.320 | 0.019/0.113 | 0.042/0.154/0.064 | 1.95 | 1.34 | 1.09 | 1.29
Our w/o Latent | 0.458/0.357 | 0.012/0.064 | 0.085/0.263/0.125 | 1.98 | 1.36 | 1.04 | 1.30 Our w/o Latent | **0.458/0.357** | 0.012/0.064 | 0.085/0.263/0.125 | 1.98 | 1.36 | 1.04 | 1.30
Our Method | 0.418/0.324 | 0.014/0.081 | 0.162/0.542/0.242 | 1.99 | 1.51 | 1.70 | 1.50 Our Method | 0.406/0.315 | **0.021/0.121** | **0.142/0.461/0.211** | **1.99** | **1.51** | **1.70** | **1.50**
### DSTC7_AVSD ### DSTC7_AVSD
Model | BELU-1 | BELU-2 | BLEU-3 | BLEU-4 | METEOR | ROUGH-L | CIDEr Model | BELU-1 | BELU-2 | BLEU-3 | BLEU-4 | METEOR | ROUGH-L | CIDEr
------|------|------|------|------|------|-------|------- ------|------|------|------|------|------|-------|-------
Baseline | 0.629 | 0.485 | 0.383 | 0.309 | 0.215 | 0.487 | 0.746 Baseline | 0.629 | 0.485 | 0.383 | 0.309 | 0.215 | 0.487 | 0.746
CMU | 0.718 | 0.584 | 0.478 | 0.394 | 0.267 | 0.563 | 1.094 CMU | 0.718 | 0.584 | 0.478 | 0.394 | 0.267 | 0.563 | 1.094
Our Method | 0.784 | 0.637 | 0.525 | 0.435 | 0.286 | 0.596 | 1.209 Our Method | **0.784** | **0.637** | **0.525** | **0.435** | **0.286** | **0.596** | **1.209**
Our Method Upper Bound | 0.925 | 0.843 | 0.767 | 0.689 | 0.361 | 0.731 | 1.716 Our Method Upper Bound | 0.925 | 0.843 | 0.767 | 0.689 | 0.361 | 0.731 | 1.716
Note: In the experiments on DSTC_AVSD, the response selection of our method is strengthened with an extra ranking step, which ranks the candidates according to the automatic scores and selects the top one as the final answer. Note: In the experiments on `DSTC7_AVSD`, the response selection of our method is strengthened with an extra ranking step, which ranks the candidates according to the automatic scores and selects the top one as the final answer.
## Citation ## Citation
If you find PLATO useful in your work, please cite the following Arxiv paper: If you find PLATO useful in your work, please cite the following Arxiv paper:
``` ```
@article{bao2019plato, @article{bao2019plato,
title={PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable}, title={PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable},
author={Bao, Siqi and He, Huang, Wang, Fan and Wu, Hua}, author={Bao, Siqi and He, Huang and Wang, Fan and Wu, Hua and Wang, Haifeng},
journal={arXiv preprint arXiv:1910.07931}, journal={arXiv preprint arXiv:1910.07931},
year={2019} year={2019}
} }
``` ```
## Disclaimer
This project aims to facilitate further research progress in dialogue generation. Baidu is not responsible for the 3rd party's generation with the pre-trained system.
## Contact information ## Contact information
For help or issues using PLATO, please submit a GitHub issue. For help or issues using PLATO, please submit a GitHub issue.
......
...@@ -56,7 +56,7 @@ class HParams(dict): ...@@ -56,7 +56,7 @@ class HParams(dict):
params_dict = json.load(fp) params_dict = json.load(fp)
for k, v in params_dict.items(): for k, v in params_dict.items():
if isinstance(v, dict): if isinstance(v, dict):
self[k] = HParams(v) self[k].update(HParams(v))
else: else:
self[k] = v self[k] = v
......
...@@ -20,9 +20,11 @@ import math ...@@ -20,9 +20,11 @@ import math
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.batch import paddle.batch
from args import str2bool from plato.args import str2bool
from sampler import RandomSampler from plato.data.sampler import RandomSampler
from sampler import SequentialSampler from plato.data.sampler import SequentialSampler
from plato.data.sampler import SortedSampler
import plato.modules.parallel as parallel
class DataLoader(object): class DataLoader(object):
...@@ -31,11 +33,13 @@ class DataLoader(object): ...@@ -31,11 +33,13 @@ class DataLoader(object):
@classmethod @classmethod
def add_cmdline_argument(cls, group): def add_cmdline_argument(cls, group):
group.add_argument("--shuffle", type=str2bool, default=True) group.add_argument("--shuffle", type=str2bool, default=True)
group.add_argument("--sort_pool_size", type=int, default=0)
return group return group
def __init__(self, dataset, hparams, collate_fn=None, sampler=None, is_test=False): def __init__(self, dataset, hparams, collate_fn=None, sampler=None, is_test=False, is_train=False):
self.dataset = dataset self.dataset = dataset
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.sort_pool_size = hparams.sort_pool_size
if sampler is None: if sampler is None:
if hparams.shuffle and not is_test: if hparams.shuffle and not is_test:
...@@ -43,6 +47,9 @@ class DataLoader(object): ...@@ -43,6 +47,9 @@ class DataLoader(object):
else: else:
sampler = SequentialSampler(dataset) sampler = SequentialSampler(dataset)
if self.sort_pool_size > 0 and not is_test:
sampler = SortedSampler(sampler, self.sort_pool_size)
def reader(): def reader():
for idx in sampler: for idx in sampler:
yield idx yield idx
...@@ -50,7 +57,7 @@ class DataLoader(object): ...@@ -50,7 +57,7 @@ class DataLoader(object):
self.reader = paddle.batch(reader, batch_size=hparams.batch_size, drop_last=False) self.reader = paddle.batch(reader, batch_size=hparams.batch_size, drop_last=False)
self.num_batches = math.ceil(len(dataset) / hparams.batch_size) self.num_batches = math.ceil(len(dataset) / hparams.batch_size)
if hparams.use_data_distributed: if hparams.use_data_distributed and parallel.Env().nranks > 1 and is_train:
self.reader = fluid.contrib.reader.distributed_batch_reader(self.reader) self.reader = fluid.contrib.reader.distributed_batch_reader(self.reader)
self.num_batches = self.num_batches // fluid.dygraph.parallel.Env().nranks self.num_batches = self.num_batches // fluid.dygraph.parallel.Env().nranks
......
...@@ -22,8 +22,8 @@ import pickle ...@@ -22,8 +22,8 @@ import pickle
import time import time
from tqdm import tqdm from tqdm import tqdm
from tokenizer import Tokenizer from plato.args import str2bool
from args import str2bool from plato.data.tokenizer import Tokenizer
def max_lens(X): def max_lens(X):
...@@ -77,21 +77,26 @@ class BPETextField(object): ...@@ -77,21 +77,26 @@ class BPETextField(object):
group.add_argument("--max_knowledge_num", type=int, default=16, group.add_argument("--max_knowledge_num", type=int, default=16,
help="The maximum number of knowledges.") help="The maximum number of knowledges.")
group.add_argument("--max_knowledge_len", type=int, default=16, group.add_argument("--max_knowledge_len", type=int, default=16,
help="The maximum length of each knowledges") help="The maximum length of each knowledges.")
group.add_argument("--tokenizer_type", type=str, default="Bert",
choices=["Bert", "GPT2"],
help="The type of tokenizer.")
return group return group
def __init__(self, hparam): def __init__(self, hparams):
special_tokens = [self.pad_token, self.bos_token, self.eos_token, self.unk_token] special_tokens = [self.pad_token, self.bos_token, self.eos_token, self.unk_token]
self.tokenizer = Tokenizer(vocab_path=hparam.vocab_path, special_tokens=special_tokens) self.tokenizer = Tokenizer(vocab_path=hparams.vocab_path,
special_tokens=special_tokens,
self.filtered = hparam.filtered tokenizer_type=hparams.tokenizer_type)
self.max_len = hparam.max_len
self.min_utt_len = hparam.min_utt_len self.filtered = hparams.filtered
self.max_utt_len = hparam.max_utt_len self.max_len = hparams.max_len
self.min_ctx_turn = hparam.min_ctx_turn self.min_utt_len = hparams.min_utt_len
self.max_ctx_turn = hparam.max_ctx_turn - 1 # subtract reply turn self.max_utt_len = hparams.max_utt_len
self.max_knowledge_num = hparam.max_knowledge_num self.min_ctx_turn = hparams.min_ctx_turn
self.max_knowledge_len = hparam.max_knowledge_len self.max_ctx_turn = hparams.max_ctx_turn - 1 # subtract reply turn
self.max_knowledge_num = hparams.max_knowledge_num
self.max_knowledge_len = hparams.max_knowledge_len
return return
@property @property
...@@ -187,6 +192,27 @@ class BPETextField(object): ...@@ -187,6 +192,27 @@ class BPETextField(object):
return self.min_ctx_turn <= len(utts) \ return self.min_ctx_turn <= len(utts) \
and (not self.filtered or len(utts) <= self.max_ctx_turn) and (not self.filtered or len(utts) <= self.max_ctx_turn)
def build_example_multi_turn(self, req):
examples = []
src = [self.tokenizer.tokenize(s) for s in req["context"]]
src = [s[-self.max_utt_len:] for s in src[-self.max_ctx_turn:]]
src = [self.numericalize(s) + [self.eos_id] for s in src]
ex = {"src": src}
examples.append(ex)
return examples
def build_example_multi_turn_with_knowledge(self, req):
examples = []
src = [self.tokenizer.tokenize(s) for s in req["context"]]
src = [s[-self.max_utt_len:] for s in src[-self.max_ctx_turn:]]
src = [self.numericalize(s) + [self.eos_id] for s in src]
knowledge = [self.tokenizer.tokenize(k) for k in req["knowledge"]]
knowledge = [k[:self.max_knowledge_len] for k in knowledge]
knowledge = [self.numericalize(k) + [self.eos_id] for k in knowledge]
ex = {"src": src, "knowledge": knowledge}
examples.append(ex)
return examples
def build_examples_multi_turn(self, data_file, data_type="train"): def build_examples_multi_turn(self, data_file, data_type="train"):
print(f"Reading examples from '{data_file}' ...") print(f"Reading examples from '{data_file}' ...")
examples = [] examples = []
...@@ -212,7 +238,7 @@ class BPETextField(object): ...@@ -212,7 +238,7 @@ class BPETextField(object):
print(f"Built {len(examples)} {data_type.upper()} examples ({ignored} filtered)") print(f"Built {len(examples)} {data_type.upper()} examples ({ignored} filtered)")
return examples return examples
def build_examples_multi_turn_with_knoledge(self, data_file, data_type="train"): def build_examples_multi_turn_with_knowledge(self, data_file, data_type="train"):
print(f"Reading examples from '{data_file}' ...") print(f"Reading examples from '{data_file}' ...")
examples = [] examples = []
ignored = 0 ignored = 0
......
...@@ -47,10 +47,43 @@ class RandomSampler(Sampler): ...@@ -47,10 +47,43 @@ class RandomSampler(Sampler):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
self.epoch = 0
return return
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
def __iter__(self): def __iter__(self):
np.random.seed(self.epoch)
self.epoch += 1
return iter(np.random.permutation(len(self))) return iter(np.random.permutation(len(self)))
class SortedSampler(Sampler):
""" Sorted Sampler.
Sort each block of examples by key.
"""
def __init__(self, sampler, sort_pool_size, key="src"):
self.sampler = sampler
self.sort_pool_size = sort_pool_size
self.key = lambda idx: len(self.sampler.dataset[idx][key])
return
def __len__(self):
return len(self.sampler)
def __iter__(self):
pool = []
for idx in self.sampler:
pool.append(idx)
if len(pool) == self.sort_pool_size:
pool = sorted(pool, key=self.key)
for i in pool:
yield i
pool = []
if len(pool) > 0:
pool = sorted(pool, key=self.key)
for i in pool:
yield i
...@@ -18,8 +18,11 @@ Tokenizer class. ...@@ -18,8 +18,11 @@ Tokenizer class.
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import json
import logging import logging
import os import os
import regex as re
import sys
import unicodedata import unicodedata
...@@ -41,40 +44,71 @@ def clean_string(string): ...@@ -41,40 +44,71 @@ def clean_string(string):
class Tokenizer(object): class Tokenizer(object):
def __init__(self, vocab_path, special_tokens=[]): def __init__(self, vocab_path, special_tokens=[], tokenizer_type="Bert"):
self.spec_convert_dict = {"[BOS]": "[unused0]", "[EOS]": "[unused1]"} self.tokenizer_type = tokenizer_type
self.spec_revert_dict = {v: k for k, if tokenizer_type == "Bert":
v in self.spec_convert_dict.items()} self.spec_convert_dict = {"[BOS]": "[unused0]", "[EOS]": "[unused1]"}
special_tokens = [self.spec_convert_dict.get(tok, tok) self.spec_revert_dict = {v: k for k,
for tok in special_tokens] v in self.spec_convert_dict.items()}
self.special_tokens = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]") special_tokens = [self.spec_convert_dict.get(tok, tok)
self.special_tokens += tuple(x for x in special_tokens if x not in self.special_tokens) for tok in special_tokens]
self.special_tokens = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")
self._tokenizer = BertTokenizer(vocab_path, never_split=self.special_tokens) self.special_tokens += tuple(x for x in special_tokens if x not in self.special_tokens)
for tok in self.special_tokens:
assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" self._tokenizer = BertTokenizer(vocab_path, never_split=self.special_tokens)
for tok in self.special_tokens:
self.vocab_size = len(self._tokenizer.vocab) assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary"
self.vocab_size = len(self._tokenizer.vocab)
elif tokenizer_type == "GPT2":
self.spec_convert_dict = {"[UNK]": "<unk>"}
self.spec_revert_dict = {v: k for k,
v in self.spec_convert_dict.items()}
special_tokens = [tok for tok in special_tokens
if tok not in self.spec_convert_dict]
vocab_file = os.path.join(vocab_path, "vocab.json")
merges_file = os.path.join(vocab_path, "merges.txt")
self._tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=special_tokens)
self.num_specials = len(special_tokens)
self.vocab_size = len(self._tokenizer)
else:
raise ValueError
def tokenize(self, text): def tokenize(self, text):
return self._tokenizer.tokenize(text) return self._tokenizer.tokenize(text)
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens] if self.tokenizer_type == "Bert":
ids = self._tokenizer.convert_tokens_to_ids(tokens) tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens]
return ids ids = self._tokenizer.convert_tokens_to_ids(tokens)
return ids
else:
tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens]
ids = self._tokenizer.convert_tokens_to_ids(tokens)
ids = [(i + self.num_specials) % self.vocab_size for i in ids]
return ids
def convert_ids_to_tokens(self, ids): def convert_ids_to_tokens(self, ids):
tokens = self._tokenizer.convert_ids_to_tokens(ids) if self.tokenizer_type == "Bert":
tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens] tokens = self._tokenizer.convert_ids_to_tokens(ids)
return tokens tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens]
return tokens
else:
ids = [(i - self.num_specials) % self.vocab_size for i in ids]
tokens = self._tokenizer.convert_ids_to_tokens(ids)
tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens]
return tokens
def decode(self, ids, ignore_tokens=[]): def decode(self, ids, ignore_tokens=[]):
tokens = self.convert_ids_to_tokens(ids) tokens = self.convert_ids_to_tokens(ids)
if len(ignore_tokens) > 0: if len(ignore_tokens) > 0:
ignore_tokens = set(ignore_tokens) ignore_tokens = set(ignore_tokens)
tokens = [tok for tok in tokens if tok not in ignore_tokens] tokens = [tok for tok in tokens if tok not in ignore_tokens]
string = " ".join(tokens).replace(" ##", "") if self.tokenizer_type == "Bert":
string = " ".join(tokens).replace(" ##", "")
else:
string = "".join(tokens)
string = bytearray([self._tokenizer.byte_decoder[c]
for c in string]).decode("utf-8")
string = clean_string(string) string = clean_string(string)
return string return string
...@@ -400,3 +434,195 @@ def _is_punctuation(char): ...@@ -400,3 +434,195 @@ def _is_punctuation(char):
if cat.startswith("P"): if cat.startswith("P"):
return True return True
return False return False
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[ord(b)] for b in token if ord(b) in self.byte_encoder)
if token == '':
continue
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
...@@ -16,6 +16,7 @@ MetricsTracker class ...@@ -16,6 +16,7 @@ MetricsTracker class
""" """
from collections import defaultdict from collections import defaultdict
import math
class MetricsTracker(object): class MetricsTracker(object):
...@@ -66,6 +67,9 @@ class MetricsTracker(object): ...@@ -66,6 +67,9 @@ class MetricsTracker(object):
for key, val in self.metrics_val.items(): for key, val in self.metrics_val.items():
metric_str = f"{key.upper()}-{val:.3f}" metric_str = f"{key.upper()}-{val:.3f}"
metric_strs.append(metric_str) metric_strs.append(metric_str)
if "token_nll" in self.metrics_val:
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}"
metric_strs.append(metric_str)
metric_strs = " ".join(metric_strs) metric_strs = " ".join(metric_strs)
return metric_strs return metric_strs
...@@ -74,5 +78,8 @@ class MetricsTracker(object): ...@@ -74,5 +78,8 @@ class MetricsTracker(object):
for key, val in self.metrics_avg.items(): for key, val in self.metrics_avg.items():
metric_str = f"{key.upper()}-{val:.3f}" metric_str = f"{key.upper()}-{val:.3f}"
metric_strs.append(metric_str) metric_strs.append(metric_str)
if "token_nll" in self.metrics_avg:
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}"
metric_strs.append(metric_str)
metric_strs = " ".join(metric_strs) metric_strs = " ".join(metric_strs)
return metric_strs return metric_strs
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Loading models.
"""
import plato.models.unified_transformer
...@@ -15,13 +15,17 @@ ...@@ -15,13 +15,17 @@
Generator class. Generator class.
""" """
import bisect
import math
import sys
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from args import str2bool from plato.args import str2bool
import modules.functions as F import plato.modules.functions as F
def repeat(var, times): def repeat(var, times):
...@@ -56,34 +60,256 @@ def gather(var, idx): ...@@ -56,34 +60,256 @@ def gather(var, idx):
return var return var
class BeamSearch(object): class Generator(object):
""" Genrator class. """
_registry = dict()
@classmethod
def register(cls, name):
Generator._registry[name] = cls
return
@staticmethod
def by_name(name):
return Generator._registry[name]
@staticmethod
def create(hparams, *args, **kwargs):
""" Create generator. """
generator_cls = Generator.by_name(hparams.generator)
return generator_cls(hparams, *args, **kwargs)
@classmethod @classmethod
def add_cmdline_argument(cls, parser): def add_cmdline_argument(cls, parser):
group = parser.add_argument_group("Generator") group = parser.add_argument_group("Generator")
group.add_argument("--beam_size", type=int, default=5, group.add_argument("--generator", type=str, default="BeamSearch",
help="The beam size in beam search.") choices=["TopKSampling", "TopPSampling", "GreedySampling",
"BeamSearch"])
group.add_argument("--min_gen_len", type=int, default=1, group.add_argument("--min_gen_len", type=int, default=1,
help="The minimum length of generated response.") help="The minimum length of generated response.")
group.add_argument("--max_gen_len", type=int, default=30, group.add_argument("--max_gen_len", type=int, default=30,
help="The maximum length of generated response.") help="The maximum length of generated response.")
group.add_argument("--length_average", type=str2bool, default=False, args, _ = parser.parse_known_args()
help="Whether to use length average.") generator_cls = cls.by_name(args.generator)
group.add_argument("--ignore_unk", type=str2bool, default=True, generator_cls.add_cmdline_argument(group)
help="Whether to ignore unkown token in generation.")
return group return group
def __init__(self, bpe, hparams): def __init__(self, hparams, bpe):
self.vocab_size = bpe.vocab_size self.vocab_size = bpe.vocab_size
self.bos_id = bpe.bos_id self.bos_id = bpe.bos_id
self.eos_id = bpe.eos_id self.eos_id = bpe.eos_id
self.unk_id = bpe.unk_id self.unk_id = bpe.unk_id
self.pad_id = bpe.pad_id self.pad_id = bpe.pad_id
self.beam_size = hparams.beam_size
self.min_gen_len = hparams.min_gen_len self.min_gen_len = hparams.min_gen_len
assert self.min_gen_len >= 1
self.max_gen_len = hparams.max_gen_len self.max_gen_len = hparams.max_gen_len
assert 1 <= self.min_gen_len <= self.max_gen_len
return
def __call__(self, step_fn, state):
"""
Running generation.
@param : step_fn : decoding one step
@type : function
@param : state : initial state
@type : dict
"""
raise NotImplementedError
class Sampling(Generator):
""" Sampling Generator. """
@classmethod
def add_cmdline_argument(cls, group):
group.add_argument("--ignore_unk", type=str2bool, default=True,
help="Whether to ignore unkown token in generation.")
group.add_argument("--sampling_temperature", type=float, default=1.0)
return group
def __init__(self, hparams, bpe):
super().__init__(hparams, bpe)
self.ignore_unk = hparams.ignore_unk
self.temperature = hparams.sampling_temperature
return
def _sampling(self, scores):
""" Sampling function. """
raise NotImplementedError
def __call__(self, step_fn, state):
"""
Running generation.
@param : step_fn : decoding one step
@type : function
@param : state : initial state
@type : dict
"""
batch_size = state["batch_size"]
vocab_size = self.vocab_size
pos_index = layers.range(0, batch_size, 1, dtype="int64")
pos_index = layers.scale(pos_index, vocab_size)
# shape: [batch_size, beam_size, 1]
predictions = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=self.bos_id)
sequence_scores = layers.fill_constant(shape=[batch_size],
dtype="float32",
value=0.0)
unk_penalty = np.zeros(vocab_size, dtype="float32")
unk_penalty[self.unk_id] = -1e10
unk_penalty = layers.assign(unk_penalty)
eos_penalty = np.zeros(vocab_size, dtype="float32")
eos_penalty[self.eos_id] = -1e10
eos_penalty = layers.assign(eos_penalty)
scores_after_end = np.full(vocab_size, -1e10, dtype="float32")
scores_after_end[self.pad_id] = 0
scores_after_end = layers.assign(scores_after_end)
# initial input
for step in range(1, self.max_gen_len + 1):
pre_ids = predictions[:, -1:]
state["pred_token"] = F.unsqueeze(pre_ids, [2])
if step > 1:
state["pred_mask"] = 1 - F.equal(state["pred_token"], self.pad_id)
state["pred_pos"] = state["pred_pos"] + 1
scores, state = step_fn(state)
# Generate next
# scores shape: [batch_size, vocab_size]
if self.ignore_unk:
scores = scores + unk_penalty
if step <= self.min_gen_len:
scores = scores + eos_penalty
# previous token is [PAD] or [EOS]
# shape: [batch_size, 1]
pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(pre_ids, self.pad_id)
scores = scores * (1 - pre_eos_mask) + \
layers.expand(pre_eos_mask, [1, vocab_size]) * scores_after_end
scores = scores / self.temperature
preds = self._sampling(scores)
predictions = layers.concat([predictions, F.unsqueeze(preds, [1])], axis=1)
scores = layers.reshape(scores, [batch_size * vocab_size])
preds = preds + pos_index
scores = gather(scores, preds)
sequence_scores = sequence_scores + scores
results = {
"preds": predictions,
"scores": sequence_scores
}
return results
class GreedySampling(Sampling):
""" Greedy sampling. """
@classmethod
def add_cmdline_argument(cls, group):
return Sampling.add_cmdline_argument(group)
def _sampling(self, logits):
""" Implement greedy sampling. """
preds = layers.argmax(logits, axis=1)
return preds
class TopKSampling(Sampling):
""" Top-k sampling. """
@classmethod
def add_cmdline_argument(cls, group):
Sampling.add_cmdline_argument(group)
group.add_argument("--top_k_ratio", type=float, default=None)
group.add_argument("--top_k_num", type=int, default=None)
return group
def __init__(self, hparams, bpe):
super().__init__(hparams, bpe)
assert hparams.top_k_ratio is not None or hparams.top_k_num is not None
if hparams.top_k_num is not None:
self.top_k_num = hparams.top_k_num
else:
self.top_k_num = math.floor(hparams.top_k_ratio * self.vocab_size)
assert self.top_k_num >= 1
return
def _sampling(self, logits):
""" Implement top-k sampling. """
probs = layers.softmax(logits, axis=1)
probs, indices = layers.topk(probs, self.top_k_num)
probs = probs / layers.reduce_sum(probs, dim=1, keep_dim=True)
preds = []
for p, ids in zip(probs.numpy(), indices.numpy()):
o = np.random.choice(ids, p=p)
preds.append(o)
preds = np.array(preds, dtype="int64")
return fluid.dygraph.to_variable(preds)
class TopPSampling(Sampling):
""" Top-p sampling. """
@classmethod
def add_cmdline_argument(cls, group):
Sampling.add_cmdline_argument(group)
group.add_argument("--top_p_ratio", type=float, default=1.0)
return group
def __init__(self, hparams, bpe):
super().__init__(hparams, bpe)
self.top_p_ratio = hparams.top_p_ratio
return
def _sampling(self, logits):
""" Implement top-k sampling. """
probs = layers.softmax(logits, axis=1)
preds = []
for p in probs.numpy():
ids = np.argsort(-p)
p = p[ids]
c_p = np.cumsum(p)
i = bisect.bisect_right(c_p, self.top_p_ratio) + 1
o = np.random.choice(ids[:i], p=p[:i]/np.sum(p[:i]))
preds.append(o)
preds = np.array(preds, dtype="int64")
return fluid.dygraph.to_variable(preds)
class BeamSearch(Generator):
""" BeamSearch generator. """
@classmethod
def add_cmdline_argument(cls, group):
group.add_argument("--beam_size", type=int, default=5,
help="The beam size in beam search.")
group.add_argument("--length_average", type=str2bool, default=False,
help="Whether to use length average.")
group.add_argument("--length_penalty", type=float, default=-1.0,
help="The parameter(alpha) of length penalty.")
group.add_argument("--ignore_unk", type=str2bool, default=True,
help="Whether to ignore unkown token in generation.")
return group
def __init__(self, hparams, bpe):
super().__init__(hparams, bpe)
self.beam_size = hparams.beam_size
self.length_average = hparams.length_average self.length_average = hparams.length_average
self.length_penalty = hparams.length_penalty
self.ignore_unk = hparams.ignore_unk self.ignore_unk = hparams.ignore_unk
return return
...@@ -159,21 +385,25 @@ class BeamSearch(object): ...@@ -159,21 +385,25 @@ class BeamSearch(object):
# previous token is [PAD] or [EOS] # previous token is [PAD] or [EOS]
pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(pre_ids, self.pad_id) pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal(pre_ids, self.pad_id)
scores = scores * (1 - pre_eos_mask) + \ scores = scores * (1 - pre_eos_mask) + \
layers.expand(pre_eos_mask, [1, 1, self.vocab_size]) * scores_after_end layers.expand(pre_eos_mask, [1, 1, self.vocab_size]) * scores_after_end
node_scores, node_preds = layers.topk(scores, beam_size)
if self.length_average: if self.length_average:
sequence_scores = layers.scale(sequence_scores, (step - 1.0) / step) scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step)
scores = layers.scale(scores, 1.0 / step) sequence_scores = F.unsqueeze(sequence_scores, [2]) * scaled_value
scores = layers.elementwise_add(scores, sequence_scores, axis=0) scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step)
else: scores = scores * scaled_value
scores = layers.elementwise_add(scores, sequence_scores, axis=0) elif self.length_penalty >= 0.0:
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
(math.pow((4 + step) / (5 + step), self.length_penalty))
sequence_scores = layers.elementwise_mul(scaled_value, sequence_scores, axis=0)
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \
(math.pow(1 / (5 + step), self.length_penalty))
scores = scores * scaled_value
scores = layers.elementwise_add(scores, sequence_scores, axis=0)
scores = layers.reshape(scores, shape=[batch_size, beam_size * self.vocab_size]) scores = layers.reshape(scores, shape=[batch_size, beam_size * self.vocab_size])
topk_scores, topk_indices = layers.topk(scores, self.beam_size) topk_scores, topk_indices = layers.topk(scores, beam_size)
vocab_size = layers.fill_constant(shape=[1], dtype="int64", value=self.vocab_size) vocab_size = layers.fill_constant(shape=[1], dtype="int64", value=self.vocab_size)
parent_idx = layers.elementwise_floordiv(topk_indices, vocab_size) parent_idx = layers.elementwise_floordiv(topk_indices, vocab_size)
preds = layers.elementwise_mod(topk_indices, vocab_size) preds = layers.elementwise_mod(topk_indices, vocab_size)
...@@ -208,3 +438,8 @@ class BeamSearch(object): ...@@ -208,3 +438,8 @@ class BeamSearch(object):
"scores": sequence_scores[:, -1] "scores": sequence_scores[:, -1]
} }
return results return results
BeamSearch.register("BeamSearch")
GreedySampling.register("GreedySampling")
TopKSampling.register("TopKSampling")
TopPSampling.register("TopPSampling")
...@@ -23,14 +23,39 @@ class ModelBase(fluid.dygraph.Layer): ...@@ -23,14 +23,39 @@ class ModelBase(fluid.dygraph.Layer):
""" """
Basic model wrapper for static graph and dygrpah. Basic model wrapper for static graph and dygrpah.
""" """
_registry = dict()
@classmethod
def register(cls, name):
ModelBase._registry[name] = cls
return
@staticmethod
def by_name(name):
return ModelBase._registry[name]
@staticmethod
def create(name_scope, hparams, *args, **kwargs):
model_cls = ModelBase.by_name(hparams.model)
return model_cls(name_scope, hparams, *args, **kwargs)
@classmethod @classmethod
def add_cmdline_argument(cls, parser): def add_cmdline_argument(cls, parser):
""" Add cmdline argument. """ """ Add cmdline argument. """
group = parser.add_argument_group("Model") group = parser.add_argument_group("Model")
group.add_argument("--init_checkpoint", type=str, default=None) group.add_argument("--init_checkpoint", type=str, default=None)
group.add_argument("--model", type=str, default="UnifiedTransformer",
choices=["UnifiedTransformer"])
args, _ = parser.parse_known_args()
model_cls = ModelBase.by_name(args.model)
model_cls.add_cmdline_argument(group)
return group return group
def __init__(self, name_scope, hparams):
super().__init__(name_scope)
self.init_checkpoint = hparams.init_checkpoint
return
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
""" Re-implement __call__ function in dygraph mode. """ """ Re-implement __call__ function in dygraph mode. """
if not self._built: if not self._built:
......
...@@ -16,11 +16,11 @@ Embedder class. ...@@ -16,11 +16,11 @@ Embedder class.
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import modules.functions as F
from paddle.fluid.dygraph import Embedding from paddle.fluid.dygraph import Embedding
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
import paddle.fluid.layers as layers
import plato.modules.functions as F
class Embedder(Layer): class Embedder(Layer):
......
...@@ -16,11 +16,11 @@ FeedForward class. ...@@ -16,11 +16,11 @@ FeedForward class.
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import modules.functions as F
from paddle.fluid.dygraph import FC from paddle.fluid.dygraph import FC
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
import paddle.fluid.layers as layers
import plato.modules.functions as F
class FeedForward(Layer): class FeedForward(Layer):
......
...@@ -22,6 +22,7 @@ import paddle.fluid.layers as layers ...@@ -22,6 +22,7 @@ import paddle.fluid.layers as layers
def unsqueeze(input, axes): def unsqueeze(input, axes):
""" Implement unsqueeze in dygraph mode. """ """ Implement unsqueeze in dygraph mode. """
# return layers.unsqueeze(input, axes)
# op:unsqueeze has bug in dygraph # op:unsqueeze has bug in dygraph
axes = [axis if axis >= 0 else axis + len(input.shape) + 1 for axis in axes] axes = [axis if axis >= 0 else axis + len(input.shape) + 1 for axis in axes]
axes = sorted(axes, reverse=True) axes = sorted(axes, reverse=True)
...@@ -33,8 +34,9 @@ def unsqueeze(input, axes): ...@@ -33,8 +34,9 @@ def unsqueeze(input, axes):
def gumbel_softmax(input, tau=1, eps=1e-10): def gumbel_softmax(input, tau=1, eps=1e-10):
""" Basic implement of gumbel_softmax. """ """ Basic implement of gumbel_softmax. """
U = layers.uniform_random(input.shape, dtype=input.dtype, min=0.0, max=1.0) U = fluid.dygraph.to_variable(np.random.rand(*input.shape))
U.stop_gradient = True # U = layers.uniform_random(input.shape, dtype=input.dtype, min=0.0, max=1.0)
# U.stop_gradient = True
gumbel = 0.0 - layers.log(eps - layers.log(U + eps)) gumbel = 0.0 - layers.log(eps - layers.log(U + eps))
y = input + gumbel y = input + gumbel
return layers.softmax(y / tau) return layers.softmax(y / tau)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
LayerNorm layer.
"""
# from paddle.fluid.dygraph import LayerNorm
from six.moves import reduce
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Layer
import logging
class LayerNorm(Layer):
""" Implement LayerNorm in dygraph mode. """
def __init__(self,
name_scope,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None):
super().__init__(name_scope)
self._scale = scale
self._shift = shift
self._begin_norm_axis = begin_norm_axis
self._epsilon = epsilon
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
return
def _build_once(self, input):
""" Create parameters. """
self._dtype = self._helper.input_dtype(input)
input_shape = input.shape
param_shape = [
reduce(lambda x, y: x * y, input_shape[self._begin_norm_axis:])
]
if self._scale:
self._scale_w = self.create_parameter(
attr=self._param_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=fluid.initializer.Constant(1.0))
else:
if self._param_attr:
logging.warn("param_attr are only avaliable with scale is True")
if self._shift:
assert self._bias_attr is not False
self._bias_w = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
else:
if self._bias_attr:
logging.warn("bias_attr are only avaliable with shift is True")
return
def forward(self, x):
""" Forward process of LayerNorm. """
mean = layers.reduce_mean(x,
dim=list(range(self._begin_norm_axis, len(x.shape))),
keep_dim=True)
shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
variance = layers.reduce_mean(layers.square(shift_x),
dim=list(range(self._begin_norm_axis, len(x.shape))),
keep_dim=True)
r_stdev = layers.rsqrt(variance + self._epsilon)
norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)
out = layers.elementwise_mul(x=norm_x, y=self._scale_w, axis=-1)
out = layers.elementwise_add(x=out, y=self._bias_w, axis=-1)
return out
...@@ -16,11 +16,11 @@ MultiheadAttention class. ...@@ -16,11 +16,11 @@ MultiheadAttention class.
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import modules.functions as F
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import FC from paddle.fluid.dygraph import FC
import paddle.fluid.layers as layers
import plato.modules.functions as F
class MultiheadAttention(Layer): class MultiheadAttention(Layer):
......
...@@ -24,7 +24,7 @@ from paddle.fluid.dygraph import layers ...@@ -24,7 +24,7 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.layers import collective from paddle.fluid.layers import collective
import paddle.fluid.dygraph.base as base from paddle.fluid.dygraph.base import to_variable, no_grad
ParallelStrategy = core.ParallelStrategy ParallelStrategy = core.ParallelStrategy
...@@ -179,7 +179,7 @@ class DataParallel(layers.Layer): ...@@ -179,7 +179,7 @@ class DataParallel(layers.Layer):
if not self._is_data_parallel_mode(): if not self._is_data_parallel_mode():
return loss return loss
loss_scale = base.to_variable( loss_scale = to_variable(
np.array([self._strategy.nranks]).astype("float32")) np.array([self._strategy.nranks]).astype("float32"))
loss_scale.stop_gradient = True loss_scale.stop_gradient = True
loss = loss / loss_scale loss = loss / loss_scale
...@@ -214,6 +214,7 @@ class DataParallel(layers.Layer): ...@@ -214,6 +214,7 @@ class DataParallel(layers.Layer):
for g_var, g_shape in zip(origin_grad_vars, grad_shapes): for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
nn.reshape(x=g_var, shape=g_shape, inplace=True) nn.reshape(x=g_var, shape=g_shape, inplace=True)
@no_grad
def apply_collective_grads(self): def apply_collective_grads(self):
""" """
AllReduce the Parameters' gradient. AllReduce the Parameters' gradient.
......
...@@ -16,14 +16,14 @@ TransformerBlock class. ...@@ -16,14 +16,14 @@ TransformerBlock class.
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
from modules.feedforward import FeedForward
from modules.multihead_attention import MultiheadAttention
import modules.functions as F
from paddle.fluid.dygraph import FC from paddle.fluid.dygraph import FC
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import LayerNorm import paddle.fluid.layers as layers
from plato.modules.feedforward import FeedForward
from plato.modules.layer_norm import LayerNorm
from plato.modules.multihead_attention import MultiheadAttention
import plato.modules.functions as F
class TransformerBlock(Layer): class TransformerBlock(Layer):
......
...@@ -22,16 +22,17 @@ import sys ...@@ -22,16 +22,17 @@ import sys
import time import time
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph import paddle.fluid.dygraph as dygraph
from tqdm import tqdm from tqdm import tqdm
from args import str2bool from plato.args import str2bool
from dataloader import DataLoader from plato.data.data_loader import DataLoader
from metrics.metrics_tracker import MetricsTracker from plato.metrics.metrics_tracker import MetricsTracker
from metrics.metrics import bleu from plato.metrics.metrics import bleu
from metrics.metrics import distinct from plato.metrics.metrics import distinct
import modules.parallel as parallel import plato.modules.parallel as parallel
def get_logger(log_path, name="default"): def get_logger(log_path, name="default"):
...@@ -54,7 +55,10 @@ def get_logger(log_path, name="default"): ...@@ -54,7 +55,10 @@ def get_logger(log_path, name="default"):
def evaluate_generation_result(results): def evaluate_generation_result(results):
tgt = [result["tgt"].split(" ") for result in results] tgt = [result["tgt"].split(" ") for result in results]
pred = [result["preds"][np.argmax(result["scores"])] for result in results] pred = [result["preds"][np.argmax(result["scores"])]
if isinstance(result["preds"], list)
else result["preds"]
for result in results]
pred = [p.split(" ") for p in pred] pred = [p.split(" ") for p in pred]
metrics = {} metrics = {}
metrics_tracker = MetricsTracker() metrics_tracker = MetricsTracker()
...@@ -78,7 +82,12 @@ def evaluate_generation_result(results): ...@@ -78,7 +82,12 @@ def evaluate_generation_result(results):
def save(model, model_path): def save(model, model_path):
if isinstance(model, parallel.DataParallel): if isinstance(model, parallel.DataParallel):
model = model._layers model = model._layers
dygraph.save_persistables(model.state_dict(), model_path, optimizers=model.optimizer) if hasattr(fluid, "save_dygraph"):
# >= 1.6.0 compatible
fluid.save_dygraph(model.state_dict(), model_path)
fluid.save_dygraph(model.optimizer.state_dict(), model_path)
else:
dygraph.save_persistables(model.state_dict(), model_path, optimizers=model.optimizer)
return return
...@@ -115,10 +124,11 @@ class Trainer(object): ...@@ -115,10 +124,11 @@ class Trainer(object):
# Use data distributed # Use data distributed
if hparams.use_data_distributed: if hparams.use_data_distributed:
strategy = parallel.prepare_context() strategy = parallel.prepare_context()
parallel_model = parallel.DataParallel(model, strategy) if strategy is not None:
model.before_backward_fn = parallel_model.scale_loss parallel_model = parallel.DataParallel(model, strategy)
model.after_backward_fn = parallel_model.apply_collective_grads model.before_backward_fn = parallel_model.scale_loss
model = parallel_model model.after_backward_fn = parallel_model.apply_collective_grads
model = parallel_model
self.model = model self.model = model
self.to_tensor = to_tensor self.to_tensor = to_tensor
...@@ -143,7 +153,8 @@ class Trainer(object): ...@@ -143,7 +153,8 @@ class Trainer(object):
self.train_summary = {} self.train_summary = {}
self.valid_summary = {} self.valid_summary = {}
self.metrics_tracker = MetricsTracker() self.batch_metrics_tracker = MetricsTracker()
self.token_metrics_tracker = MetricsTracker()
self.best_valid_metric = float("inf" if self.is_decreased_valid_metric else "-inf") self.best_valid_metric = float("inf" if self.is_decreased_valid_metric else "-inf")
self.epoch = 0 self.epoch = 0
...@@ -167,33 +178,44 @@ class Trainer(object): ...@@ -167,33 +178,44 @@ class Trainer(object):
""" """
self.epoch += 1 self.epoch += 1
num_batches = len(train_iter) num_batches = len(train_iter)
self.metrics_tracker.clear() self.batch_metrics_tracker.clear()
self.token_metrics_tracker.clear()
times = [] times = []
for batch_id, (batch, batch_size) in enumerate(train_iter, 1): for batch_id, (batch, batch_size) in enumerate(train_iter, 1):
batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
batch["epoch"] = self.epoch batch["epoch"] = self.epoch
batch["num_steps"] = self.batch_num batch["num_steps"] = self.batch_num
# measure data loading time
# Do a training iteration # Do a training iteration
start_time = time.time() start_time = time.time()
metrics = self.model(batch, is_training=True) metrics = self.model(batch, is_training=True)
token_num = metrics.pop("token_num", None)
elapsed = time.time() - start_time elapsed = time.time() - start_time
times.append(elapsed) times.append(elapsed)
self.metrics_tracker.update(metrics, batch_size) batch_metrics = {k: v for k, v in metrics.items() if "token" not in k}
token_metrics = {k: v for k, v in metrics.items() if "token" in k}
self.batch_metrics_tracker.update(batch_metrics, batch_size)
self.token_metrics_tracker.update(token_metrics, token_num)
self.batch_num += 1 self.batch_num += 1
if self.log_steps and batch_id % self.log_steps == 0: if self.log_steps and batch_id % self.log_steps == 0:
metrics_message = self.metrics_tracker.value() batch_metrics_message = self.batch_metrics_tracker.value()
token_metrics_message = self.token_metrics_tracker.value()
message_prefix = f"[Train][{self.epoch}][{batch_id}/{num_batches}]" message_prefix = f"[Train][{self.epoch}][{batch_id}/{num_batches}]"
avg_time = f"AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}" avg_time = f"AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}"
message = " ".join([message_prefix, metrics_message, avg_time]) message = " ".join([message_prefix, batch_metrics_message, token_metrics_message,
avg_time])
self.logger.info(message) self.logger.info(message)
if self.save_summary: if self.save_summary:
with self.summary_logger.mode("train"): with self.summary_logger.mode("train"):
for k, v in self.metrics_tracker.items(): for k, v in self.batch_metrics_tracker.items():
if k not in self.train_summary:
self.train_summary[k] = self.summary_logger.scalar(k)
scalar = self.train_summary[k]
scalar.add_record(self.batch_num, v)
for k, v in self.token_metrics_tracker.items():
if k not in self.train_summary: if k not in self.train_summary:
self.train_summary[k] = self.summary_logger.scalar(k) self.train_summary[k] = self.summary_logger.scalar(k)
scalar = self.train_summary[k] scalar = self.train_summary[k]
...@@ -226,9 +248,11 @@ class Trainer(object): ...@@ -226,9 +248,11 @@ class Trainer(object):
""" """
self.logger.info("Generation starts ...") self.logger.info("Generation starts ...")
infer_save_file = os.path.join(self.save_dir, f"infer_{self.epoch}.result.json") infer_save_file = os.path.join(self.save_dir, f"infer_{self.epoch}.result.json")
# Inference # Inference
infer_results = [] infer_results = []
batch_cnt = 0 batch_cnt = 0
begin_time = time.time()
for batch, batch_size in tqdm(data_iter, total=num_batches): for batch, batch_size in tqdm(data_iter, total=num_batches):
batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
...@@ -264,7 +288,8 @@ class Trainer(object): ...@@ -264,7 +288,8 @@ class Trainer(object):
infer_metrics_tracker = evaluate_generation_result(infer_results) infer_metrics_tracker = evaluate_generation_result(infer_results)
metrics_message = infer_metrics_tracker.summary() metrics_message = infer_metrics_tracker.summary()
message_prefix = f"[Infer][{self.epoch}]" message_prefix = f"[Infer][{self.epoch}]"
message = " ".join([message_prefix, metrics_message]) time_cost = f"TIME-{time.time() - begin_time:.3f}"
message = " ".join([message_prefix, metrics_message, time_cost])
self.logger.info(message) self.logger.info(message)
return return
...@@ -282,42 +307,56 @@ class Trainer(object): ...@@ -282,42 +307,56 @@ class Trainer(object):
need_save = need_save and parallel.Env().local_rank == 0 need_save = need_save and parallel.Env().local_rank == 0
# Evaluation # Evaluation
metrics_tracker = MetricsTracker() begin_time = time.time()
batch_metrics_tracker = MetricsTracker()
token_metrics_tracker = MetricsTracker()
for batch, batch_size in data_iter: for batch, batch_size in data_iter:
batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
metrics = self.model(batch, is_training=False) metrics = self.model(batch, is_training=False)
metrics_tracker.update(metrics, batch_size) token_num = int(metrics.pop("token_num"))
metrics_message = metrics_tracker.summary() batch_metrics = {k: v for k, v in metrics.items() if "token" not in k}
token_metrics = {k: v for k, v in metrics.items() if "token" in k}
batch_metrics_tracker.update(batch_metrics, batch_size)
token_metrics_tracker.update(token_metrics, token_num)
batch_metrics_message = batch_metrics_tracker.summary()
token_metrics_message = token_metrics_tracker.summary()
message_prefix = f"[Valid][{self.epoch}]" message_prefix = f"[Valid][{self.epoch}]"
message = " ".join([message_prefix, metrics_message]) time_cost = f"TIME-{time.time() - begin_time:.3f}"
message = " ".join([message_prefix, batch_metrics_message, token_metrics_message, time_cost])
self.logger.info(message) self.logger.info(message)
# Check valid metric if need_save:
cur_valid_metric = metrics_tracker.get(self.valid_metric_name) # Check valid metric
if self.is_decreased_valid_metric: cur_valid_metric = batch_metrics_tracker.get(self.valid_metric_name)
is_best = cur_valid_metric < self.best_valid_metric if self.is_decreased_valid_metric:
else: is_best = cur_valid_metric < self.best_valid_metric
is_best = cur_valid_metric > self.best_valid_metric else:
if is_best and need_save: is_best = cur_valid_metric > self.best_valid_metric
# Save current best model if is_best:
self.best_valid_metric = cur_valid_metric # Save current best model
best_model_path = os.path.join(self.save_dir, "best.model") self.best_valid_metric = cur_valid_metric
save(self.model, best_model_path) best_model_path = os.path.join(self.save_dir, "best.model")
self.logger.info( save(self.model, best_model_path)
f"Saved best model to '{best_model_path}' with new best valid metric " self.logger.info(
f"{self.valid_metric_name.upper()}-{self.best_valid_metric:.3f}") f"Saved best model to '{best_model_path}' with new best valid metric "
f"{self.valid_metric_name.upper()}-{self.best_valid_metric:.3f}")
# Save checkpoint
if self.save_checkpoint and need_save: # Save checkpoint
model_file = os.path.join(self.save_dir, f"epoch_{self.epoch}.model") if self.save_checkpoint:
save(self.model, model_file) model_file = os.path.join(self.save_dir, f"epoch_{self.epoch}.model")
save(self.model, model_file)
if self.save_summary and need_save:
with self.summary_logger.mode("valid"): if self.save_summary:
for k, v in self.metrics_tracker.items(): with self.summary_logger.mode("valid"):
if k not in self.valid_summary: for k, v in self.batch_metrics_tracker.items():
self.valid_summary[k] = self.summary_logger.scalar(k) if k not in self.valid_summary:
scalar = self.valid_summary[k] self.valid_summary[k] = self.summary_logger.scalar(k)
scalar.add_record(self.batch_num, v) scalar = self.valid_summary[k]
scalar.add_record(self.batch_num, v)
for k, v in self.token_metrics_tracker.items():
if k not in self.valid_summary:
self.valid_summary[k] = self.summary_logger.scalar(k)
scalar = self.valid_summary[k]
scalar.add_record(self.batch_num, v)
return return
...@@ -18,10 +18,10 @@ Preprocess script. ...@@ -18,10 +18,10 @@ Preprocess script.
import os import os
import argparse import argparse
from args import str2bool from plato.args import str2bool
from args import parse_args from plato.args import parse_args
from dataset import Dataset from plato.data.dataset import Dataset
from field import BPETextField from plato.data.field import BPETextField
def main(): def main():
...@@ -35,15 +35,15 @@ def main(): ...@@ -35,15 +35,15 @@ def main():
raw_train_file = os.path.join(args.data_dir, "dial.train") raw_train_file = os.path.join(args.data_dir, "dial.train")
raw_valid_file = os.path.join(args.data_dir, "dial.valid") raw_valid_file = os.path.join(args.data_dir, "dial.valid")
raw_test_file = os.path.join(args.data_dir, "dial.test") raw_test_file = os.path.join(args.data_dir, "dial.test")
train_file = raw_train_file + ".jsonl" train_file = raw_train_file + f".{args.tokenizer_type}.jsonl"
valid_file = raw_valid_file + ".jsonl" valid_file = raw_valid_file + f".{args.tokenizer_type}.jsonl"
test_file = raw_test_file + ".jsonl" test_file = raw_test_file + f".{args.tokenizer_type}.jsonl"
bpe = BPETextField(args.BPETextField) bpe = BPETextField(args.BPETextField)
BUILD_EXAMPLES_FN = { BUILD_EXAMPLES_FN = {
"multi": bpe.build_examples_multi_turn, "multi": bpe.build_examples_multi_turn,
"multi_knowledge": bpe.build_examples_multi_turn_with_knoledge "multi_knowledge": bpe.build_examples_multi_turn_with_knowledge
} }
build_examples_fn = BUILD_EXAMPLES_FN[args.data_type] build_examples_fn = BUILD_EXAMPLES_FN[args.data_type]
......
...@@ -22,16 +22,16 @@ import os ...@@ -22,16 +22,16 @@ import os
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from args import parse_args from plato.args import parse_args
from args import str2bool from plato.args import str2bool
from dataloader import DataLoader from plato.data.data_loader import DataLoader
from dataset import Dataset from plato.data.dataset import Dataset
from dataset import LazyDataset from plato.data.dataset import LazyDataset
from field import BPETextField from plato.data.field import BPETextField
from trainer import Trainer from plato.trainer import Trainer
from models.unified_transformer import UnifiedTransformer from plato.models.model_base import ModelBase
from models.generator import BeamSearch from plato.models.generator import Generator
import modules.parallel as parallel import plato.modules.parallel as parallel
def main(): def main():
...@@ -39,21 +39,28 @@ def main(): ...@@ -39,21 +39,28 @@ def main():
parser.add_argument("--do_train", type=str2bool, default=False, parser.add_argument("--do_train", type=str2bool, default=False,
help="Whether to run trainning.") help="Whether to run trainning.")
parser.add_argument("--do_valid", type=str2bool, default=False, parser.add_argument("--do_test", type=str2bool, default=False,
help="Whether to run evaluation on the test dataset.") help="Whether to run evaluation on the test dataset.")
parser.add_argument("--do_infer", type=str2bool, default=False, parser.add_argument("--do_infer", type=str2bool, default=False,
help="Whether to run inference on the test dataset.") help="Whether to run inference on the test dataset.")
parser.add_argument("--num_infer_batches", type=int, default=None, parser.add_argument("--num_infer_batches", type=int, default=None,
help="The number of batches need to infer.\n" help="The number of batches need to infer.\n"
"Stay 'None': infer on entrie test dataset.") "Stay 'None': infer on entrie test dataset.")
parser.add_argument("--hparams_file", type=str, default=None,
help="Loading hparams setting from file(.json format).")
BPETextField.add_cmdline_argument(parser) BPETextField.add_cmdline_argument(parser)
Dataset.add_cmdline_argument(parser) Dataset.add_cmdline_argument(parser)
Trainer.add_cmdline_argument(parser) Trainer.add_cmdline_argument(parser)
UnifiedTransformer.add_cmdline_argument(parser) ModelBase.add_cmdline_argument(parser)
BeamSearch.add_cmdline_argument(parser) Generator.add_cmdline_argument(parser)
hparams = parse_args(parser) hparams = parse_args(parser)
if hparams.hparams_file and os.path.exists(hparams.hparams_file):
print(f"Loading hparams from {hparams.hparams_file} ...")
hparams.load(hparams.hparams_file)
print(f"Loaded hparams from {hparams.hparams_file}")
print(json.dumps(hparams, indent=2)) print(json.dumps(hparams, indent=2))
if not os.path.exists(hparams.save_dir): if not os.path.exists(hparams.save_dir):
...@@ -63,7 +70,7 @@ def main(): ...@@ -63,7 +70,7 @@ def main():
bpe = BPETextField(hparams.BPETextField) bpe = BPETextField(hparams.BPETextField)
hparams.Model.num_token_embeddings = bpe.vocab_size hparams.Model.num_token_embeddings = bpe.vocab_size
generator = BeamSearch(bpe, hparams.Generator) generator = Generator.create(hparams.Generator, bpe=bpe)
COLLATE_FN = { COLLATE_FN = {
"multi": bpe.collate_fn_multi_turn, "multi": bpe.collate_fn_multi_turn,
...@@ -74,22 +81,22 @@ def main(): ...@@ -74,22 +81,22 @@ def main():
# Loading datasets # Loading datasets
if hparams.do_train: if hparams.do_train:
raw_train_file = os.path.join(hparams.data_dir, "dial.train") raw_train_file = os.path.join(hparams.data_dir, "dial.train")
train_file = raw_train_file + ".jsonl" train_file = raw_train_file + f".{hparams.tokenizer_type}.jsonl"
assert os.path.exists(train_file), f"{train_file} isn't exist" assert os.path.exists(train_file), f"{train_file} isn't exist"
train_dataset = LazyDataset(train_file) train_dataset = LazyDataset(train_file)
train_loader = DataLoader(train_dataset, hparams.Trainer, collate_fn=collate_fn) train_loader = DataLoader(train_dataset, hparams.Trainer, collate_fn=collate_fn, is_train=True)
raw_valid_file = os.path.join(hparams.data_dir, "dial.valid") raw_valid_file = os.path.join(hparams.data_dir, "dial.valid")
valid_file = raw_valid_file + ".jsonl" valid_file = raw_valid_file + f".{hparams.tokenizer_type}.jsonl"
assert os.path.exists(valid_file), f"{valid_file} isn't exist" assert os.path.exists(valid_file), f"{valid_file} isn't exist"
valid_dataset = LazyDataset(valid_file) valid_dataset = LazyDataset(valid_file)
valid_loader = DataLoader(valid_dataset, hparams.Trainer, collate_fn=collate_fn) valid_loader = DataLoader(valid_dataset, hparams.Trainer, collate_fn=collate_fn)
if hparams.do_infer or hparams.do_valid: if hparams.do_infer or hparams.do_test:
raw_test_file = os.path.join(hparams.data_dir, "dial.test") raw_test_file = os.path.join(hparams.data_dir, "dial.test")
test_file = raw_test_file + ".jsonl" test_file = raw_test_file + f".{hparams.tokenizer_type}.jsonl"
assert os.path.exists(test_file), f"{test_file} isn't exist" assert os.path.exists(test_file), f"{test_file} isn't exist"
test_dataset = LazyDataset(test_file) test_dataset = LazyDataset(test_file)
test_loader = DataLoader(test_dataset, hparams.Trainer, collate_fn=collate_fn, is_test=True) test_loader = DataLoader(test_dataset, hparams.Trainer, collate_fn=collate_fn, is_test=hparams.do_infer)
def to_tensor(array): def to_tensor(array):
array = np.expand_dims(array, -1) array = np.expand_dims(array, -1)
...@@ -102,7 +109,7 @@ def main(): ...@@ -102,7 +109,7 @@ def main():
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
# Construct Model # Construct Model
model = UnifiedTransformer("Model", generator, hparams) model = ModelBase.create("Model", hparams, generator=generator)
# Construct Trainer # Construct Trainer
trainer = Trainer(model, to_tensor, hparams.Trainer) trainer = Trainer(model, to_tensor, hparams.Trainer)
...@@ -112,7 +119,7 @@ def main(): ...@@ -112,7 +119,7 @@ def main():
for epoch in range(hparams.num_epochs): for epoch in range(hparams.num_epochs):
trainer.train_epoch(train_loader, valid_loader) trainer.train_epoch(train_loader, valid_loader)
if hparams.do_valid: if hparams.do_test:
# Validation process # Validation process
trainer.evaluate(test_loader, need_save=False) trainer.evaluate(test_loader, need_save=False)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
set -ux set -ux
SAVE_DIR=outputs/DSTC7_AVSD.infer SAVE_DIR=outputs/DSTC7_AVSD.infer
VOCAB_PATH=data/vocab.txt VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DSTC7_AVSD DATA_DIR=data/DSTC7_AVSD
INIT_CHECKPOINT=outputs/DSTC7_AVSD/best.model INIT_CHECKPOINT=outputs/DSTC7_AVSD/best.model
DATA_TYPE=multi_knowledge DATA_TYPE=multi_knowledge
...@@ -15,13 +15,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1 ...@@ -15,13 +15,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
if [[ ! -e $DATA_DIR/dial.test.jsonl ]]; then python -u \
python -u \ ./preprocess.py \
./preprocess.py \ --vocab_path $VOCAB_PATH \
--vocab_path $VOCAB_PATH \ --data_dir $DATA_DIR \
--data_dir $DATA_DIR \ --data_type $DATA_TYPE
--data_type $DATA_TYPE
fi
python -u \ python -u \
./run.py \ ./run.py \
...@@ -29,7 +27,7 @@ python -u \ ...@@ -29,7 +27,7 @@ python -u \
--vocab_path $VOCAB_PATH \ --vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--data_type $DATA_TYPE \ --data_type $DATA_TYPE \
--batch_size 2 \ --batch_size 4 \
--num_type_embeddings 3 \ --num_type_embeddings 3 \
--use_discriminator true \ --use_discriminator true \
--init_checkpoint $INIT_CHECKPOINT \ --init_checkpoint $INIT_CHECKPOINT \
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
set -ux set -ux
SAVE_DIR=outputs/DSTC7_AVSD SAVE_DIR=outputs/DSTC7_AVSD
VOCAB_PATH=data/vocab.txt VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DSTC7_AVSD DATA_DIR=data/DSTC7_AVSD
INIT_CHECKPOINT=model/PLATO INIT_CHECKPOINT=model/PLATO
DATA_TYPE=multi_knowledge DATA_TYPE=multi_knowledge
...@@ -33,7 +33,7 @@ python -u \ ...@@ -33,7 +33,7 @@ python -u \
--vocab_path $VOCAB_PATH \ --vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--data_type $DATA_TYPE \ --data_type $DATA_TYPE \
--batch_size 8 \ --batch_size 4 \
--valid_steps 2000 \ --valid_steps 2000 \
--num_type_embeddings 3 \ --num_type_embeddings 3 \
--use_discriminator true \ --use_discriminator true \
......
#!/bin/bash
set -ux
SAVE_DIR=outputs/DailyDialog.baseline.infer
VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DailyDialog
INIT_CHECKPOINT=outputs/DailyDialog.baseline/best.model
DATA_TYPE=multi
# CUDA environment settings.
export CUDA_VISIBLE_DEVICES=0
# Paddle environment settings.
export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0
python -u \
./preprocess.py \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE
python -u \
./run.py \
--do_infer true \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE \
--batch_size 48 \
--num_latent 0 \
--num_type_embeddings 2 \
--init_checkpoint $INIT_CHECKPOINT \
--length_average true \
--save_dir $SAVE_DIR
#!/bin/bash
set -ux
SAVE_DIR=outputs/DailyDialog.baseline
VOCAB_PATH=model-baseline/Bert/vocab.txt
DATA_DIR=data/DailyDialog
INIT_CHECKPOINT=model-baseline/PLATO.baseline
DATA_TYPE=multi
USE_VISUALDL=false
# CUDA environment settings.
export CUDA_VISIBLE_DEVICES=2
# Paddle environment settings.
export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0
python -u \
./preprocess.py \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE
if [[ "$USE_VISUALDL" = true ]]; then
visualdl --logdir=$SAVE_DIR/summary --port=8083 --host=`hostname` &
VISUALDL_PID=$!
fi
python -u \
./run.py \
--do_train true \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE \
--batch_size 2 \
--valid_steps 2000 \
--num_type_embeddings 2 \
--num_latent 0 \
--num_epoch 20 \
--lr 1e-5 \
--save_checkpoint false \
--save_summary $USE_VISUALDL \
--init_checkpoint $INIT_CHECKPOINT \
--save_dir $SAVE_DIR
if [[ $USE_VISUALDL = true ]]; then
kill $VISUALDL_PID
fi
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
set -ux set -ux
SAVE_DIR=outputs/DailyDialog.infer SAVE_DIR=outputs/DailyDialog.infer
VOCAB_PATH=data/vocab.txt VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DailyDialog DATA_DIR=data/DailyDialog
INIT_CHECKPOINT=outputs/DailyDialog/best.model INIT_CHECKPOINT=outputs/DailyDialog/best.model
DATA_TYPE=multi DATA_TYPE=multi
...@@ -15,13 +15,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1 ...@@ -15,13 +15,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
if [[ ! -e $DATA_DIR/dial.test.jsonl ]]; then python -u \
python -u \ ./preprocess.py \
./preprocess.py \ --vocab_path $VOCAB_PATH \
--vocab_path $VOCAB_PATH \ --data_dir $DATA_DIR \
--data_dir $DATA_DIR \ --data_type $DATA_TYPE
--data_type $DATA_TYPE
fi
python -u \ python -u \
./run.py \ ./run.py \
...@@ -29,8 +27,9 @@ python -u \ ...@@ -29,8 +27,9 @@ python -u \
--vocab_path $VOCAB_PATH \ --vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--data_type $DATA_TYPE \ --data_type $DATA_TYPE \
--batch_size 2 \ --batch_size 4 \
--num_type_embeddings 2 \ --num_type_embeddings 2 \
--num_latent 20 \
--use_discriminator true \ --use_discriminator true \
--init_checkpoint $INIT_CHECKPOINT \ --init_checkpoint $INIT_CHECKPOINT \
--save_dir $SAVE_DIR --save_dir $SAVE_DIR
#!/bin/bash
set -ux
SAVE_DIR=outputs/DailyDialog
VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DailyDialog
INIT_CHECKPOINT=model/PLATO
DATA_TYPE=multi
USE_VISUALDL=false
# CUDA environment settings.
export CUDA_VISIBLE_DEVICES=0,1
# Paddle environment settings.
export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0
if [[ ! -e $DATA_DIR/dial.train.jsonl ]]; then
python -u \
./preprocess.py \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE
fi
if [[ "$USE_VISUALDL" = true ]]; then
visualdl --logdir=$SAVE_DIR/summary --port=8083 --host=`hostname` &
VISUALDL_PID=$!
fi
python -m \
paddle.distributed.launch \
--log_dir $SAVE_DIR \
--started_port 8888 \
./run.py \
--use_data_distributed true \
--do_train true \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE \
--batch_size 6 \
--valid_steps 2000 \
--num_type_embeddings 2 \
--use_discriminator true \
--num_epoch 20 \
--lr 1e-5 \
--save_checkpoint false \
--save_summary $USE_VISUALDL \
--init_checkpoint $INIT_CHECKPOINT \
--save_dir $SAVE_DIR
if [[ $USE_VISUALDL = true ]]; then
kill $VISUALDL_PID
fi
#!/bin/bash
set -ux
SAVE_DIR=outputs/DailyDialog.infer
VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DailyDialog
INIT_CHECKPOINT=outputs/DailyDialog/best.model
DATA_TYPE=multi
# CUDA environment settings.
export CUDA_VISIBLE_DEVICES=0
# Paddle environment settings.
export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0
if [[ ! -e $DATA_DIR/dial.test.jsonl ]]; then
python -u \
./preprocess.py \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE
fi
python -u \
./run.py \
--do_infer true \
--generator TopKSampling \
--top_k_num 10 \
--sampling_temperate 0.8 \
--vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \
--data_type $DATA_TYPE \
--batch_size 16 \
--num_type_embeddings 2 \
--use_discriminator true \
--init_checkpoint $INIT_CHECKPOINT \
--save_dir $SAVE_DIR
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
set -ux set -ux
SAVE_DIR=outputs/DailyDialog SAVE_DIR=outputs/DailyDialog
VOCAB_PATH=data/vocab.txt VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/DailyDialog DATA_DIR=data/DailyDialog
INIT_CHECKPOINT=model/PLATO INIT_CHECKPOINT=model/PLATO
DATA_TYPE=multi DATA_TYPE=multi
...@@ -16,13 +16,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1 ...@@ -16,13 +16,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
if [[ ! -e $DATA_DIR/dial.train.jsonl ]]; then python -u \
python -u \ ./preprocess.py \
./preprocess.py \ --vocab_path $VOCAB_PATH \
--vocab_path $VOCAB_PATH \ --data_dir $DATA_DIR \
--data_dir $DATA_DIR \ --data_type $DATA_TYPE
--data_type $DATA_TYPE
fi
if [[ "$USE_VISUALDL" = true ]]; then if [[ "$USE_VISUALDL" = true ]]; then
visualdl --logdir=$SAVE_DIR/summary --port=8083 --host=`hostname` & visualdl --logdir=$SAVE_DIR/summary --port=8083 --host=`hostname` &
...@@ -35,7 +33,7 @@ python -u \ ...@@ -35,7 +33,7 @@ python -u \
--vocab_path $VOCAB_PATH \ --vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--data_type $DATA_TYPE \ --data_type $DATA_TYPE \
--batch_size 12 \ --batch_size 6 \
--valid_steps 2000 \ --valid_steps 2000 \
--num_type_embeddings 2 \ --num_type_embeddings 2 \
--use_discriminator true \ --use_discriminator true \
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
set -ux set -ux
SAVE_DIR=outputs/PersonaChat.infer SAVE_DIR=outputs/PersonaChat.infer
VOCAB_PATH=data/vocab.txt VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/PersonaChat DATA_DIR=data/PersonaChat
INIT_CHECKPOINT=outputs/PersonaChat/best.model INIT_CHECKPOINT=outputs/PersonaChat/best.model
DATA_TYPE=multi_knowledge DATA_TYPE=multi_knowledge
...@@ -15,13 +15,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1 ...@@ -15,13 +15,11 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.1
export FLAGS_eager_delete_scope=True export FLAGS_eager_delete_scope=True
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
if [[ ! -e $DATA_DIR/dial.test.jsonl ]]; then python -u \
python -u \ ./preprocess.py \
./preprocess.py \ --vocab_path $VOCAB_PATH \
--vocab_path $VOCAB_PATH \ --data_dir $DATA_DIR \
--data_dir $DATA_DIR \ --data_type $DATA_TYPE
--data_type $DATA_TYPE
fi
python -u \ python -u \
./run.py \ ./run.py \
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
set -ux set -ux
SAVE_DIR=outputs/PersonaChat SAVE_DIR=outputs/PersonaChat
VOCAB_PATH=data/vocab.txt VOCAB_PATH=model/Bert/vocab.txt
DATA_DIR=data/PersonaChat DATA_DIR=data/PersonaChat
INIT_CHECKPOINT=model/PLATO INIT_CHECKPOINT=model/PLATO
DATA_TYPE=multi_knowledge DATA_TYPE=multi_knowledge
...@@ -33,7 +33,7 @@ python -u \ ...@@ -33,7 +33,7 @@ python -u \
--vocab_path $VOCAB_PATH \ --vocab_path $VOCAB_PATH \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--data_type $DATA_TYPE \ --data_type $DATA_TYPE \
--batch_size 12 \ --batch_size 4 \
--valid_steps 2000 \ --valid_steps 2000 \
--num_type_embeddings 3 \ --num_type_embeddings 3 \
--use_discriminator true \ --use_discriminator true \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册