From d7009805f1641322015ecdd115daf91974477a19 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Fri, 18 Dec 2020 19:38:34 +0800 Subject: [PATCH] Add transformer-xl for language model (#4987) * add transformer-xl for language model --- .../language_model/transformer-xl/README.md | 89 ++ .../transformer-xl/configs/enwik8.yaml | 112 ++ .../transformer-xl/configs/text8.yaml | 112 ++ .../transformer-xl/configs/wt103.yaml | 112 ++ .../language_model/transformer-xl/eval.py | 134 ++ .../language_model/transformer-xl/gen_data.sh | 55 + .../transformer-xl/mem_transformer.py | 1181 +++++++++++++++++ .../language_model/transformer-xl/reader.py | 197 +++ .../language_model/transformer-xl/train.py | 308 +++++ .../transformer-xl/utils/preprocess_text8.py | 21 + 10 files changed, 2321 insertions(+) create mode 100644 PaddleNLP/examples/language_model/transformer-xl/README.md create mode 100644 PaddleNLP/examples/language_model/transformer-xl/configs/enwik8.yaml create mode 100644 PaddleNLP/examples/language_model/transformer-xl/configs/text8.yaml create mode 100644 PaddleNLP/examples/language_model/transformer-xl/configs/wt103.yaml create mode 100644 PaddleNLP/examples/language_model/transformer-xl/eval.py create mode 100644 PaddleNLP/examples/language_model/transformer-xl/gen_data.sh create mode 100644 PaddleNLP/examples/language_model/transformer-xl/mem_transformer.py create mode 100644 PaddleNLP/examples/language_model/transformer-xl/reader.py create mode 100644 PaddleNLP/examples/language_model/transformer-xl/train.py create mode 100644 PaddleNLP/examples/language_model/transformer-xl/utils/preprocess_text8.py diff --git a/PaddleNLP/examples/language_model/transformer-xl/README.md b/PaddleNLP/examples/language_model/transformer-xl/README.md new file mode 100644 index 00000000..81f326a7 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/README.md @@ -0,0 +1,89 @@ +# Language Model + +## Transformer-XL + +以下是本例的简要目录结构及说明: + +```text +. +├── eval.py # 预测脚本 +├── reader.py # 数据读取接口 +├── README.md # 文档 +├── train.py # 训练脚本 +└── configs # 配置文件 +``` + +## 模型简介 + +本项目是语言模型 Transformer-XL 的 PaddlePaddle 实现, 包含模型训练,预测等内容。 + + +## 快速开始 + +### 安装说明 + +1. paddle安装 + + 本项目依赖于 PaddlePaddle 2.0rc及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装 + +2. 下载代码 + + 克隆代码库到本地 + +3. 环境依赖 + + 该模型使用PaddlePaddle,关于环境依赖部分,请先参考PaddlePaddle[安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/index_cn.html)关于环境依赖部分的内容。 + 此外,需要另外涉及: + * attrdict + * pyyaml + + + +### 数据准备 + +公开数据集:enwik8、text8、wt103 多用于语言模型的 benchmark 测试。输出获取与处理方式如下: + +```shell +bash gen_data.sh +``` + +会在当前路径下的 ./gen_data/ 路径下生成我们需要的数据。 + + +### 单机训练 + +### 单机单卡 + +以提供的 enwik8 数据为例,可以执行以下命令进行模型训练: + +```sh +# setting visible devices for training +export CUDA_VISIBLE_DEVICES=0 +python train.py --config ./configs/enwik8.yaml +``` + +可以在 enwik8.yaml 文件中设置相应的参数,比如 `batch_size`、`epoch` 等。 + +### 单机多卡 + +同样,可以执行如下命令实现八卡训练: + +```sh +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train.py --config ./configs/enwik8.yaml +``` + + +### 模型推断 + +以 enwik8 数据为例,模型训练完成后可以执行以下命令可以进行预测: + +```sh +# setting visible devices for prediction +export CUDA_VISIBLE_DEVICES=0 +python eval.py --config ./configs/enwik8.yaml +``` + +完成推断之后,会将显示在验证集和测试集上的结果。 + +## 参考文献 diff --git a/PaddleNLP/examples/language_model/transformer-xl/configs/enwik8.yaml b/PaddleNLP/examples/language_model/transformer-xl/configs/enwik8.yaml new file mode 100644 index 00000000..0f672326 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/configs/enwik8.yaml @@ -0,0 +1,112 @@ +# The frequency to save trained models when training. +save_step: 10000 +# The frequency to fetch and print output when training. +print_step: 100 +# Path of the checkpoint, to resume the previous training +init_from_checkpoint: "" +# Path of the pretrain model, to better solve the current task +init_from_pretrain_model: "" +# Path of trained parameter, to make prediction +init_from_params: "./trained_models/step_final/" +# The directory for saving model +save_model: "trained_models" +# The directory for saving inference model. +inference_model_dir: "infer_model" +# Set seed for CE or debug +random_seed: None +# The path to data files +data: "./gen_data/enwik8/" +# The name of dataset +dataset: "enwik8" + +# Whether to use cuda +use_gpu: True + +# Args for reader, see reader.py for details +token_delimiter: None +batch_size: 16 +eval_batch_size: 2 + +# Hyparams for training: +# The number of epoches for training +epoch: 30 + +# The hyper parameters for optimizer. +# Type of ptimizer. +optim: adam +# Learning rate schedule. +scheduler: cosine +# This static learning_rate will be applied to the LearningRateScheduler +# derived learning rate the to get the final learning rate. +learning_rate: 0.00025 +# The hyper parameters for Adam optimizer. +beta1: 0.9 +beta2: 0.997 +eps: 1e-9 +# The hyper parameters for Momentum optimizer. +mom: 0.0 +# Global gradient clip. +clip: 0.25 +# The parameters for learning rate scheduling. +warmup_steps: 0 +# The parameters for CosineAnnealingDecay. Minimum learning rate. +eta_min: 0.0 +# The parameters for ReduceLROnPlateau. +# The Ratio that the learning rate will be reduced. +decay_rate: 0.5 +# When loss doesn’t improve for this number of epochs, learing rate will be reduced. +patience: 0 +# The lower bound of the learning rate after reduction. +min_lr: 0.0 + +# Hyparams for model: +# Whe use adaptive softmax. +adaptive: False +# Size of dictionary. This can be obtained automatically. +ntokens: 10000 +# The dimension for word embeddings, which is also the last dimension of +# the input and output of multi-head attention, position-wise feed-forward +# networks, encoder and decoder. +d_model: 512 +# Dimension of heads. +d_head: 64 +# Size of the hidden layer in position-wise feed-forward networks. +d_inner_hid: 2048 +# Number of head used in multi-head attention. +n_head: 8 +# Number of sub-layers to be stacked in the encoder and decoder. +n_layer: 12 +# Dropout rates. +dropout: 0.1 +# Attention dropout +attn_dropout: 0.0 +# Attention type for decoder. +# 0 for relative partial MHA (in Transformer-XL). +# 1 for relative MHA (in Shaw et al). +attn_type: 0 +# Apply layer normalization before or after sublayers. +normalize_before: False +# Whether to tie weight or not. +tie_weight: True +# The length of the extended context. +ext_len: 0 +# The divident value for softmax and adapative input. +div_val: 1 +# Target length. The number of tokens to predict. +tgt_len: 512 +# Memory length. The length of the retained previous heads. +mem_len: 512 +# Use the same attention length for all tokens. +same_length: False +# Use the same positional encoding after clamp len. +clamp_len: -1 +# The number of samples in sample softmax. -1 means do not use sampled softmax. +sample_softmax: -1 +# Max step for training. +max_step: 400000 +# Target length for evaluation. That is, the number of tokens to predict for evaluation. +eval_tgt_len: 128 +# What kind of mode for evaluation. valid, test or both("all"). +mode: "all" +# Maximum evaluation step. +max_eval_steps: -1 diff --git a/PaddleNLP/examples/language_model/transformer-xl/configs/text8.yaml b/PaddleNLP/examples/language_model/transformer-xl/configs/text8.yaml new file mode 100644 index 00000000..eb84ff3b --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/configs/text8.yaml @@ -0,0 +1,112 @@ +# The frequency to save trained models when training. +save_step: 10000 +# The frequency to fetch and print output when training. +print_step: 100 +# Path of the checkpoint, to resume the previous training +init_from_checkpoint: "" +# Path of the pretrain model, to better solve the current task +init_from_pretrain_model: "" +# Path of trained parameter, to make prediction +init_from_params: "./trained_models/step_final/" +# The directory for saving model +save_model: "trained_models" +# The directory for saving inference model. +inference_model_dir: "infer_model" +# Set seed for CE or debug +random_seed: None +# The path to data files +data: "./gen_data/text8/" +# The name of dataset +dataset: "text8" + +# Whether to use cuda +use_gpu: True + +# Args for reader, see reader.py for details +token_delimiter: None +batch_size: 15 +eval_batch_size: 5 + +# Hyparams for training: +# The number of epoches for training +epoch: 30 + +# The hyper parameters for optimizer. +# Type of ptimizer. +optim: adam +# Learning rate schedule. +scheduler: cosine +# This static learning_rate will be applied to the LearningRateScheduler +# derived learning rate the to get the final learning rate. +learning_rate: 0.00025 +# The hyper parameters for Adam optimizer. +beta1: 0.9 +beta2: 0.997 +eps: 1e-9 +# The hyper parameters for Momentum optimizer. +mom: 0.0 +# Global gradient clip. +clip: 0.25 +# The parameters for learning rate scheduling. +warmup_steps: 0 +# The parameters for CosineAnnealingDecay. Minimum learning rate. +eta_min: 0.0 +# The parameters for ReduceLROnPlateau. +# The Ratio that the learning rate will be reduced. +decay_rate: 0.5 +# When loss doesn’t improve for this number of epochs, learing rate will be reduced. +patience: 0 +# The lower bound of the learning rate after reduction. +min_lr: 0.0 + +# Hyparams for model: +# Whe use adaptive softmax. +adaptive: False +# Size of dictionary. This can be obtained automatically. +ntokens: 10000 +# The dimension for word embeddings, which is also the last dimension of +# the input and output of multi-head attention, position-wise feed-forward +# networks, encoder and decoder. +d_model: 512 +# Dimension of heads. +d_head: 64 +# Size of the hidden layer in position-wise feed-forward networks. +d_inner_hid: 2048 +# Number of head used in multi-head attention. +n_head: 8 +# Number of sub-layers to be stacked in the encoder and decoder. +n_layer: 12 +# Dropout rates. +dropout: 0.1 +# Attention dropout +attn_dropout: 0.0 +# Attention type for decoder. +# 0 for relative partial MHA (in Transformer-XL). +# 1 for relative MHA (in Shaw et al). +attn_type: 0 +# Apply layer normalization before or after sublayers. +normalize_before: False +# Whether to tie weight or not. +tie_weight: True +# The length of the extended context. +ext_len: 0 +# The divident value for softmax and adapative input. +div_val: 1 +# Target length. The number of tokens to predict. +tgt_len: 512 +# Memory length. The length of the retained previous heads. +mem_len: 512 +# Use the same attention length for all tokens. +same_length: False +# Use the same positional encoding after clamp len. +clamp_len: -1 +# The number of samples in sample softmax. -1 means do not use sampled softmax. +sample_softmax: -1 +# Max step for training. +max_step: 400000 +# Target length for evaluation. That is, the number of tokens to predict for evaluation. +eval_tgt_len: 128 +# What kind of mode for evaluation. valid, test or both("all"). +mode: "all" +# Maximum evaluation step. +max_eval_steps: -1 diff --git a/PaddleNLP/examples/language_model/transformer-xl/configs/wt103.yaml b/PaddleNLP/examples/language_model/transformer-xl/configs/wt103.yaml new file mode 100644 index 00000000..f8cce122 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/configs/wt103.yaml @@ -0,0 +1,112 @@ +# The frequency to save trained models when training. +save_step: 10000 +# The frequency to fetch and print output when training. +print_step: 100 +# Path of the checkpoint, to resume the previous training +init_from_checkpoint: "" +# Path of the pretrain model, to better solve the current task +init_from_pretrain_model: "" +# Path of trained parameter, to make prediction +init_from_params: "./trained_models/step_final/" +# The directory for saving model +save_model: "trained_models" +# The directory for saving inference model. +inference_model_dir: "infer_model" +# Set seed for CE or debug +random_seed: None +# The path to data files +data: "./gen_data/wikitext-103/" +# The name of dataset +dataset: "wt103" + +# Whether to use cuda +use_gpu: True + +# Args for reader, see reader.py for details +token_delimiter: None +batch_size: 32 +eval_batch_size: 5 + +# Hyparams for training: +# The number of epoches for training +epoch: 30 + +# The hyper parameters for optimizer. +# Type of ptimizer. +optim: adam +# Learning rate schedule. +scheduler: cosine +# This static learning_rate will be applied to the LearningRateScheduler +# derived learning rate the to get the final learning rate. +learning_rate: 0.00025 +# The hyper parameters for Adam optimizer. +beta1: 0.9 +beta2: 0.997 +eps: 1e-9 +# The hyper parameters for Momentum optimizer. +mom: 0.0 +# Global gradient clip. +clip: 0.25 +# The parameters for learning rate scheduling. +warmup_steps: 0 +# The parameters for CosineAnnealingDecay. Minimum learning rate. +eta_min: 0.0 +# The parameters for ReduceLROnPlateau. +# The Ratio that the learning rate will be reduced. +decay_rate: 0.5 +# When loss doesn’t improve for this number of epochs, learing rate will be reduced. +patience: 0 +# The lower bound of the learning rate after reduction. +min_lr: 0.0 + +# Hyparams for model: +# Whe use adaptive softmax. +adaptive: True +# Size of dictionary. This can be obtained automatically. +ntokens: 10000 +# The dimension for word embeddings, which is also the last dimension of +# the input and output of multi-head attention, position-wise feed-forward +# networks, encoder and decoder. +d_model: 410 +# Dimension of heads. +d_head: 41 +# Size of the hidden layer in position-wise feed-forward networks. +d_inner_hid: 2100 +# Number of head used in multi-head attention. +n_head: 10 +# Number of sub-layers to be stacked in the encoder and decoder. +n_layer: 16 +# Dropout rates. +dropout: 0.1 +# Attention dropout +attn_dropout: 0.0 +# Attention type for decoder. +# 0 for relative partial MHA (in Transformer-XL). +# 1 for relative MHA (in Shaw et al). +attn_type: 0 +# Apply layer normalization before or after sublayers. +normalize_before: False +# Whether to tie weight or not. +tie_weight: True +# The length of the extended context. +ext_len: 0 +# The divident value for softmax and adapative input. +div_val: 1 +# Target length. The number of tokens to predict. +tgt_len: 150 +# Memory length. The length of the retained previous heads. +mem_len: 150 +# Target length for evaluation. That is, the number of tokens to predict for evaluation. +eval_tgt_len: 150 +# Use the same attention length for all tokens. +same_length: False +# Use the same positional encoding after clamp len. +clamp_len: -1 +# The number of samples in sample softmax. -1 means do not use sampled softmax. +sample_softmax: -1 +# Max step for training. +max_step: 200000 +# What kind of mode for evaluation. valid, test or both("all"). +mode: "all" +# Maximum evaluation step. +max_eval_steps: -1 diff --git a/PaddleNLP/examples/language_model/transformer-xl/eval.py b/PaddleNLP/examples/language_model/transformer-xl/eval.py new file mode 100644 index 00000000..a6c49cf8 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/eval.py @@ -0,0 +1,134 @@ +import os +import time +import yaml +import logging +import argparse +import numpy as np +from pprint import pprint +from attrdict import AttrDict + +import paddle + +from reader import get_lm_vocab, get_lm_data_loader +from mem_transformer import MemTransformerLM + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="./configs/enwik8.yaml", + type=str, + help="Path of the config file. ") + args = parser.parse_args() + return args + + +def do_eval(args): + assert args.ext_len >= 0, 'Extended context length must be no less than 0' + + def _evaluate(loader): + total_len, total_loss = 0, 0. + + eval_mems = tuple() + for i, (src, target, seq_len) in enumerate(loader): + if args.max_eval_steps > 0 and i >= args.max_eval_steps: + break + ret = mem_transformer(src, target, *eval_mems) + loss, eval_mems = ret[0], ret[1:] + seq_len = seq_len.numpy() + eval_cur_loss = seq_len * loss.numpy() + total_loss += eval_cur_loss + total_len += seq_len + return total_loss / total_len + + def _logger(loss): + if args.dataset in ['enwik8', 'text8']: + logger_info = "loss: %f, bpc: %f" % \ + (loss, loss / np.log(2)) + else: + logger_info = "loss: %f, ppl: %.2f" % \ + (loss, np.exp(loss)) + return logger_info + + vocab = get_lm_vocab(args) + eval_loader = get_lm_data_loader(args, vocab, "valid") + test_loader = get_lm_data_loader(args, vocab, "test") + + cutoffs, tie_projs = [], [False] + if args.adaptive: + assert args.dataset in ['wt103', 'lm1b'] + if args.dataset == 'wt103': + cutoffs = [20000, 40000, 200000] + tie_projs += [True] * len(cutoffs) + elif args.dataset == 'lm1b': + cutoffs = [60000, 100000, 640000] + tie_projs += [False] * len(cutoffs) + + mem_transformer = MemTransformerLM( + args.ntokens, + args.n_layer, + args.n_head, + args.d_model, + args.d_head, + args.d_inner_hid, + args.dropout, + args.attn_dropout, + tie_weight=args.tie_weight, + d_embed=args.d_model, + div_val=args.div_val, + tie_projs=tie_projs, + normalize_before=args.normalize_before, + tgt_len=args.tgt_len, + ext_len=args.ext_len, + mem_len=args.mem_len, + cutoffs=cutoffs, + same_length=args.same_length, + attn_type=args.attn_type, + clamp_len=args.clamp_len, + sample_softmax=args.sample_softmax) + + assert args.init_from_params, ( + "Please set init_from_params to load the infer model.") + + model_dict = paddle.load( + os.path.join(args.init_from_params, "mem_transformer.pdparams")) + mem_transformer.load_dict(model_dict) + + logger.info( + "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}". + format(args.batch_size, args.tgt_len, args.ext_len, args.mem_len, + args.clamp_len)) + + mem_transformer.reset_length(args.tgt_len, args.ext_len, args.mem_len) + + test_loss = None + valid_loss = None + if args.mode == 'all': + test_loss = _evaluate(test_loader) + valid_loss = _evaluate(eval_loader) + elif args.mode == 'valid': + valid_loss = _evaluate(eval_loader) + elif args.mode == 'test': + test_loss = _evaluate(test_loader) + + logger_info = '' + if valid_loss is not None: + logger_info = logger_info + _logger(valid_loss) + if test_loss is not None: + logger_info = logger_info + _logger(test_loss) + logger.info(logger_info) + + +if __name__ == "__main__": + ARGS = parse_args() + yaml_file = ARGS.config + with open(yaml_file, 'rt') as f: + args = AttrDict(yaml.safe_load(f)) + pprint(args) + + do_eval(args) diff --git a/PaddleNLP/examples/language_model/transformer-xl/gen_data.sh b/PaddleNLP/examples/language_model/transformer-xl/gen_data.sh new file mode 100644 index 00000000..865a8a58 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/gen_data.sh @@ -0,0 +1,55 @@ +echo "Downloading dataset..." + +CUR_DIR=$PWD + +mkdir -p gen_data +cd ./gen_data/ + +if [ ! -d "wikitext-103" ]; then + echo "Downloading wikitext-103..." + wget -O wikitext-103-v1.zip https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip + echo "Unzip wikitext-103..." + unzip wikitext-103-v1.zip + cd wikitext-103 + # Rename + mv wiki.train.tokens train.txt + mv wiki.valid.tokens valid.txt + mv wiki.test.tokens test.txt + cd - +fi + +if [ ! -d 'enwik8' ]; then + mkdir -p enwik8 + cd enwik8 + echo "Downloading enwik8..." + wget -O enwik8.zip http://mattmahoney.net/dc/enwik8.zip + wget -O prep_enwik8.py https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py + python3 prep_enwik8.py + rm -f prep_enwik8.py + cd - +fi + +if [ ! -d 'text8' ]; then + mkdir -p text8 + cd text8 + echo "Downloading text8..." + wget -O text8.zip http://mattmahoney.net/dc/text8.zip + python ${CUR_DIR}/utils/preprocess_text8.py 5000000 + cd - +fi + +if [ ! -d 'one-billion-words' ]; then + mkdir -p one-billion-words + cd one-billion-words + echo "Downloading one-billion-words..." + wget -O 1-billion-word-language-modeling-benchmark-r13output.tar.gz http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz + tar xzf 1-billion-word-language-modeling-benchmark-r13output.tar.gz + + dir="./1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/" + cat ${dir}/news.en.heldout-00000-of-00050 > valid.txt + cat ${dir}/news.en.heldout-00000-of-00050 > test.txt + wget -O 1b_word_vocab.txt https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt + cd - +fi + +echo "All done. " diff --git a/PaddleNLP/examples/language_model/transformer-xl/mem_transformer.py b/PaddleNLP/examples/language_model/transformer-xl/mem_transformer.py new file mode 100644 index 00000000..77f3b678 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/mem_transformer.py @@ -0,0 +1,1181 @@ +import re +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +global_dtype = paddle.get_default_dtype() + + +def einsum4x4(equation, x, y): + """ + Only works for 4D x 4D. + """ + idx_x, idx_y, idx_z = re.split(",|->", equation) + # Compute repeated index + repeated_idx = list(set(idx_x + idx_y) - set(idx_z)) + + unique_idx_x = list(set(idx_x) - set(idx_y)) + unique_idx_y = list(set(idx_y) - set(idx_x)) + common_idx = list(set(idx_x) & set(idx_y) - set(repeated_idx)) + + new_idx_x = common_idx + unique_idx_x + repeated_idx + new_idx_y = common_idx + unique_idx_y + repeated_idx + new_idx_z = common_idx + unique_idx_x + unique_idx_y + + perm_x = [idx_x.index(i) for i in new_idx_x] + perm_y = [idx_y.index(i) for i in new_idx_y] + perm_z = [new_idx_z.index(i) for i in idx_z] + + x = paddle.transpose(x, perm=perm_x) + y = paddle.transpose(y, perm=perm_y) + z = paddle.matmul(x=x, y=y, transpose_y=True) + z = paddle.transpose(z, perm=perm_z) + return z + + +def sample_logits(embedding, bias, labels, inputs, sampler): + true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) + n_sample = neg_samples.shape[0] + b1, b2 = labels.shape[0], labels.shape[1] + all_ids = paddle.concat([paddle.reshape(labels, shape=[-1]), neg_samples]) + all_w = embedding(all_ids) + true_w = paddle.reshape(all_w[:-n_sample], shape=[b1, b2, -1]) + sample_w = paddle.reshape(all_w[-n_sample:], shape=[n_sample, -1]) + + all_b = paddle.gather(bias, all_ids) + true_b = paddle.reshape(all_b[:-n_sample], shape=[b1, b2]) + sample_b = all_b[-n_sample:] + + hit = paddle.cast( + (labels.unsqueeze([2]) == neg_samples), dtype=global_dtype).detach() + true_logits = paddle.sum(true_w * inputs, axis=-1) + true_b - true_log_probs + sample_logits = paddle.transpose( + paddle.matmul(sample_w, paddle.transpose(inputs, [0, 2, 1])), + [0, 2, 1]) + sample_b - samp_log_probs + sample_logits = sample_logits - 1e30 * hit + logits = paddle.concat([true_logits.unsqueeze([2]), sample_logits], -1) + + return logits + + +class ProjAdaptiveSoftmax(nn.Layer): + """ + Combine projection and logsoftmax. + """ + + def __init__(self, + n_token, + d_embed, + d_proj, + cutoffs, + div_val=1, + keep_order=False): + super(ProjAdaptiveSoftmax, self).__init__() + + self.n_token = n_token + self.d_embed = d_embed + self.d_proj = d_proj + + self.cutoffs = cutoffs + [n_token] + self.cutoff_ends = [0] + self.cutoffs + self.div_val = div_val + + self.shortlist_size = self.cutoffs[0] + self.num_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.num_clusters + + if self.num_clusters > 0: + self.cluster_weight = paddle.create_parameter( + shape=[self.num_clusters, self.d_embed], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + self.cluster_bias = paddle.create_parameter( + shape=[self.num_clusters], + dtype=global_dtype, + is_bias=True, + default_initializer=paddle.nn.initializer.Constant(0.0)) + + self.out_layers_weight = nn.ParameterList() + self.out_layers_bias = nn.ParameterList() + self.out_projs = nn.ParameterList() + + if div_val == 1: + for i in range(len(self.cutoffs)): + if d_proj != d_embed: + self.out_projs.append( + paddle.create_parameter( + shape=[d_proj, d_embed], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01))) + else: + self.out_projs.append(None) + + self.out_layers_weight.append( + paddle.create_parameter( + shape=[n_token, d_embed], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Constant(0.0))) + self.out_layers_bias.append( + paddle.create_parameter( + shape=[n_token], + dtype=global_dtype, + is_bias=True, + default_initializer=paddle.nn.initializer.Constant(0.0))) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = d_embed // (div_val**i) + + self.out_projs.append( + paddle.create_parameter( + shape=[d_proj, d_emb_i], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01))) + + self.out_layers_weight.append( + paddle.create_parameter( + shape=[r_idx - l_idx, d_emb_i], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Uniform( + low=-(r_idx - l_idx)**(-1.0 / 2.0), + high=(r_idx - l_idx)**(-1.0 / 2.0)))) + self.out_layers_bias.append( + paddle.create_parameter( + shape=[r_idx - l_idx], + dtype=global_dtype, + is_bias=True, + default_initializer=paddle.nn.initializer.Uniform( + low=-(r_idx - l_idx)**(-1.0 / 2.0), + high=(r_idx - l_idx)**(-1.0 / 2.0)))) + + self.keep_order = keep_order + + def _compute_logits(self, hidden, weight, bias, proj=None): + if proj is None: + logit = F.linear(hidden, weight.t(), bias=bias) + else: + proj_hid = F.linear(hidden, proj) + logit = F.linear(proj_hid, weight.t(), bias=bias) + + return logit + + def forward(self, hidden, target, keep_order=False): + assert (hidden.shape[0] == target.shape[0]) + + if self.num_clusters == 0: + logit = self._compute_logits(hidden, self.out_layers_weight[0], + self.out_layers_bias[0], + self.out_projs[0]) + nll = -paddle.log(F.softmax(logit, axis=-1)) + idx = paddle.concat( + [ + paddle.arange(0, nll.shape[0]).unsqueeze([1]), + target.unsqueeze(1) + ], + axis=1) + nll = paddle.gather_nd(nll, idx) + else: + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers_weight[0][l_idx:r_idx] + bias_i = self.out_layers_bias[0][l_idx:r_idx] + else: + weight_i = self.out_layers_weight[i] + bias_i = self.out_layers_bias[i] + + if i == 0: + weight_i = paddle.concat( + [weight_i, self.cluster_weight], axis=0) + bias_i = paddle.concat([bias_i, self.cluster_bias], axis=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[ + 0], self.out_projs[0] + + head_logit = self._compute_logits(hidden, head_weight, head_bias, + head_proj) + head_logprob = paddle.log(F.softmax(head_logit, axis=-1)) + + nll = paddle.zeros_like(target, dtype=hidden.dtype) + + offset = 0 + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] + + mask_i = paddle.cast( + target >= l_idx, + dtype=paddle.get_default_dtype()) * paddle.cast( + target < r_idx, dtype="int64") + indices_i = paddle.nonzero(mask_i).squeeze([1]) + + if paddle.numel(indices_i) == 0: + continue + target_i = paddle.gather(target, indices_i, axis=0) - l_idx + head_logprob_i = paddle.gather(head_logprob, indices_i, axis=0) + if i == 0: + target_i_idx = paddle.concat( + [ + paddle.arange(0, head_logprob_i.shape[0]).unsqueeze( + [1]), target_i.unsqueeze([1]) + ], + axis=1) + logprob_i = head_logprob_i.gather_nd(target_i_idx) + else: + weight_i, bias_i, proj_i = weights[i], biases[ + i], self.out_projs[i].weight if self.out_projs[ + i] is not None else None + + hidden_i = paddle.gather(hidden, indices_i, axis=0) + + tail_logit_i = self._compute_logits(hidden_i, weight_i, + bias_i, proj_i) + tail_logprob_i = paddle.log( + F.softmax( + tail_logit_i, axis=-1)) + + target_i_idx = paddle.concat( + [ + paddle.arange(0, tail_logprob_i.shape[0]).unsqueeze( + [1]), target_i.unsqueeze([1]) + ], + axis=1) + logprob_i = tail_logprob_i.gather_nd(target_i_idx) + + logprob_i = head_logprob_i[:, -i] + logprob_i + + if self.keep_order or keep_order: + nll = paddle.scatter(nll, indices_i, -logprob_i) + else: + index = paddle.arange(offset, offset + logprob_i.shape[0], + 1) + nll = paddle.scatter(nll, index, -logprob_i) + + offset += logprob_i.shape[0] + + return nll + + +class LogUniformSampler(object): + def __init__(self, range_max, n_sample): + with paddle.no_grad(): + self.range_max = range_max + log_indices = paddle.log( + paddle.arange( + 1., range_max + 2., 1., dtype=global_dtype)) + self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] + + self.log_q = paddle.cast( + paddle.log( + paddle.exp(-(paddle.log1p(-paddle.cast( + self.dist, dtype=global_dtype)) * 2 * n_sample)) - 1), + dtype=global_dtype) + + self.n_sample = n_sample + + def sample(self, labels): + n_sample = self.n_sample + n_tries = 2 * n_sample + batch_size = labels.shape[0] + + with paddle.no_grad(): + neg_samples = paddle.unique( + paddle.multinomial( + self.dist, n_tries, replacement=True)) + true_log_probs = paddle.gather(self.log_q, labels.flatten()) + true_log_probs = paddle.reshape( + true_log_probs, shape=[batch_size, -1]) + samp_log_probs = paddle.gather(self.log_q, neg_samples) + return true_log_probs, samp_log_probs, neg_samples + + +class PositionEmbedding(nn.Layer): + def __init__(self, emb_dim): + super(PositionEmbedding, self).__init__() + self.emb_dim = emb_dim + self.inv_freq = 1.0 / (10000.0**(paddle.arange( + 0.0, emb_dim, 2.0, dtype=global_dtype) / emb_dim)) + + def forward(self, pos_seq, bsz=None): + sinusoid_inp = paddle.matmul( + pos_seq.unsqueeze([1]), self.inv_freq.unsqueeze([0])) + pos_emb = paddle.concat( + [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) + + if bsz is not None: + pos_emb = pos_emb.unsqueeze([0]).expand([bsz, -1, -1]) + pos_emb.stop_gradient = True + return pos_emb + else: + pos_emb = pos_emb.unsqueeze([0]) + pos_emb.stop_gradient = True + return pos_emb + + +class PositionwiseFFN(nn.Layer): + def __init__(self, d_model, d_inner, dropout, normalize_before=False): + super(PositionwiseFFN, self).__init__() + + self.d_model = d_model + self.d_inner = d_inner + + self.CoreNet = nn.Sequential( + nn.Linear( + d_model, + d_inner, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear( + d_inner, + d_model, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)), + nn.Dropout(dropout)) + self.layer_norm = nn.LayerNorm( + d_model, + weight_attr=paddle.nn.initializer.Normal( + mean=1.0, std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + self.normalize_before = normalize_before + + def forward(self, inp): + if self.normalize_before: + core_out = self.CoreNet(self.layer_norm(inp)) + output = core_out + inp + else: + core_out = self.CoreNet(inp) + output = self.layer_norm(inp + core_out) + return output + + +class MultiHeadAttn(nn.Layer): + def __init__(self, + n_head, + d_model, + d_head, + dropout, + attn_dropout=0, + normalize_before=False): + super(MultiHeadAttn, self).__init__() + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + + self.q_proj = nn.Linear( + d_model, + n_head * d_head, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=False) + self.kv_proj = nn.Linear( + d_model, + 2 * n_head * d_head, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=False) + self.drop = nn.Dropout(p=dropout) + self.attn_drop = nn.Dropout(p=attn_dropout) + self.o_proj = nn.Linear( + n_head * d_head, + d_model, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=False) + self.layer_norm = nn.LayerNorm( + d_model, + weight_attr=paddle.nn.initializer.Normal( + mean=1.0, std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + + self.scale = 1 / (d_head**0.5) + self.normalize_before = normalize_before + + def forward(self, h, attn_mask=None, mems=None): + if mems is not None: + c = paddle.concat([mems, h], axis=1) + else: + c = h + + if self.normalize_before: + c = self.layer_norm(c) + + head_q = self.q_proj(h) + head_k, head_v = paddle.chunk(self.kv_proj(c), chunks=2, axis=-1) + + head_q = paddle.reshape( + head_q, shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) + head_k = paddle.reshape( + head_k, shape=[c.shape[0], c.shape[1], self.n_head, self.d_head]) + head_v = paddle.reshape( + head_v, shape=[c.shape[0], c.shape[1], self.n_head, self.d_head]) + + attn_score = einsum4x4('bind,bjnd->bnij', head_q, head_k) + attn_score = attn_score * self.scale + if attn_mask is not None: + attn_score = attn_score - float('inf') * attn_mask + + attn_prob = F.softmax(attn_score, dim=-1) + attn_prob = self.attn_drop(attn_prob) + + attn_vec = einsum4x4('bnij,bjnd->bind', attn_prob, head_v) + attn_vec = paddle.reshape( + attn_vec, + shape=[ + attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head + ]) + + attn_out = self.o_proj(attn_vec) + attn_out = self.drop(attn_out) + if self.normalize_before: + output = h + attn_out + else: + output = self.layer_norm(h + attn_out) + + return output + + +class RelMultiHeadAttn(nn.Layer): + def __init__(self, + n_head, + d_model, + d_head, + dropout, + attn_dropout=0, + tgt_len=None, + ext_len=None, + mem_len=None, + normalize_before=False): + super(RelMultiHeadAttn, self).__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + + self.qkv_proj = nn.Linear( + d_model, + 3 * n_head * d_head, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=False) + + self.drop = nn.Dropout(dropout) + self.attn_drop = nn.Dropout(attn_dropout) + self.o_proj = nn.Linear( + n_head * d_head, + d_model, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=False) + + self.layer_norm = nn.LayerNorm( + d_model, + weight_attr=paddle.nn.initializer.Normal( + mean=1.0, std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + + self.scale = 1 / (d_head**0.5) + + self.normalize_before = normalize_before + + def _rel_shift(self, x, zero_triu=False): + x_shape = x.shape + zero_pad = paddle.zeros( + [x_shape[0], x_shape[1], x_shape[2], 1], dtype=x.dtype) + x_padded = paddle.concat([zero_pad, x], axis=-1) + + x_padded = paddle.reshape( + x_padded, + shape=[x_shape[0], x_shape[1], x_shape[3] + 1, x_shape[2]]) + + x = paddle.reshape(x_padded[:, :, 1:, :], shape=x_shape) + + if zero_triu: + ones = paddle.ones([x_shape[2], x_shape[3]]) + x = x * paddle.tril( + ones, diagonal=x_shape[3] - x_shape[2]).unsqueeze([2, 3]) + + return x + + def forward(self, w, r, attn_mask=None, mems=None): + raise NotImplementedError + + +class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): + def __init__(self, *args, **kwargs): + super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) + + self.r_proj = nn.Linear( + self.d_model, + self.n_head * self.d_head, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=False) + + def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): + qlen, rlen, bsz = w.shape[1], r.shape[1], w.shape[0] + + if mems is not None: + cat = paddle.concat([mems, w], axis=1) + if self.normalize_before: + w_heads = self.qkv_proj(self.layer_norm(cat)) + else: + w_heads = self.qkv_proj(cat) + r_head_k = self.r_proj(r) + + w_head_q, w_head_k, w_head_v = paddle.chunk( + w_heads, chunks=3, axis=-1) + + w_head_q = w_head_q[:, -qlen:, :] + else: + if self.normalize_before: + w_heads = self.qkv_proj(self.layer_norm(w)) + else: + w_heads = self.qkv_proj(w) + r_head_k = self.r_proj(r) + + w_head_q, w_head_k, w_head_v = paddle.chunk( + w_heads, chunks=3, axis=-1) + + klen = w_head_k.shape[1] + + w_head_q = paddle.reshape( + w_head_q, shape=[bsz, qlen, self.n_head, self.d_head]) + w_head_k = paddle.reshape( + w_head_k, shape=[bsz, klen, self.n_head, self.d_head]) + w_head_v = paddle.reshape( + w_head_v, shape=[bsz, klen, self.n_head, self.d_head]) + + r_head_k = paddle.reshape( + r_head_k, shape=[bsz, rlen, self.n_head, self.d_head]) + + rw_head_q = w_head_q + r_w_bias + + AC = einsum4x4('bind,bjnd->bnij', rw_head_q, w_head_k) + rr_head_q = w_head_q + r_r_bias + + # TODO: use einsum. einsum4x4 only works for 4D tensor. + BD = einsum4x4('bind,bjnd->bnij', rr_head_q, r_head_k) + BD = self._rel_shift(BD) + + attn_score = AC + BD + attn_score = attn_score * self.scale + + if attn_mask is not None: + attn_score = attn_score - 1e30 * attn_mask + + attn_prob = F.softmax(attn_score, axis=-1) + attn_prob = self.attn_drop(attn_prob) + + attn_vec = einsum4x4('bnij,bjnd->bind', attn_prob, w_head_v) + + attn_vec = paddle.reshape( + attn_vec, + shape=[ + attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head + ]) + + attn_out = self.o_proj(attn_vec) + attn_out = self.drop(attn_out) + + if self.normalize_before: + output = w + attn_out + else: + output = self.layer_norm(w + attn_out) + + return output + + +class RelLearnableMultiHeadAttn(RelMultiHeadAttn): + def __init__(self, *args, **kwargs): + super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) + + def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): + qlen, bsz = w.shape[1], w.shape[0] + + if mems is not None: + cat = paddle.concat([mems, w], 1) + if self.normalize_before: + w_heads = self.qkv_proj(self.layer_norm(cat)) + else: + w_heads = self.qkv_proj(cat) + w_head_q, w_head_k, w_head_v = paddle.chunk( + w_heads, chunks=3, axis=-1) + + w_head_q = w_head_q[-qlen:] + else: + if self.normalize_before: + w_heads = self.qkv_proj(self.layer_norm(w)) + else: + w_heads = self.qkv_proj(w) + w_head_q, w_head_k, w_head_v = paddle.chunk( + w_heads, chunks=3, axis=-1) + + klen = w_head_k.shape[1] + + w_head_q = paddle.reshape( + w_head_q, + shape=[ + w_head_q.shape[0], w_head_q.shape[1], self.n_head, self.d_head + ]) + w_head_k = paddle.reshape( + w_head_k, + shape=[ + w_head_k.shape[0], w_head_k.shape[1], self.n_head, self.d_head + ]) + w_head_v = paddle.reshape( + w_head_v, + shape=[ + w_head_v.shape[0], w_head_v.shape[1], self.n_head, self.d_head + ]) + + if klen > r_emb.shape[0]: + r_emb_pad = r_emb[0:1].expand(klen - r_emb.shape[0], -1, -1) + r_emb = paddle.concat([r_emb_pad, r_emb], 0) + r_bias_pad = r_bias[0:1].expand(klen - r_bias.shape[0], -1) + r_bias = paddle.concat([r_bias_pad, r_bias], 0) + else: + r_emb = r_emb[-klen:] + r_bias = r_bias[-klen:] + + rw_head_q = w_head_q + r_w_bias.unsqueeze([0]) + + AC = einsum4x4('bind,bjnd->bnij', rw_head_q, w_head_k) + r_emb = r_emb.unsqueeze([0]).expand([bsz, -1, -1, -1]) + B_ = einsum4x4('bind,bjnd->bnij', w_head_q, r_emb) + D_ = r_bias.unsqueeze([0, 2]) + BD = self._rel_shift(B_ + D_) + + attn_score = AC + BD + attn_score = attn_score * self.scale + + if attn_mask is not None: + attn_score = attn_score - float('inf') * attn_mask + + attn_prob = F.softmax(attn_score, dim=-1) + attn_prob = self.attn_drop(attn_prob) + + attn_vec = einsum4x4('bnij,bjnd->bind', attn_prob, w_head_v) + + attn_vec = paddle.reshape( + attn_vec, + shape=[ + attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head + ]) + + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.normalize_before: + output = w + attn_out + else: + output = self.layer_norm(w + attn_out) + + return output + + +class DecoderLayer(nn.Layer): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): + super(DecoderLayer, self).__init__() + + self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, + **kwargs) + self.pos_ff = PositionwiseFFN( + d_model, + d_inner, + dropout, + normalize_before=kwargs.get('normalize_before')) + + def forward(self, dec_inp, dec_attn_mask=None, mems=None): + + output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, mems=mems) + output = self.pos_ff(output) + + return output + + +class RelLearnableDecoderLayer(nn.Layer): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): + super(RelLearnableDecoderLayer, self).__init__() + + self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, + dropout, **kwargs) + self.pos_ff = PositionwiseFFN( + d_model, + d_inner, + dropout, + normalize_before=kwargs.get('normalize_before')) + + def forward(self, + dec_inp, + r_emb, + r_w_bias, + r_bias, + dec_attn_mask=None, + mems=None): + + output = self.dec_attn( + dec_inp, + r_emb, + r_w_bias, + r_bias, + attn_mask=dec_attn_mask, + mems=mems) + output = self.pos_ff(output) + + return output + + +class RelPartialLearnableDecoderLayer(nn.Layer): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): + super(RelPartialLearnableDecoderLayer, self).__init__() + + self.dec_attn = RelPartialLearnableMultiHeadAttn( + n_head, d_model, d_head, dropout, **kwargs) + self.pos_ff = PositionwiseFFN( + d_model, + d_inner, + dropout, + normalize_before=kwargs.get('normalize_before')) + + def forward(self, + dec_inp, + r, + r_w_bias, + r_r_bias, + dec_attn_mask=None, + mems=None): + output = self.dec_attn( + dec_inp, r, r_w_bias, r_r_bias, attn_mask=dec_attn_mask, mems=mems) + output = self.pos_ff(output) + + return output + + +class AdaptiveEmbedding(nn.Layer): + def __init__(self, + n_token, + d_embed, + d_proj, + cutoffs, + div_val=1, + sample_softmax=False): + super(AdaptiveEmbedding, self).__init__() + + self.n_token = n_token + self.d_embed = d_embed + + self.cutoffs = cutoffs + [n_token] + self.div_val = div_val + self.d_proj = d_proj + + self.emb_scale = d_proj**0.5 + + self.cutoff_ends = [0] + self.cutoffs + + self.emb_layers = nn.LayerList() + self.emb_projs = nn.ParameterList() + if div_val == 1: + self.emb_layers.append( + nn.Embedding( + n_token, + d_embed, + sparse=sample_softmax > 0, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01))) + if d_proj != d_embed: + self.emb_projs.append( + paddle.create_parameter( + shape=[d_embed, d_proj], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01))) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = d_embed // (div_val**i) + self.emb_layers.append( + nn.Embedding( + r_idx - l_idx, + d_emb_i, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01))) + self.emb_projs.append( + paddle.create_parameter( + shape=[d_emb_i, d_proj], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01))) + + def forward(self, inp): + if self.div_val == 1: + embed = self.emb_layers[0](inp) + if self.d_proj != self.d_embed: + embed = F.linear(embed, self.emb_projs[0]) + else: + inp_flat = paddle.reshape(inp, shape=[-1]) + emb_flat = paddle.zeros( + [inp_flat.shape[0], self.d_proj], dtype=global_dtype) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + + mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) + indices_i = paddle.nonzero(mask_i).squeeze([1]) + + if indices_i.numel() == 0: + continue + + inp_i = paddle.gather(inp_flat, indices_i, axis=0) - l_idx + emb_i = self.emb_layers[i](inp_i) + emb_i = F.linear(emb_i, self.emb_projs[i]) + + emb_flat = paddle.scatter(emb_flat, indices_i, emb_i) + + embed = paddle.reshape( + emb_flat, shape=inp.shape.append(self.d_proj)) + + embed = embed * self.emb_scale + + return embed + + +class MemTransformerLM(nn.Layer): + def __init__(self, + n_token, + n_layer, + n_head, + d_model, + d_head, + d_inner, + dropout, + attn_dropout, + tie_weight=True, + d_embed=None, + div_val=1, + tie_projs=[False], + normalize_before=False, + tgt_len=None, + ext_len=None, + mem_len=None, + cutoffs=[], + adapt_inp=False, + same_length=False, + attn_type=0, + clamp_len=-1, + sample_softmax=-1): + super(MemTransformerLM, self).__init__() + self.n_token = n_token + + d_embed = d_model if d_embed is None else d_embed + self.d_embed = d_embed + self.d_model = d_model + self.n_head = n_head + self.d_head = d_head + + self.word_emb = AdaptiveEmbedding( + n_token, d_embed, d_model, cutoffs, div_val=div_val) + + self.drop = nn.Dropout(dropout) + + self.n_layer = n_layer + + self.tgt_len = tgt_len + self.mem_len = mem_len + self.ext_len = ext_len + self.max_klen = tgt_len + ext_len + mem_len + + self.attn_type = attn_type + + self.layers = nn.LayerList() + if attn_type == 0: + for i in range(n_layer): + self.layers.append( + RelPartialLearnableDecoderLayer( + n_head, + d_model, + d_head, + d_inner, + dropout, + tgt_len=tgt_len, + ext_len=ext_len, + mem_len=mem_len, + attn_dropout=attn_dropout, + normalize_before=normalize_before)) + elif attn_type == 1: + for i in range(n_layer): + self.layers.append( + RelLearnableDecoderLayer( + n_head, + d_model, + d_head, + d_inner, + dropout, + tgt_len=tgt_len, + ext_len=ext_len, + mem_len=mem_len, + attn_dropout=attn_dropout, + normalize_before=normalize_before)) + elif attn_type in [2, 3]: + for i in range(n_layer): + self.layers.append( + DecoderLayer( + n_head, + d_model, + d_head, + d_inner, + dropout, + attn_dropout=attn_dropout, + normalize_before=normalize_before)) + + self.sample_softmax = sample_softmax + if sample_softmax > 0: + self.out_layer = nn.Linear( + d_model, + n_token, + weight_attr=paddle.nn.initializer.Normal( + mean=0.0, std=0.01), + bias_attr=paddle.nn.initializer.Constant(0.0)) + self.tie_weight = tie_weight + self.sampler = LogUniformSampler(n_token, sample_softmax) + else: + self.crit = ProjAdaptiveSoftmax( + n_token, d_embed, d_model, cutoffs, div_val=div_val) + + if tie_weight: + for i in range(len(self.crit.out_layers_weight)): + self.crit.out_layers_weight[i] = self.word_emb.emb_layers[ + i].weight + + if tie_projs: + for i, tie_proj in enumerate(tie_projs): + if tie_proj and div_val == 1 and d_model != d_embed: + self.crit.out_projs[i] = self.word_emb.emb_projs[0] + elif tie_proj and div_val != 1: + self.crit.out_projs[i] = self.word_emb.emb_projs[i] + + self.same_length = same_length + self.clamp_len = clamp_len + + self._create_params() + + def backward_compatible(self): + self.sample_softmax = -1 + + def _create_params(self): + if self.attn_type == 0: + self.pos_emb = PositionEmbedding(self.d_model) + self.r_w_bias = paddle.create_parameter( + shape=[self.n_head, self.d_head], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + self.r_r_bias = paddle.create_parameter( + shape=[self.n_head, self.d_head], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + elif self.attn_type == 1: + self.r_emb = paddle.create_parameter( + shape=[self.n_layer, self.max_klen, self.n_head, self.d_head], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + self.r_w_bias = paddle.create_parameter( + shape=[self.n_layer, self.n_head, self.d_head], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + self.r_bias = paddle.create_parameter( + shape=[self.n_layer, self.max_klen, self.n_head], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + elif self.attn_type == 2: + self.pos_emb = PositionEmbedding(self.d_model) + elif self.attn_type == 3: + self.r_emb = paddle.create_parameter( + shape=[self.n_layer, self.max_klen, self.n_head, self.d_head], + dtype=global_dtype, + default_initializer=paddle.nn.initializer.Normal( + mean=0.0, std=0.01)) + + def reset_length(self, tgt_len, ext_len, mem_len): + self.tgt_len = tgt_len + self.mem_len = mem_len + self.ext_len = ext_len + + def init_mems(self, batch_size, d_model): + if self.mem_len > 0: + mems = [] + for _ in range(self.n_layer + 1): + empty = paddle.empty( + shape=[batch_size, 0, d_model], dtype=global_dtype) + mems.append(empty) + + return mems + else: + return None + + def _update_mems(self, hids, mems, qlen, mlen): + if mems is None: return None + + assert len(hids) == len( + mems), "length of hids and length of mems must be the same. " + + with paddle.no_grad(): + new_mems = [] + end_idx = mlen + max(0, qlen - 0 - self.ext_len) + beg_idx = max(0, end_idx - self.mem_len) + for i in range(len(hids)): + cat = paddle.concat([mems[i], hids[i]], axis=1) + new_mems.append(cat[:, beg_idx:end_idx].detach()) + + return new_mems + + def _forward(self, dec_inputs, mems=None): + bsz, qlen = dec_inputs.shape + + word_emb = self.word_emb(dec_inputs) + + mlen = mems[0].shape[1] if mems is not None else 0 + klen = mlen + qlen + if self.same_length: + all_ones = paddle.ones(shape=[qlen, klen], dtype=word_emb.dtype) + mask_len = klen - self.mem_len + if mask_len > 0: + mask_shift_len = qlen - mask_len + else: + mask_shift_len = qlen + dec_attn_mask = (paddle.triu( + all_ones, diagonal=1 + mlen) + paddle.tril( + all_ones, -mask_shift_len)).unsqueeze([0, 1]) + else: + dec_attn_mask = paddle.ones( + shape=[qlen, klen], dtype=word_emb.dtype) + dec_attn_mask = paddle.triu( + dec_attn_mask, diagonal=1 + mlen).unsqueeze([0, 1]) + + hids = [] + if self.attn_type == 0: + pos_seq = paddle.arange(klen - 1, -1, -1.0, dtype=word_emb.dtype) + if self.clamp_len > 0: + # TODO: clamp and clip + pos_seq = paddle.clip(pos_seq, max=self.clamp_len) + pos_emb = self.pos_emb(pos_seq, bsz) + + core_out = self.drop(word_emb) + pos_emb = self.drop(pos_emb) + + hids.append(core_out) + for i, layer in enumerate(self.layers): + mems_i = None if mems is None else mems[i] + core_out = layer( + core_out, + pos_emb, + self.r_w_bias, + self.r_r_bias, + dec_attn_mask=dec_attn_mask, + mems=mems_i) + hids.append(core_out) + elif self.attn_type == 1: + core_out = self.drop(word_emb) + hids.append(core_out) + for i, layer in enumerate(self.layers): + if self.clamp_len > 0: + r_emb = self.r_emb[i][-self.clamp_len:] + r_bias = self.r_bias[i][-self.clamp_len:] + else: + r_emb, r_bias = self.r_emb[i], self.r_bias[i] + + mems_i = None if mems is None else mems[i] + core_out = layer( + core_out, + r_emb, + self.r_w_bias[i], + r_bias, + dec_attn_mask=dec_attn_mask, + mems=mems_i) + hids.append(core_out) + elif self.attn_type == 2: + pos_seq = paddle.arange(klen - 1, -1, -1.0, dtype=word_emb.dtype) + if self.clamp_len > 0: + pos_seq = paddle.clip(pos_seq, max=self.clamp_len) + pos_emb = self.pos_emb(pos_seq, bsz) + + core_out = self.drop(word_emb + pos_emb[-qlen:]) + + hids.append(core_out) + for i, layer in enumerate(self.layers): + mems_i = None if mems is None else mems[i] + if mems_i is not None and i == 0: + mems_i += pos_emb[:mlen] + core_out = layer( + core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) + hids.append(core_out) + elif self.attn_type == 3: + core_out = self.drop(word_emb) + + hids.append(core_out) + for i, layer in enumerate(self.layers): + mems_i = None if mems is None else mems[i] + if mems_i is not None and mlen > 0: + cur_emb = self.r_emb[i][:-qlen] + cur_size = cur_emb.size(0) + if cur_size < mlen: + cur_emb_pad = cur_emb[0:1].expand(mlen - cur_size, -1, + -1) + cur_emb = paddle.concat([cur_emb_pad, cur_emb], 0) + else: + cur_emb = cur_emb[-mlen:] + mems_i += cur_emb.view(mlen, 1, -1) + core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) + + core_out = layer( + core_out, dec_attn_mask=dec_attn_mask, mems=mems_i) + hids.append(core_out) + + core_out = self.drop(core_out) + + new_mems = self._update_mems(hids, mems, mlen, qlen) + + return core_out, new_mems + + def forward(self, data, target, *mems): + if not mems: + batch_size = data.shape[0] + mems = self.init_mems(batch_size, self.d_model) + + hidden, new_mems = self._forward(data, mems=mems) + + # TODO(FrostML): use getitem. + tgt_len = target.shape[1] + pred_hid = paddle.slice(hidden, [1], [-tgt_len], [hidden.shape[1]]) + if self.sample_softmax > 0 and self.training: + assert self.tie_weight, "tie_weight must be True if sample_softmax > 0" + logit = sample_logits(self.word_emb, self.out_layer.bias, target, + pred_hid, self.sampler) + loss = -paddle.log(F.softmax(logit, axis=-1))[:, :, 0] + else: + loss = self.crit( + paddle.reshape( + pred_hid, shape=[-1, pred_hid.shape[-1]]), + paddle.reshape( + target, shape=[-1])) + + if new_mems is None: + return [loss.mean()] + else: + return [loss.mean()] + new_mems diff --git a/PaddleNLP/examples/language_model/transformer-xl/reader.py b/PaddleNLP/examples/language_model/transformer-xl/reader.py new file mode 100644 index 00000000..171743c6 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/reader.py @@ -0,0 +1,197 @@ +import os + +import numpy as np + +from paddlenlp.data import Vocab + +import paddle +from paddle.io import IterableDataset, DataLoader +import paddle.distributed as dist + + +class LMDataset(IterableDataset): + def __init__(self, mode, vocab, path, dataset_name, batch_size, bptt, + ext_len, nranks, rank): + assert (mode in ["train", "valid", "test"] + ), "Parameter mode must be one of [train, valid, test]." + + super(LMDataset, self).__init__() + self.vocab = vocab + self.dataset_name = dataset_name + + if self.dataset_name in ["wt103"]: + self.data = self.read_raw_data( + filename=os.path.join(path, mode + ".txt"), ordered=True) + elif self.dataset_name in ["enwik8", "text8"]: + self.data = self.read_raw_data( + filename=os.path.join(path, mode + ".txt"), + ordered=True, + add_eos=False) + else: + raise ValueError("Not supported dataset yet. ") + self.rank = rank + self.batch_size = batch_size + batch_size *= nranks + + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.num_step = len(self.data) // batch_size + data = self.data[:self.num_step * batch_size] + self.data = data.reshape([batch_size, -1]) + + # Number of samples + self.num_samples = (self.num_step + self.bptt - 1) // self.bptt + + def __len__(self): + return self.num_samples + + def __iter__(self): + for i in range(0, self.data.shape[1] - 1, self.bptt): + seq_len = min(self.bptt, self.data.shape[1] - 1 - i) + end_idx = i + seq_len + beg_idx = max(0, i - self.ext_len) + src = self.data[:, beg_idx:end_idx] + target = self.data[:, i + 1:i + 1 + seq_len] + + # NOTE: `seq_len` will be transfered to numpy immediately + # after returned by DataLoader. Hence, `seq_len` can be + # yield as `int`. And the returned tensor `seq_len`'s shape + # will be empty []. + # However, if it's necessary to use `seq_len` as input for some + # PaddlePaddle op, then it must be returned by `[seq_len]` whose + # shape is [1], cause some op cannot use shape [] as input. + yield [ + src[self.rank * self.batch_size:(self.rank + 1) * + self.batch_size], target[self.rank * self.batch_size:( + self.rank + 1) * self.batch_size], seq_len + ] + + def read_raw_data(self, + filename, + ordered=False, + lower_case=True, + delimiter=None, + add_eos=True, + add_double_eos=False): + assert os.path.exists(filename), "%s is not exist. " % filename + + data = [] + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + tokens = LMDataset.tokenize( + line=line, delimiter=delimiter, lower_case=lower_case) + if add_double_eos: # for lm1b + tokens = [self.vocab._identifiers_to_tokens['bos_token'] + ] + tokens + [ + self.vocab._identifiers_to_tokens['bos_token'] + ] + elif add_eos: + tokens = tokens + [ + self.vocab._identifiers_to_tokens['eos_token'] + ] + data.append( + np.asarray(self.get_indices(tokens)).astype("int64")) + + if ordered: + data = np.concatenate(data) + + return data + + def get_indices(self, tokens): + return self.vocab.to_indices(tokens) + + @classmethod + def get_vocab(cls, + files, + max_size=None, + min_freq=0, + lower_case=True, + delimiter=None, + unk_token=None, + pad_token=None, + bos_token=None, + eos_token=None, + **kwargs): + return Vocab.build_vocab( + cls.data_iterator( + files=files, delimiter=delimiter, lower_case=lower_case), + max_size=max_size, + min_freq=min_freq, + unk_token=unk_token, + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token) + + @classmethod + def tokenize(cls, line, delimiter=None, lower_case=True): + line = line.strip() + if lower_case: + line = line.lower() + tokens = list(line) if delimiter == "" else line.split(delimiter) + return tokens + + @classmethod + def data_iterator(cls, files, delimiter=None, lower_case=True): + if isinstance(files, str): + files = [files] + elif not isinstance(files, (list, tuple)): + raise ValueError( + "The parameter files must be a str or a list/tuple.") + + for fl in files: + assert os.path.exists(fl), "%s is not exist. " % fl + + with open(fl, 'r', encoding='utf-8') as f: + for line in f: + tokens = cls.tokenize( + line=line, delimiter=delimiter, lower_case=lower_case) + yield tokens + + +def get_lm_data_loader(args, vocab, mode="train"): + lm_dataset = LMDataset( + mode=mode, + vocab=vocab, + path=args.data, + dataset_name=args.dataset, + batch_size=args.batch_size if mode == "train" else args.eval_batch_size, + bptt=args.tgt_len, + ext_len=args.ext_len, + nranks=dist.get_world_size() if mode == "train" else 1, + rank=dist.get_rank() if mode == "train" else 0) + + data_loader = DataLoader( + dataset=lm_dataset, batch_size=None, num_workers=0, return_list=True) + + return data_loader + + +def get_lm_vocab(args): + kwargs = {"unk_token": ""} + if args.token_delimiter == "None": + kwargs["delimiter"] = None + else: + kwargs["delimiter"] = args.token_delimiter + + if args.dataset == "wt103": + kwargs["eos_token"] = "" + kwargs["lower_case"] = False + + if args.dataset in ["enwik8", "text8"]: + files = [ + os.path.join(args.data, "train.txt"), + os.path.join(args.data, "valid.txt"), + os.path.join(args.data, "test.txt") + ] + elif args.dataset == "wt103": + files = [os.path.join(args.data, "train.txt")] + else: + raise ValueError("Not supported dataset yet. ") + + vocab = LMDataset.get_vocab(files, **kwargs) + args.ntokens = len(vocab) + print("Finish processing vocabulary, and the size of vocabulary is {}". + format(args.ntokens)) + + return vocab diff --git a/PaddleNLP/examples/language_model/transformer-xl/train.py b/PaddleNLP/examples/language_model/transformer-xl/train.py new file mode 100644 index 00000000..116a9352 --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/train.py @@ -0,0 +1,308 @@ +import os +import time +import yaml +import logging +import argparse +import numpy as np +from pprint import pprint +from attrdict import AttrDict + +import paddle +import paddle.nn as nn +import paddle.distributed as dist + +from mem_transformer import MemTransformerLM +from reader import get_lm_vocab, get_lm_data_loader + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="./configs/enwik8.yaml", + type=str, + help="Path of the config file. ") + args = parser.parse_args() + return args + + +def do_train(args): + if args.use_gpu: + rank = dist.get_rank() + trainer_count = dist.get_world_size() + else: + rank = 0 + trainer_count = 1 + + if trainer_count > 1: + dist.init_parallel_env() + + random_seed = eval(str(args.random_seed)) + if random_seed is not None: + paddle.seed(random_seed) + + vocab = get_lm_vocab(args) + train_loader = get_lm_data_loader(args, vocab, "train") + eval_loader = get_lm_data_loader(args, vocab, "valid") + + cutoffs, tie_projs = [], [False] + if args.adaptive: + assert args.dataset in ['wt103', 'lm1b'] + if args.dataset == 'wt103': + cutoffs = [20000, 40000, 200000] + tie_projs += [True] * len(cutoffs) + elif args.dataset == 'lm1b': + cutoffs = [60000, 100000, 640000] + tie_projs += [False] * len(cutoffs) + + mem_transformer = MemTransformerLM( + args.ntokens, + args.n_layer, + args.n_head, + args.d_model, + args.d_head, + args.d_inner_hid, + args.dropout, + args.attn_dropout, + tie_weight=args.tie_weight, + d_embed=args.d_model, + div_val=args.div_val, + tie_projs=tie_projs, + normalize_before=args.normalize_before, + tgt_len=args.tgt_len, + ext_len=args.ext_len, + mem_len=args.mem_len, + cutoffs=cutoffs, + same_length=args.same_length, + attn_type=args.attn_type, + clamp_len=args.clamp_len, + sample_softmax=args.sample_softmax) + + if args.scheduler == 'cosine': + scheduler = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=args.learning_rate, + T_max=args.max_step, + eta_min=args.eta_min) + elif args.scheduler == 'noam': + scheduler = paddle.optimizer.lr.NoamDecay( + d_model=args.d_model, + warmup_steps=args.warmup_steps, + learning_rate=args.learning_rate) + elif args.scheduler == 'dev_perf': + # fluid api + scheduler = paddle.fluid.dygraph.ReduceLROnPlateau( + learning_rate=args.learning_rate, + decay_rate=args.decay_rate, + patience=args.patience, + min_lr=args.lr_min) + elif args.scheduler == 'constant': + scheduler = args.learning_rate + + clip = paddle.nn.ClipGradByGlobalNorm(args.clip) + if args.optim.lower() == 'momentum': + optimizer = paddle.optimizer.Momentum( + learning_rate=scheduler, + parameters=mem_transformer.parameters(), + momentum=args.mom, + grad_clip=clip) + elif args.optim.lower() == 'adam': + optimizer = paddle.optimizer.Adam( + learning_rate=scheduler, + parameters=mem_transformer.parameters(), + beta1=args.beta1, + beta2=args.beta2, + epsilon=eval(args.eps), + grad_clip=clip) + elif args.optim.lower() == 'adagrad': + optimizer = paddle.optimizer.Adagrad( + learning_rate=scheduler, + parameters=mem_transformer.parameters(), + grad_clip=clip) + + # Init from some checkpoint, to resume the previous training + if args.init_from_checkpoint: + model_dict = paddle.load( + os.path.join(args.init_from_checkpoint, "mem_transformer.pdparams")) + opt_dict = paddle.load( + os.path.join(args.init_from_checkpoint, "mem_transformer.pdopt")) + mem_transformer.set_state_dict(model_dict) + optimizer.set_state_dict(opt_dict) + print("loaded from checkpoint.") + # Init from some pretrain models, to better solve the current task + if args.init_from_pretrain_model: + model_dict = paddle.load( + os.path.join(args.init_from_pretrain_model, + "mem_transformer.pdparams")) + mem_transformer.set_state_dict(model_dict) + print("loaded from pre-trained model.") + + if trainer_count > 1: + mem_transformer = paddle.DataParallel(mem_transformer) + + step_idx = 0 + train_loss = 0.0 + + log_start_time = time.time() + + for pass_id in range(args.epoch): + batch_id = 0 + + mems = tuple() + for input_data in train_loader: + (src, target, seq_len) = input_data + ret = mem_transformer(src, target, *mems) + loss = ret[0] + mems = ret[1:] + train_loss += loss.numpy() + + loss.backward() + optimizer.step() + optimizer.clear_grad() + + if step_idx > 0 and step_idx % args.print_step == 0 and rank == 0: + cur_loss = train_loss / args.print_step + elapsed = time.time() - log_start_time + if args.scheduler == "constant": + lr = optimizer.get_lr() + else: + lr = scheduler.get_lr() + logger_info = "step_idx: %d, epoch: %d, batch: %d, learning rate: %.8f, " \ + "speed: %f ms/batch, loss: %f" % \ + (step_idx, pass_id, batch_id, lr, + elapsed * 1000.0 / args.print_step, cur_loss) + if args.dataset in ["enwik8", "text8"]: + logger_info = logger_info + ", bpc: %f" % (cur_loss / + np.log(2)) + else: + logger_info = logger_info + ", ppl: %f" % (np.exp(cur_loss)) + + logger.info(logger_info) + train_loss = 0.0 + log_start_time = time.time() + + if step_idx % args.save_step == 0 and step_idx != 0: + # Do validation. + mem_transformer.eval() + + # TODO(FrostML): simplify this. + if args.mem_len == 0: + if dist.get_world_size() == 1: + mem_transformer.reset_length( + tgt_len=args.eval_tgt_len, + ext_len=args.ext_len + args.tgt_len - + args.eval_tgt_len, + mem_len=args.mem_len) + else: + mem_transformer._layers.reset_length( + tgt_len=args.eval_tgt_len, + ext_len=args.ext_len + args.tgt_len - + args.eval_tgt_len, + mem_len=args.mem_len) + else: + if dist.get_world_size() == 1: + mem_transformer.reset_length( + tgt_len=args.eval_tgt_len, + ext_len=args.ext_len, + mem_len=args.mem_len + args.tgt_len - + args.eval_tgt_len) + else: + mem_transformer._layers.reset_length( + tgt_len=args.eval_tgt_len, + ext_len=args.ext_len, + mem_len=args.mem_len + args.tgt_len - + args.eval_tgt_len) + + total_len, total_loss = 0, 0. + + eval_mems = tuple() + with paddle.no_grad(): + for i, (src, target, seq_len) in enumerate(eval_loader): + if args.max_eval_steps > 0 and i >= args.max_eval_steps: + break + ret = mem_transformer(src, target, *eval_mems) + loss, eval_mems = ret[0], ret[1:] + seq_len = seq_len.numpy() + eval_cur_loss = seq_len * loss.numpy() + total_loss += eval_cur_loss + total_len += seq_len + eval_loss = total_loss / total_len + + logger_info = "Validation, step_idx: %d, validation loss: %f" % \ + (step_idx, eval_loss) + if args.dataset in ['enwik8', 'text8']: + logger_info = logger_info + ", bpc: %f" % (eval_loss / + np.log(2)) + else: + logger_info = logger_info + ", ppl: %f" % (np.exp(eval_loss) + ) + logger.info(logger_info) + + if args.save_model and rank == 0: + model_dir = os.path.join(args.save_model, + "step_" + str(step_idx)) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + paddle.save( + mem_transformer.state_dict(), + os.path.join(model_dir, "mem_transformer.pdparams")) + paddle.save( + optimizer.state_dict(), + os.path.join(model_dir, "mem_transformer.pdopt")) + + if args.scheduler == 'dev_perf': + scheduler.step(eval_loss) + + # TODO(FrostML): simplify this. + if dist.get_world_size() == 1: + mem_transformer.reset_length( + tgt_len=args.tgt_len, + ext_len=args.ext_len, + mem_len=args.mem_len) + else: + mem_transformer._layers.reset_length( + tgt_len=args.tgt_len, + ext_len=args.ext_len, + mem_len=args.mem_len) + + mem_transformer.train() + + step_idx += 1 + batch_id += 1 + if args.scheduler in ['cosine', 'dev_perf']: + if step_idx < args.warmup_steps: + curr_lr = args.learning_rate * step_idx / args.warmup_steps + scheduler.base_lr = curr_lr + else: + if args.scheduler == 'cosine': + scheduler.step() + elif args.scheduler == 'constant': + if step_idx < args.warmup_steps: + curr_lr = args.learning_rate * step_idx / args.warmup_steps + optimizer.set_lr(curr_lr) + elif args.scheduler == 'noam': + scheduler.step() + if step_idx >= args.max_step: + break + + if args.save_model and rank == 0: + model_dir = os.path.join(args.save_model, "step_final") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + paddle.save(mem_transformer.state_dict(), + os.path.join(model_dir, "mem_transformer.pdparams")) + paddle.save(optimizer.state_dict(), + os.path.join(model_dir, "mem_transformer.pdopt")) + + +if __name__ == "__main__": + ARGS = parse_args() + yaml_file = ARGS.config + with open(yaml_file, 'rt') as f: + args = AttrDict(yaml.safe_load(f)) + pprint(args) + + do_train(args) diff --git a/PaddleNLP/examples/language_model/transformer-xl/utils/preprocess_text8.py b/PaddleNLP/examples/language_model/transformer-xl/utils/preprocess_text8.py new file mode 100644 index 00000000..b6b4d7df --- /dev/null +++ b/PaddleNLP/examples/language_model/transformer-xl/utils/preprocess_text8.py @@ -0,0 +1,21 @@ +import sys +import zipfile +import argparse + +if __name__ == "__main__": + data = zipfile.ZipFile("text8.zip").extractall() + data = open("text8", "r", encoding="utf-8").read() + + num_test_char = int(sys.argv[1]) + + train_data = data[:-2 * num_test_char] + valid_data = data[-2 * num_test_char:-num_test_char] + test_data = data[-num_test_char:] + + for files, data in [("train.txt", train_data), ("valid.txt", valid_data), + ("test.txt", test_data)]: + data_str = " ".join(["_" if c == " " else c for c in data.strip()]) + with open(files, "w") as f: + f.write(data_str) + with open(files + ".raw", "w", encoding="utf-8") as fw: + fw.write(data) -- GitLab