From 33df1dc6832667ad333b76f354818013c002fe10 Mon Sep 17 00:00:00 2001 From: zhanghan17 Date: Thu, 20 May 2021 01:11:38 +0800 Subject: [PATCH] ernie-gram release demos --- ernie/modeling_ernie.py | 54 +- ernie/tokenizing_ernie.py | 2 + .../.meta/ernie-gram.jpeg | Bin {ernie-gram => ernie_gram}/README.en.md | 0 {ernie-gram => ernie_gram}/README.md | 0 {ernie-gram => ernie_gram}/README.zh.md | 0 ernie_gram/__init__.py | 0 ernie_gram/finetune_classifier_distributed.py | 239 +++++++ ernie_gram/finetune_mrc.py | 250 ++++++++ ernie_gram/finetune_ner.py | 261 ++++++++ ernie_gram/mrc/__init__.py | 0 ernie_gram/mrc/mrc_metrics.py | 602 ++++++++++++++++++ ernie_gram/mrc/mrc_reader.py | 303 +++++++++ ernie_gram/optimization.py | 136 ++++ ernie_gram/run_cls.sh | 13 + ernie_gram/run_mrc.sh | 9 + ernie_gram/run_ner.sh | 9 + ernie_gram/task_configs/cmrc_conf | 5 + ernie_gram/task_configs/msra_ner_conf | 5 + ernie_gram/task_configs/xnli_conf | 9 + ernie_gram/utils.py | 47 ++ 21 files changed, 1941 insertions(+), 3 deletions(-) rename {ernie-gram => ernie_gram}/.meta/ernie-gram.jpeg (100%) rename {ernie-gram => ernie_gram}/README.en.md (100%) rename {ernie-gram => ernie_gram}/README.md (100%) rename {ernie-gram => ernie_gram}/README.zh.md (100%) create mode 100644 ernie_gram/__init__.py create mode 100644 ernie_gram/finetune_classifier_distributed.py create mode 100644 ernie_gram/finetune_mrc.py create mode 100644 ernie_gram/finetune_ner.py create mode 100644 ernie_gram/mrc/__init__.py create mode 100644 ernie_gram/mrc/mrc_metrics.py create mode 100644 ernie_gram/mrc/mrc_reader.py create mode 100644 ernie_gram/optimization.py create mode 100644 ernie_gram/run_cls.sh create mode 100644 ernie_gram/run_mrc.sh create mode 100644 ernie_gram/run_ner.sh create mode 100644 ernie_gram/task_configs/cmrc_conf create mode 100644 ernie_gram/task_configs/msra_ner_conf create mode 100644 ernie_gram/task_configs/xnli_conf create mode 100644 ernie_gram/utils.py diff --git a/ernie/modeling_ernie.py b/ernie/modeling_ernie.py index 2d0637c..65e76b5 100644 --- a/ernie/modeling_ernie.py +++ b/ernie/modeling_ernie.py @@ -19,12 +19,13 @@ from __future__ import unicode_literals import json import logging +import math import six if six.PY2: from pathlib2 import Path else: from pathlib import Path - +import numpy as np import paddle as P from paddle import nn from paddle.nn import functional as F @@ -36,7 +37,35 @@ ACT_DICT = { 'relu': nn.ReLU, 'gelu': nn.GELU, } +def _get_rel_pos_bias(seq_len, max_len=128, num_buckets=32, bidirectional=True, reset=True): + #max_len = 520 + pos = np.array(range(seq_len)) + rel_pos = pos[:, None] - pos[None, :] + ret = 0 + n = -rel_pos + if bidirectional: + num_buckets //= 2 + ret += (n < 0).astype('int32') * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets + n = np.abs(n) + else: + n = np.max(n, np.zeros_like(n)) + # now n is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = n < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + val_if_large = max_exact + (np.log(n.astype('float32') / max_exact) / math.log(max_len / max_exact) * (num_buckets - max_exact)).astype('int32') + tmp = np.full_like(val_if_large, num_buckets-1) + val_if_large = np.where(val_if_large < tmp, val_if_large, tmp) + + ret += np.where(is_small, n, val_if_large) + if reset: + num_buckets *= 2 + ret[:, 0] = num_buckets + ret[0, :] = num_buckets // 2 + return np.array(ret).reshape([seq_len, seq_len]).astype("int64") def _build_linear(n_in, n_out, name, init): return nn.Linear( @@ -223,6 +252,8 @@ class PretrainedModel(object): 'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz', 'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz', 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', + 'ernie-gram-zh': bce + 'model-ernie-gram-zh.1.tar.gz', + 'ernie-gram-en': bce + 'model-ernie-gram-en.1.tar.gz', } @classmethod @@ -283,10 +314,14 @@ class ErnieModel(nn.Layer, PretrainedModel): d_vocab = cfg['vocab_size'] d_pos = cfg['max_position_embeddings'] d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size'] + self.d_rel_pos = cfg.get('rel_pos_size', None) + max_seq_len = cfg.get("max_seq_len", 512) self.n_head = cfg['num_attention_heads'] self.return_additional_info = cfg.get('return_additional_info', False) initializer = nn.initializer.TruncatedNormal( std=cfg['initializer_range']) + if self.d_rel_pos: + self.rel_pos_bias = _get_rel_pos_bias(max_seq_len) self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder')) self.word_emb = nn.Embedding( @@ -307,6 +342,13 @@ class ErnieModel(nn.Layer, PretrainedModel): weight_attr=P.ParamAttr( name=append_name(name, 'sent_embedding'), initializer=initializer)) + if self.d_rel_pos: + self.rel_pos_bias_emb = nn.Embedding( + self.d_rel_pos, + self.n_head, + weight_attr=P.ParamAttr( + name=append_name(name, 'rel_pos_embedding'), + initializer=initializer)) prob = cfg['hidden_dropout_prob'] self.dropout = nn.Dropout(p=prob) @@ -347,6 +389,7 @@ class ErnieModel(nn.Layer, PretrainedModel): attn_bias=None, past_cache=None, use_causal_mask=False): + """ Args: src_ids (`Variable` of shape `[batch_size, seq_len]`): @@ -402,15 +445,20 @@ class ErnieModel(nn.Layer, PretrainedModel): attn_bias = (1. - attn_bias) * -10000.0 attn_bias = attn_bias.unsqueeze(1).tile( [1, self.n_head, 1, 1]) # avoid broadcast =_= - + attn_bias.stop_gradient=True if sent_ids is None: sent_ids = P.zeros_like(src_ids) - + if self.d_rel_pos: + rel_pos_ids = self.rel_pos_bias[:d_seqlen, :d_seqlen] + rel_pos_ids = P.to_tensor(rel_pos_ids, dtype='int64') + rel_pos_bias = self.rel_pos_bias_emb(rel_pos_ids).transpose([2, 0, 1]) + attn_bias += rel_pos_bias src_embedded = self.word_emb(src_ids) pos_embedded = self.pos_emb(pos_ids) sent_embedded = self.sent_emb(sent_ids) embedded = src_embedded + pos_embedded + sent_embedded + embedded = self.dropout(self.ln(embedded)) encoded, hidden_list, cache_list = self.encoder_stack( diff --git a/ernie/tokenizing_ernie.py b/ernie/tokenizing_ernie.py index b4cc247..7b866d8 100644 --- a/ernie/tokenizing_ernie.py +++ b/ernie/tokenizing_ernie.py @@ -87,6 +87,8 @@ class ErnieTokenizer(object): 'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz', 'ernie-gen-base-en': bce + 'model-ernie-gen-base-en.1.tar.gz', 'ernie-gen-large-en': bce + 'model-ernie-gen-large-en.1.tar.gz', + 'ernie-gram-zh': bce + 'model-ernie-gram-zh.1.tar.gz', + 'ernie-gram-en': bce + 'model-ernie-gram-en.1.tar.gz', } @classmethod diff --git a/ernie-gram/.meta/ernie-gram.jpeg b/ernie_gram/.meta/ernie-gram.jpeg similarity index 100% rename from ernie-gram/.meta/ernie-gram.jpeg rename to ernie_gram/.meta/ernie-gram.jpeg diff --git a/ernie-gram/README.en.md b/ernie_gram/README.en.md similarity index 100% rename from ernie-gram/README.en.md rename to ernie_gram/README.en.md diff --git a/ernie-gram/README.md b/ernie_gram/README.md similarity index 100% rename from ernie-gram/README.md rename to ernie_gram/README.md diff --git a/ernie-gram/README.zh.md b/ernie_gram/README.zh.md similarity index 100% rename from ernie-gram/README.zh.md rename to ernie_gram/README.zh.md diff --git a/ernie_gram/__init__.py b/ernie_gram/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ernie_gram/finetune_classifier_distributed.py b/ernie_gram/finetune_classifier_distributed.py new file mode 100644 index 0000000..4ddd7c5 --- /dev/null +++ b/ernie_gram/finetune_classifier_distributed.py @@ -0,0 +1,239 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import logging +import json +import re +from random import random +from functools import reduce, partial + +import numpy as np +import logging +#from visualdl import LogWriter + +from pathlib import Path +import paddle as P +from propeller import log +import propeller.paddle as propeller + +#from model.bert import BertConfig, BertModelLayer +from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification +from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer +from ernie_gram.optimization import AdamW +from ernie_gram.utils import create_if_not_exists, get_warmup_and_linear_decay + +log.setLevel(logging.DEBUG) +logging.getLogger().setLevel(logging.DEBUG) + +parser = propeller.ArgumentParser('classify model with ERNIE') +parser.add_argument( + '--from_pretrained', + type=Path, + required=True, + help='pretrained model directory or tag') +parser.add_argument( + '--max_seqlen', + type=int, + default=128, + help='max sentence length, should not greater than 512') +parser.add_argument('--bsz', type=int, default=32, help='batchsize') +parser.add_argument( + '--data_dir', + type=str, + required=True, + help='data directory includes train / develop data') +parser.add_argument( + '--max_steps', + type=int, + required=True, + help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE') +parser.add_argument('--warmup_proportion', type=float, default=0.1) +parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') +parser.add_argument('--lr_decay', type=float, default=0.8, help='layerwise learning decay rate') +parser.add_argument('--decay_layers', type=float, default=12, help='number of layers for layerwise learning decay') +parser.add_argument('--label_map', type=str, default="", help='str to int') +parser.add_argument('--num_labels', type=int, default=2, help='number of labels') +parser.add_argument('--valid_steps', type=int, default=100, help='The steps interval to evaluate model performance.') +parser.add_argument('--pair_input', type=int, default=0, help='is sentence pair task or not') + +parser.add_argument( + '--save_dir', type=Path, required=True, help='model output directory') +parser.add_argument( + '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') +parser.add_argument( + '--init_checkpoint', + type=str, + default=None, + help='checkpoint to warm start from') + +parser.add_argument( + '--use_amp', + action='store_true', + help='only activate AMP(auto mixed precision accelatoin) on TensorCore compatible devices' +) + +args = parser.parse_args() +env = P.distributed.ParallelEnv() + +tokenizer = ErnieTokenizer.from_pretrained(args.from_pretrained) +#tokenizer = ErnieTinyTokenizer.from_pretrained(args.from_pretrained) + +if args.label_map: + label_map = {k.encode(): v for k, v in json.loads(args.label_map).items()} +else: + label_map = {str(l).encode(): l for l in range(args.num_labels)} + +text_col_names = ["seg_a", "seg_b"] if args.pair_input else ["seg_a"] +feature_column = propeller.data.FeatureColumns([ + propeller.data.TextColumn( + col_name, + unk_id=tokenizer.unk_id, + vocab_dict=tokenizer.vocab, + tokenizer=tokenizer.tokenize) for col_name in text_col_names] + [ + propeller.data.LabelColumn( + 'label', + vocab_dict=label_map), +]) + + +def map_fn_pair(seg_a, seg_b, label): + seg_a, seg_b = tokenizer.truncate(seg_a, seg_b, seqlen=args.max_seqlen) + sentence, segments = tokenizer.build_for_ernie(seg_a, seg_b) + return sentence, segments, label + +def map_fn_single(seg_a, label): + seg_a, _ = tokenizer.truncate(seg_a, [], seqlen=args.max_seqlen) + sentence, segments = tokenizer.build_for_ernie(seg_a, []) + return sentence, segments, label + +map_fn = map_fn_pair if args.pair_input else map_fn_single + +train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), + shuffle=True, repeat=True, use_gz=False, shard=True) \ + .map(map_fn) \ + .padded_batch(args.bsz, (0, 0, 0)) + +dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), + shuffle=False, repeat=False, use_gz=False) \ + .map(map_fn) \ + .padded_batch(args.bsz, (0, 0, 0)) +test_ds = feature_column.build_dataset('test', data_dir=os.path.join(args.data_dir, 'test'), + shuffle=False, repeat=False, use_gz=False) \ + .map(map_fn) \ + .padded_batch(args.bsz, (0, 0, 0)) + +shapes = ([-1, args.max_seqlen], [-1, args.max_seqlen], [-1]) +types = ('int64', 'int64', 'int64') + +P.distributed.init_parallel_env() +model = ErnieModelForSequenceClassification.from_pretrained( + args.from_pretrained, num_labels=args.num_labels, name='') + +if args.init_checkpoint is not None: + log.info('loading checkpoint from %s' % args.init_checkpoint) + sd = P.load(args.init_checkpoint) + model.set_state_dict(sd) + +model = P.DataParallel(model) + +g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental +param_name_to_exclue_from_weight_decay = re.compile( + r'.*layer_norm_scale|.*layer_norm_bias|.*b_0') + +lr_scheduler = P.optimizer.lr.LambdaDecay( + args.lr, + get_warmup_and_linear_decay(args.max_steps, + int(args.warmup_proportion * args.max_steps))) + +opt = AdamW( + learning_rate=lr_scheduler, + parameters=model.parameters(), + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), + weight_decay=args.wd, + grad_clip=g_clip, + layerwise_lr_decay_rate=args.lr_decay, + n_layers=args.decay_layers) +scaler = P.amp.GradScaler(enable=args.use_amp) +step = 0 +create_if_not_exists(args.save_dir) + +#with LogWriter(logdir=str(create_if_not_exists(args.save_dir / 'vdl-%d' % env.dev_id))) as log_writer: +with P.amp.auto_cast(enable=args.use_amp): + for ids, sids, label in P.io.DataLoader( + train_ds, places=P.CUDAPlace(env.dev_id), batch_size=None): + step += 1 + loss, _ = model(ids, sids, labels=label) + loss = scaler.scale(loss) + loss.backward() + scaler.minimize(opt, loss) + model.clear_gradients() + lr_scheduler.step() + + # do logging + if step % 10 == 0: + _lr = lr_scheduler.get_lr() + if args.use_amp: + _l = (loss / scaler._scale).numpy() + msg = '[rank-%d][step-%d] train loss %.5f lr %.3e scaling %.3e' % ( + env.dev_id, step, _l, _lr, scaler._scale.numpy()) + else: + _l = loss.numpy() + msg = '[rank-%d][step-%d] train loss %.5f lr %.3e' % ( + env.dev_id, step, _l, _lr) + log.debug(msg) + #log_writer.add_scalar('loss', _l, step=step) + #log_writer.add_scalar('lr', _lr, step=step) + + # do saving + if step % args.valid_steps == 0 and env.dev_id == 0: + acc = [] + with P.no_grad(): + model.eval() + for d in P.io.DataLoader( + dev_ds, places=P.CUDAPlace(env.dev_id), + batch_size=None): + ids, sids, label = d + loss, logits = model(ids, sids, labels=label) + a = (logits.argmax(-1) == label) + acc.append(a.numpy()) + model.train() + acc = np.concatenate(acc).mean() + #log_writer.add_scalar('eval/acc', acc, step=step) + log.debug('dev acc %.5f' % acc) + acc = [] + with P.no_grad(): + model.eval() + for d in P.io.DataLoader( + test_ds, places=P.CUDAPlace(env.dev_id), + batch_size=None): + ids, sids, label = d + loss, logits = model(ids, sids, labels=label) + a = (logits.argmax(-1) == label) + acc.append(a.numpy()) + model.train() + acc = np.concatenate(acc).mean() + #log_writer.add_scalar('eval/acc', acc, step=step) + log.debug('test acc %.5f' % acc) + + if args.save_dir is not None: + P.save(model.state_dict(), args.save_dir / 'ckpt.bin') + # exit + if step > args.max_steps: + break + +if args.save_dir is not None and env.dev_id == 0: + P.save(model.state_dict(), args.save_dir / 'ckpt.bin') +log.debug('done') diff --git a/ernie_gram/finetune_mrc.py b/ernie_gram/finetune_mrc.py new file mode 100644 index 0000000..9d23697 --- /dev/null +++ b/ernie_gram/finetune_mrc.py @@ -0,0 +1,250 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import os +import re +import time +import logging +import json +from pathlib import Path +from random import random +from tqdm import tqdm +from functools import reduce, partial +import pickle +import argparse +from functools import partial +from io import open + +import numpy as np +import logging + +import paddle as P + +from propeller import log +import propeller.paddle as propeller +from ernie_gram.optimization import AdamW + +from ernie.modeling_ernie import ErnieModel, ErnieModelForQuestionAnswering +from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer +#from ernie.optimization import AdamW, LinearDecay + +from ernie_gram.mrc import mrc_reader +from ernie_gram.mrc import mrc_metrics +from ernie_gram.utils import create_if_not_exists, get_warmup_and_linear_decay + +log.setLevel(logging.DEBUG) +logging.getLogger().setLevel(logging.DEBUG) + + +def evaluate(model, ds, all_examples, all_features, tokenizer, args): + dev_file = json.loads(open(args.dev_file, encoding='utf8').read()) + with P.no_grad(): + log.debug('start eval') + model.eval() + all_res = [] + for step, (uids, token_ids, token_type_ids, _, __) in enumerate( + P.io.DataLoader( + ds, places=P.CUDAPlace(env.dev_id), batch_size=None)): + _, start_logits, end_logits = model(token_ids, token_type_ids) + res = [ + mrc_metrics.RawResult( + unique_id=u, start_logits=s, end_logits=e) + for u, s, e in zip(uids.numpy(), + start_logits.numpy(), end_logits.numpy()) + ] + all_res += res + open('all_res', 'wb').write(pickle.dumps(all_res)) + all_pred, all_nbests = mrc_metrics.make_results( + tokenizer, + all_examples, + all_features, + all_res, + n_best_size=args.n_best_size, + max_answer_length=args.max_answer_length, + do_lower_case=tokenizer.lower) + f1, em, _, __ = mrc_metrics.evaluate(dev_file, all_pred) + model.train() + log.debug('done eval') + return f1, em + + +def train(model, train_dataset, dev_dataset, dev_examples, dev_features, + tokenizer, args): + model = P.DataParallel(model) + + max_steps = args.max_steps + + + g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental + lr_scheduler = P.optimizer.lr.LambdaDecay( + args.lr, + get_warmup_and_linear_decay(max_steps, + int(args.warmup_proportion * max_steps))) + + opt = AdamW( + lr_scheduler, + parameters=model.parameters(), + weight_decay=args.wd, + grad_clip=g_clip) + + train_dataset = train_dataset \ + .cache_shuffle_shard(env.nranks, env.dev_id, drop_last=True) \ + .padded_batch(args.bsz) + + log.debug('init training with args: %s' % repr(args)) + scaler = P.amp.GradScaler(enable=args.use_amp) + create_if_not_exists(args.save_dir) + + with P.amp.auto_cast(enable=args.use_amp): + for step, (_, token_ids, token_type_ids, start_pos, + end_pos) in enumerate( + P.io.DataLoader( + train_dataset, + places=P.CUDAPlace(env.dev_id), + batch_size=None)): + loss, _, __ = model( + token_ids, + token_type_ids, + start_pos=start_pos, + end_pos=end_pos) + loss = scaler.scale(loss) + loss.backward() + scaler.minimize(opt, loss) + model.clear_gradients() + lr_scheduler.step() + + if env.dev_id == 0 and step % 10==0 and step: + _lr = lr_scheduler.get_lr() + if args.use_amp: + _l = (loss / scaler._scale).numpy() + msg = '[rank-%d][step-%d] train loss %.5f lr %.3e scaling %.3e' % ( + env.dev_id, step, _l, _lr, scaler._scale.numpy()) + else: + _l = loss.numpy() + msg = '[rank-%d][step-%d] train loss %.5f lr %.3e' % ( + env.dev_id, step, _l, _lr) + log.debug(msg) + + if env.dev_id == 0 and step % 100==0 and step: + print(step) + f1, em = evaluate(model, dev_dataset, dev_examples, + dev_features, tokenizer, args) + log.debug('[step %d] eval result: f1 %.5f em %.5f' % + (step, f1, em)) + if env.dev_id == 0 and args.save_dir is not None: + P.save(model.state_dict(), args.save_dir / 'ckpt.bin') + if step > max_steps: + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser('MRC model with ERNIE') + parser.add_argument( + '--from_pretrained', + type=Path, + required=True, + help='pretrained model directory or tag') + parser.add_argument( + '--max_seqlen', + type=int, + default=512, + help='max sentence length, should not greater than 512') + parser.add_argument('--bsz', type=int, default=16, help='batchsize') + parser.add_argument('--max_steps', type=int, required=True, help='max steps') + parser.add_argument( + '--train_file', + type=str, + required=True, + help='data directory includes train / develop data') + parser.add_argument( + '--dev_file', + type=str, + required=True, + help='data directory includes train / develop data') + parser.add_argument('--warmup_proportion', type=float, default=0.0) + parser.add_argument('--lr', type=float, default=3e-5, help='learning rate') + parser.add_argument( + '--save_dir', type=Path, required=True, help='model output directory') + parser.add_argument( + '--n_best_size', type=int, default=20, help='nbest prediction to keep') + parser.add_argument( + '--max_answer_length', type=int, default=100, help='max answer span') + parser.add_argument( + '--wd', + type=float, + default=0.01, + help='weight decay, aka L2 regularizer') + parser.add_argument( + '--use_amp', + action='store_true', + help='only activate AMP(auto mixed precision accelatoin) on TensorCore compatible devices' + ) + + args = parser.parse_args() + + env = P.distributed.ParallelEnv() + P.distributed.init_parallel_env() + + tokenizer = ErnieTokenizer.from_pretrained(args.from_pretrained) + + if not os.path.exists(args.train_file): + raise RuntimeError('input data not found at %s' % args.train_file) + if not os.path.exists(args.dev_file): + raise RuntimeError('input data not found at %s' % args.dev_file) + + log.info('making train/dev data...') + train_examples = mrc_reader.read_files(args.train_file, is_training=True) + train_features = mrc_reader.convert_example_to_features( + train_examples, args.max_seqlen, tokenizer, is_training=True) + + dev_examples = mrc_reader.read_files(args.dev_file, is_training=False) + dev_features = mrc_reader.convert_example_to_features( + dev_examples, args.max_seqlen, tokenizer, is_training=False) + + log.info('train examples: %d, features: %d' % + (len(train_examples), len(train_features))) + + def map_fn(unique_id, example_index, doc_span_index, tokens, + token_to_orig_map, token_is_max_context, token_ids, + position_ids, text_type_ids, start_position, end_position): + if start_position is None: + start_position = 0 + if end_position is None: + end_position = 0 + return np.array(unique_id), np.array(token_ids), np.array( + text_type_ids), np.array(start_position), np.array(end_position) + + train_dataset = propeller.data.Dataset.from_list(train_features).map( + map_fn) + + dev_dataset = propeller.data.Dataset.from_list(dev_features).map( + map_fn).padded_batch(args.bsz) + + model = ErnieModelForQuestionAnswering.from_pretrained( + args.from_pretrained, name='') + + train(model, train_dataset, dev_dataset, dev_examples, dev_features, + tokenizer, args) + + if env.dev_id == 0: + f1, em = evaluate(model, dev_dataset, dev_examples, dev_features, + tokenizer, args) + log.debug('final eval result: f1 %.5f em %.5f' % (f1, em)) + if env.dev_id == 0 and args.save_dir is not None: + P.save(model.state_dict(), args.save_dir / 'ckpt.bin') diff --git a/ernie_gram/finetune_ner.py b/ernie_gram/finetune_ner.py new file mode 100644 index 0000000..a59a815 --- /dev/null +++ b/ernie_gram/finetune_ner.py @@ -0,0 +1,261 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import time +import logging +import six +import json +from random import random +from tqdm import tqdm +from collections import OrderedDict +from functools import reduce, partial +from pathlib import Path +from visualdl import LogWriter + +import numpy as np +import multiprocessing +import pickle +import logging + +from sklearn.metrics import f1_score +import paddle as P + +from propeller import log +import propeller.paddle as propeller + +log.setLevel(logging.DEBUG) +logging.getLogger().setLevel(logging.DEBUG) + +from ernie_gram.utils import create_if_not_exists, get_warmup_and_linear_decay +from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification, ErnieModelForTokenClassification +from ernie.tokenizing_ernie import ErnieTokenizer +from ernie_gram.optimization import AdamW + +parser = propeller.ArgumentParser('NER model with ERNIE') +parser.add_argument('--max_seqlen', type=int, default=256) +parser.add_argument('--bsz', type=int, default=16) +parser.add_argument('--data_dir', type=str, required=True) +parser.add_argument('--epoch', type=int, default=10) +parser.add_argument( + '--warmup_proportion', + type=float, + default=0.1, + help='if use_lr_decay is set, ' + 'learning rate will raise to `lr` at `warmup_proportion` * `max_steps` and decay to 0. at `max_steps`' +) +parser.add_argument( + '--max_steps', + type=int, + required=True, + help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE, used in learning rate scheduler' +) +parser.add_argument( + '--use_amp', + action='store_true', + help='only activate AMP(auto mixed precision accelatoin) on TensorCore compatible devices' +) + +parser.add_argument('--from_pretrained', type=Path, required=True) +parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') +parser.add_argument( + '--save_dir', type=Path, required=True, help='model output directory') +parser.add_argument( + '--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer') +args = parser.parse_args() + +tokenizer = ErnieTokenizer.from_pretrained(args.from_pretrained) + + +def tokenizer_func(inputs): + ret = inputs.split(b'\2') + tokens, orig_pos = [], [] + for i, r in enumerate(ret): + t = tokenizer.tokenize(r) + for tt in t: + tokens.append(tt) + orig_pos.append(i) + assert len(tokens) == len(orig_pos) + return tokens + orig_pos + + +def tokenizer_func_for_label(inputs): + return inputs.split(b'\2') + + +feature_map = { + b"B-PER": 0, + b"I-PER": 1, + b"B-ORG": 2, + b"I-ORG": 3, + b"B-LOC": 4, + b"I-LOC": 5, + b"O": 6, +} +other_tag_id = feature_map[b'O'] + +feature_column = propeller.data.FeatureColumns([ + propeller.data.TextColumn( + 'text_a', + unk_id=tokenizer.unk_id, + vocab_dict=tokenizer.vocab, + tokenizer=tokenizer_func), propeller.data.TextColumn( + 'label', + unk_id=other_tag_id, + vocab_dict=feature_map, + tokenizer=tokenizer_func_for_label, ) +]) + + +def before(seg, label): + seg, orig_pos = np.split(seg, 2) + aligned_label = label[orig_pos] + seg, _ = tokenizer.truncate(seg, [], args.max_seqlen) + aligned_label, _ = tokenizer.truncate(aligned_label, [], args.max_seqlen) + orig_pos, _ = tokenizer.truncate(orig_pos, [], args.max_seqlen) + + sentence, segments = tokenizer.build_for_ernie( + seg + ) #utils.data.build_1_pair(seg, max_seqlen=args.max_seqlen, cls_id=cls_id, sep_id=sep_id) + aligned_label = np.concatenate([[0], aligned_label, [0]], 0) + orig_pos = np.concatenate([[0], orig_pos, [0]]) + + assert len(aligned_label) == len(sentence) == len(orig_pos), ( + len(aligned_label), len(sentence), len(orig_pos)) # alinged + return sentence, segments, aligned_label, label, orig_pos + +train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=False, use_gz=False) \ + .map(before) \ + .padded_batch(args.bsz, (0,0,-100, other_tag_id + 1, 0)) \ + +dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \ + .map(before) \ + .padded_batch(args.bsz, (0,0,-100, other_tag_id + 1,0)) \ + +test_ds = feature_column.build_dataset('test', data_dir=os.path.join(args.data_dir, 'test'), shuffle=False, repeat=False, use_gz=False) \ + .map(before) \ + .padded_batch(args.bsz, (0,0,-100, other_tag_id + 1,0)) \ + + +def evaluate(model, dataset): + model.eval() + with P.no_grad(): + chunkf1 = propeller.metrics.ChunkF1(None, None, None, len(feature_map)) + for step, (ids, sids, aligned_label, label, orig_pos + ) in enumerate(P.io.DataLoader( + dataset, batch_size=None)): + loss, logits = model(ids, sids) + #print('\n'.join(map(str, logits.numpy().tolist()))) + + assert orig_pos.shape[0] == logits.shape[0] == ids.shape[ + 0] == label.shape[0] + for pos, lo, la, id in zip(orig_pos.numpy(), + logits.numpy(), + label.numpy(), ids.numpy()): + _dic = OrderedDict() + assert len(pos) == len(lo) == len(id) + for _pos, _lo, _id in zip(pos, lo, id): + if _id > tokenizer.mask_id: # [MASK] is the largest special token + _dic.setdefault(_pos, []).append(_lo) + merged_lo = np.array( + [np.array(l).mean(0) for _, l in six.iteritems(_dic)]) + merged_preds = np.argmax(merged_lo, -1) + la = la[np.where(la != (other_tag_id + 1))] #remove pad + if len(la) > len(merged_preds): + log.warn( + 'accuracy loss due to truncation: label len:%d, truncate to %d' + % (len(la), len(merged_preds))) + merged_preds = np.pad(merged_preds, + [0, len(la) - len(merged_preds)], + mode='constant', + constant_values=7) + else: + assert len(la) == len( + merged_preds + ), 'expect label == prediction, got %d vs %d' % ( + la.shape, merged_preds.shape) + chunkf1.update((merged_preds, la, np.array(len(la)))) + #f1 = f1_score(np.concatenate(all_label), np.concatenate(all_pred), average='macro') + f1 = chunkf1.eval() + model.train() + return f1 + + +model = ErnieModelForTokenClassification.from_pretrained( + args.from_pretrained, + num_labels=len(feature_map), + name='', + has_pooler=False) + +g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental +param_name_to_exclue_from_weight_decay = re.compile( + r'.*layer_norm_scale|.*layer_norm_bias|.*b_0') +lr_scheduler = P.optimizer.lr.LambdaDecay( + args.lr, + get_warmup_and_linear_decay(args.max_steps, + int(args.warmup_proportion * args.max_steps))) +opt = AdamW( + lr_scheduler, + parameters=model.parameters(), + weight_decay=args.wd, + apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n), + grad_clip=g_clip) + +scaler = P.amp.GradScaler(enable=args.use_amp) +with LogWriter( + logdir=str(create_if_not_exists(args.save_dir / 'vdl'))) as log_writer: + with P.amp.auto_cast(enable=args.use_amp): + for epoch in range(args.epoch): + for step, ( + ids, sids, aligned_label, label, orig_pos + ) in enumerate(P.io.DataLoader( + train_ds, batch_size=None)): + loss, logits = model(ids, sids, labels=aligned_label) + #loss, logits = model(ids, sids, labels=aligned_label, loss_weights=P.cast(ids != 0, 'float32')) + loss = scaler.scale(loss) + loss.backward() + scaler.minimize(opt, loss) + model.clear_gradients() + lr_scheduler.step() + + if step % 10 == 0: + _lr = lr_scheduler.get_lr() + if args.use_amp: + _l = (loss / scaler._scale).numpy() + msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % ( + step, _l, _lr, scaler._scale.numpy()) + else: + _l = loss.numpy() + msg = '[step-%d] train loss %.5f lr %.3e' % (step, _l, + _lr) + log.debug(msg) + log_writer.add_scalar('loss', _l, step=step) + log_writer.add_scalar('lr', _lr, step=step) + + if step % 100 == 0: + f1 = evaluate(model, dev_ds) + log.debug('dev eval f1: %.5f' % f1) + log_writer.add_scalar('dev eval/f1', f1, step=step) + f1 = evaluate(model, test_ds) + log.debug('test eval f1: %.5f' % f1) + log_writer.add_scalar('test eval/f1', f1, step=step) + if args.save_dir is not None: + P.save(model.state_dict(), args.save_dir / 'ckpt.bin') + +f1 = evaluate(model, dev_ds) +log.debug('final eval f1: %.5f' % f1) +log_writer.add_scalar('eval/f1', f1, step=step) +if args.save_dir is not None: + P.save(model.state_dict(), args.save_dir / 'ckpt.bin') diff --git a/ernie_gram/mrc/__init__.py b/ernie_gram/mrc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ernie_gram/mrc/mrc_metrics.py b/ernie_gram/mrc/mrc_metrics.py new file mode 100644 index 0000000..a94859c --- /dev/null +++ b/ernie_gram/mrc/mrc_metrics.py @@ -0,0 +1,602 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import re +import six +import logging +import math +import collections +import nltk + +import unicodedata +from collections import namedtuple + +RawResult = namedtuple("RawResult", + ["unique_id", "start_logits", "end_logits"]) + +log = logging.getLogger(__name__) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a peice of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +class _BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +def _get_best_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + index_and_score = sorted( + enumerate(logits), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def _get_final_text(pred_text, orig_text, tokenizer): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heruistic between + # `pred_text` and `orig_text` to get a character-to-charcter alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + return orig_text + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + return orig_text + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in six.iteritems(tok_ns_to_s_map): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + return orig_text + + output_text = orig_text[orig_start_position:(orig_end_position + 1)] + return output_text + + +def make_results(vocab, all_examples, all_features, all_results, n_best_size, + max_answer_length, do_lower_case): + """Write final predictions to the json file and log-odds of null if needed.""" + tokenizer = _BasicTokenizer(do_lower_case) + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", [ + "feature_index", "start_index", "end_index", "start_logit", + "end_logit" + ]) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + for (feature_index, feature) in enumerate(features): + result = unique_id_to_result[feature.unique_id] + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + #log.debug(start_indexes) + #log.debug(end_indexes) + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens): + continue + if end_index >= len(feature.tokens): + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get(start_index, + False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index])) + + prelim_predictions = sorted( + prelim_predictions, + key=lambda x: (x.start_logit + x.end_logit), + reverse=True) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"]) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + if pred.start_index > 0: # this is a non-null prediction + tok_tokens = feature.tokens[pred.start_index:(pred.end_index + + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + + 1)] + tok_text = " ".join(tok_tokens) + + # De-tokenize WordPieces that have been split off. + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = "".join(orig_tokens) + final_text = _get_final_text(tok_text, orig_text, tokenizer) + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + else: + final_text = "" + seen_predictions[final_text] = True + + nbest.append( + _NbestPrediction( + text=final_text, + start_logit=pred.start_logit, + end_logit=pred.end_logit)) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append( + _NbestPrediction( + text="empty", start_logit=0.0, end_logit=0.0)) + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + nbest_json.append(output) + #log.debug(nbest_json[0]) + #log.debug(example.qas_id) + + assert len(nbest_json) >= 1 + + all_predictions[example.qas_id] = nbest_json[0]["text"] + all_nbest_json[example.qas_id] = nbest_json + return all_predictions, all_nbest_json + + +# split Chinese with English +def mixed_segmentation(in_str, rm_punc=False): + """mix segmentation""" + in_str = in_str.lower().strip() + segs_out = [] + temp_str = "" + sp_char = [ + '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', + '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', + ')', '-', '~', '『', '』' + ] + for char in in_str: + if rm_punc and char in sp_char: + continue + if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: + if temp_str != "": + ss = nltk.word_tokenize(temp_str) + segs_out.extend(ss) + temp_str = "" + segs_out.append(char) + else: + temp_str += char + + #handling last part + if temp_str != "": + ss = nltk.word_tokenize(temp_str) + segs_out.extend(ss) + + return segs_out + + +# remove punctuation +def remove_punctuation(in_str): + """remove punctuation""" + in_str = in_str.lower().strip() + sp_char = [ + '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', + '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', + ')', '-', '~', '『', '』' + ] + out_segs = [] + for char in in_str: + if char in sp_char: + continue + else: + out_segs.append(char) + return ''.join(out_segs) + + +# find longest common string +def find_lcs(s1, s2): + """find_lcs""" + m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)] + mmax = 0 + p = 0 + for i in range(len(s1)): + for j in range(len(s2)): + if s1[i] == s2[j]: + m[i + 1][j + 1] = m[i][j] + 1 + if m[i + 1][j + 1] > mmax: + mmax = m[i + 1][j + 1] + p = i + 1 + return s1[p - mmax:p], mmax + + +def calc_f1_score(answers, prediction): + """calc_f1_score""" + f1_scores = [] + for ans in answers: + ans_segs = mixed_segmentation(ans, rm_punc=True) + prediction_segs = mixed_segmentation(prediction, rm_punc=True) + lcs, lcs_len = find_lcs(ans_segs, prediction_segs) + if lcs_len == 0: + f1_scores.append(0) + continue + precision = 1.0 * lcs_len / len(prediction_segs) + recall = 1.0 * lcs_len / len(ans_segs) + f1 = (2 * precision * recall) / (precision + recall) + f1_scores.append(f1) + return max(f1_scores) + + +def calc_em_score(answers, prediction): + """calc_f1_score""" + em = 0 + for ans in answers: + ans_ = remove_punctuation(ans) + prediction_ = remove_punctuation(prediction) + if ans_ == prediction_: + em = 1 + break + return em + + +def evaluate(ground_truth_file, prediction_file): + """evaluate""" + f1 = 0 + em = 0 + total_count = 0 + skip_count = 0 + for instances in ground_truth_file["data"]: + for instance in instances["paragraphs"]: + context_text = instance['context'].strip() + for qas in instance['qas']: + total_count += 1 + query_id = qas['id'].strip() + query_text = qas['question'].strip() + answers = [ans["text"] for ans in qas["answers"]] + + if query_id not in prediction_file: + sys.stderr.write('Unanswered question: {}\n'.format( + query_id)) + skip_count += 1 + continue + + prediction = prediction_file[query_id] + f1 += calc_f1_score(answers, prediction) + em += calc_em_score(answers, prediction) + + f1_score = f1 / total_count + em_score = em / total_count + return [f1_score, em_score, total_count, skip_count] diff --git a/ernie_gram/mrc/mrc_reader.py b/ernie_gram/mrc/mrc_reader.py new file mode 100644 index 0000000..999925e --- /dev/null +++ b/ernie_gram/mrc/mrc_reader.py @@ -0,0 +1,303 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import argparse +import logging +from functools import partial +from io import open + +open = partial(open, encoding='utf-8') + +import json +from collections import namedtuple + +log = logging.getLogger(__name__) + +Example = namedtuple('Example', [ + 'qas_id', 'question_text', 'doc_tokens', 'orig_answer_text', + 'start_position', 'end_position' +]) + +Feature = namedtuple("Feature", [ + "unique_id", "example_index", "doc_span_index", "tokens", + "token_to_orig_map", "token_is_max_context", "token_ids", "position_ids", + "text_type_ids", "start_position", "end_position" +]) + + +def _tokenize_chinese_chars(text): + """Adds whitespace around any CJK character.""" + + def _is_chinese_char(cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + output = [] + buff = "" + for char in text: + cp = ord(char) + if _is_chinese_char(cp): + if buff != "": + output.append(buff) + buff = "" + output.append(char) + else: + buff += char + + if buff != "": + output.append(buff) + + return output + + +def _check_is_max_context(doc_spans, cur_span_index, position): + """chech is max context""" + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + continue + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, + num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + + +def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, + orig_answer_text): + """improve answer span""" + tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) + + for new_start in range(input_start, input_end + 1): + for new_end in range(input_end, new_start - 1, -1): + text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) + if text_span == tok_answer_text: + return (new_start, new_end) + + return (input_start, input_end) + + +def read_files(input_file, is_training): + """read file""" + examples = [] + with open(input_file, "r") as f: + input_data = json.load(f)["data"] + for entry in input_data: + for paragraph in entry["paragraphs"]: + paragraph_text = paragraph["context"] + for qa in paragraph["qas"]: + qas_id = qa["id"] + question_text = qa["question"] + start_pos = None + end_pos = None + orig_answer_text = None + + if is_training: + if len(qa["answers"]) != 1: + raise ValueError( + "For training, each question should have exactly 1 answer." + ) + + answer = qa["answers"][0] + orig_answer_text = answer["text"] + answer_offset = answer["answer_start"] + answer_length = len(orig_answer_text) + doc_tokens = [ + paragraph_text[:answer_offset], paragraph_text[ + answer_offset:answer_offset + answer_length], + paragraph_text[answer_offset + answer_length:] + ] + + start_pos = 1 + end_pos = 1 + + actual_text = " ".join(doc_tokens[start_pos:(end_pos + + 1)]) + if actual_text.find(orig_answer_text) == -1: + log.info("Could not find answer: '%s' vs. '%s'", + actual_text, orig_answer_text) + continue + else: + doc_tokens = _tokenize_chinese_chars(paragraph_text) + + example = Example( + qas_id=qas_id, + question_text=question_text, + doc_tokens=doc_tokens, + orig_answer_text=orig_answer_text, + start_position=start_pos, + end_position=end_pos) + examples.append(example) + + return examples + + +def convert_example_to_features(examples, + max_seq_length, + tokenizer, + is_training, + doc_stride=128, + max_query_length=64): + """convert example to feature""" + features = [] + unique_id = 1000000000 + + for (example_index, example) in enumerate(examples): + query_tokens = tokenizer.tokenize(example.question_text) + if len(query_tokens) > max_query_length: + query_tokens = query_tokens[0:max_query_length] + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(example.doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + #log.info(orig_to_tok_index, example.start_position) + + tok_start_position = None + tok_end_position = None + if is_training: + tok_start_position = orig_to_tok_index[example.start_position] + if example.end_position < len(example.doc_tokens) - 1: + tok_end_position = orig_to_tok_index[example.end_position + + 1] - 1 + else: + tok_end_position = len(all_doc_tokens) - 1 + (tok_start_position, tok_end_position) = _improve_answer_span( + all_doc_tokens, tok_start_position, tok_end_position, + tokenizer, example.orig_answer_text) + + max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 + _DocSpan = namedtuple("DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, doc_stride) + + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + text_type_ids = [] + tokens.append("[CLS]") + text_type_ids.append(0) + for token in query_tokens: + tokens.append(token) + text_type_ids.append(0) + tokens.append("[SEP]") + text_type_ids.append(0) + + for i in range(doc_span.length): + split_token_index = doc_span.start + i + token_to_orig_map[len(tokens)] = tok_to_orig_index[ + split_token_index] + + is_max_context = _check_is_max_context( + doc_spans, doc_span_index, split_token_index) + token_is_max_context[len(tokens)] = is_max_context + tokens.append(all_doc_tokens[split_token_index]) + text_type_ids.append(1) + tokens.append("[SEP]") + text_type_ids.append(1) + + token_ids = tokenizer.convert_tokens_to_ids(tokens) + position_ids = list(range(len(token_ids))) + start_position = None + end_position = None + if is_training: + doc_start = doc_span.start + doc_end = doc_span.start + doc_span.length - 1 + out_of_span = False + if not (tok_start_position >= doc_start and + tok_end_position <= doc_end): + out_of_span = True + if out_of_span: + start_position = 0 + end_position = 0 + else: + doc_offset = len(query_tokens) + 2 + start_position = tok_start_position - doc_start + doc_offset + end_position = tok_end_position - doc_start + doc_offset + + feature = Feature( + unique_id=unique_id, + example_index=example_index, + doc_span_index=doc_span_index, + tokens=tokens, + token_to_orig_map=token_to_orig_map, + token_is_max_context=token_is_max_context, + token_ids=token_ids, + position_ids=position_ids, + text_type_ids=text_type_ids, + start_position=start_position, + end_position=end_position) + features.append(feature) + + unique_id += 1 + + return features + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='main') + parser.add_argument("--input", type=str, default=None) + args = parser.parse_args() + + from ernie.tokenizing_ernie import ErnieTokenizer + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + examples = read_files(args.input, True) + features = convert_example_to_features(examples, 512, tokenizer, True) + log.debug(len(examples)) + log.debug(len(features)) diff --git a/ernie_gram/optimization.py b/ernie_gram/optimization.py new file mode 100644 index 0000000..fe5e361 --- /dev/null +++ b/ernie_gram/optimization.py @@ -0,0 +1,136 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +import logging +import re +from paddle.fluid import framework +from paddle.fluid.framework import Variable, default_main_program +import numpy as np +import paddle as P +import paddle.distributed.fleet as fleet +from propeller.paddle.train.hooks import RunHook +import paddle.fluid as F +log = logging.getLogger(__name__) + +from ernie_gram.utils import create_if_not_exists, get_warmup_and_linear_decay + +class AdamW(P.optimizer.AdamW): + """AdamW object for dygraph""" + def __init__(self, *args, **kwargs): + layerwise_lr_decay = kwargs.pop('layerwise_lr_decay_rate', 0.8) + n_layers = kwargs.pop('n_layers', 12) + var_name_to_exclude = kwargs.pop('var_name_to_exclude', '.*layer_norm_scale|.*layer_norm_bias|.*b_0') + super(AdamW, self).__init__(*args, **kwargs) + self.ld = layerwise_lr_decay + self.pat = re.compile(var_name_to_exclude) + self.n_layers = n_layers + + def _get_layerwise_lr_decay_rate(self, param): + #if self.pat.match(param.name): + # return 1.0 + if param.name.startswith("encoder_layer"): + layer = int(param.name.split("_")[2]) + decay_rate = self.ld ** (self.n_layers - layer) + elif "embedding" in param.name: + decay_rate = self.ld ** (self.n_layers + 1) + else: + decay_rate = 1.0 + return decay_rate + + + def _create_param_lr(self, param_and_grad): + # create learning rate tensor for every parameter + param = param_and_grad[0] + param_lr = param.optimize_attr['learning_rate'] * self._get_layerwise_lr_decay_rate(param) + if type(param_lr) == Variable: + return param_lr + else: + if param_lr == 1.0: + return self._global_learning_rate() + else: + with default_main_program()._lr_schedule_guard( + is_with_opt=True), framework.name_scope( + 'scale_with_param_lr'): + return self._global_learning_rate() * param_lr + + def apply_optimize(self, loss, startup_program, params_grads): + super(AdamW, self).apply_optimize(loss, startup_program, params_grads) + for p, g in params_grads: + #log.debug(L.reduce_mean(p)) + if not self.pat.match(p.name): + L.assign(p * (1. - self.wd * self.current_step_lr()), p) + + +def optimization( + loss, + warmup_steps, + num_train_steps, + learning_rate, + train_program, + startup_prog, + weight_decay, + scheduler='linear_warmup_decay', + use_fp16=False, ): + """do backword for static""" + + def exclude_from_weight_decay(param): + name = param.rstrip('.master') + if name.find("layer_norm") > -1: + return True + bias_suffix = ["_bias", "_b", ".b_0"] + for suffix in bias_suffix: + if name.endswith(suffix): + return True + return False + + g_clip = P.nn.ClipGradByGlobalNorm(1.0) + lr_scheduler = P.optimizer.lr.LambdaDecay( + learning_rate, + get_warmup_and_linear_decay(num_train_steps, warmup_steps)) + + optimizer = AdamW( + learning_rate=lr_scheduler, + weight_decay=weight_decay, + grad_clip=g_clip, + apply_decay_param_fun=exclude_from_weight_decay) + + if use_fp16: + log.info('AMP activated') + if weight_decay > 0.: + raise ValueError( + 'paddle amp will ignore `weight_decay`, see https://github.com/PaddlePaddle/Paddle/issues/29794' + ) + #amp_list = P.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( + # custom_white_list=['softmax', 'layer_norm', 'gelu']) + optimizer = P.fluid.contrib.mixed_precision.decorate( + optimizer, init_loss_scaling=2**15, use_dynamic_loss_scaling=True) + _, param_grads = optimizer.minimize(loss) + loss_scaling = P.static.default_main_program().global_block().var( + 'loss_scaling_0') + else: + _, param_grads = optimizer.minimize(loss) + loss_scaling = None + + class LRStepHook(RunHook): + def after_run(self, _, __): + lr_scheduler.step() + log.debug('lr step: %.5f' % lr_scheduler.get_lr()) + + return LRStepHook(), loss_scaling diff --git a/ernie_gram/run_cls.sh b/ernie_gram/run_cls.sh new file mode 100644 index 0000000..b8587e6 --- /dev/null +++ b/ernie_gram/run_cls.sh @@ -0,0 +1,13 @@ +source $1 + +python3 -m paddle.distributed.launch ./ernie_gram/finetune_classifier_distributed.py \ + --data_dir $data_dir \ + --max_steps $max_steps \ + --bsz $bsz \ + --lr $lr \ + --label_map ${label_map:-""} \ + --num_labels $num_labels \ + --pair_input $pair_input \ + --valid_steps $valid_steps \ + --from_pretrained $from_pretrained \ + --save_dir checkpoints diff --git a/ernie_gram/run_mrc.sh b/ernie_gram/run_mrc.sh new file mode 100644 index 0000000..0f3980a --- /dev/null +++ b/ernie_gram/run_mrc.sh @@ -0,0 +1,9 @@ +source $1 +export CUDA_VISIBLE_DEVICES=0 +python3 -m paddle.distributed.launch ./ernie_gram/finetune_mrc.py \ + --train_file $train_file \ + --dev_file $dev_file \ + --max_steps $max_steps \ + --lr $lr \ + --from_pretrained $from_pretrained \ + --save_dir checkpoints diff --git a/ernie_gram/run_ner.sh b/ernie_gram/run_ner.sh new file mode 100644 index 0000000..11604b6 --- /dev/null +++ b/ernie_gram/run_ner.sh @@ -0,0 +1,9 @@ +source $1 + +python3 -m paddle.distributed.launch ./ernie_gram/finetune_ner.py \ + --data_dir $data_dir \ + --max_steps $max_steps \ + --epoch $epoch \ + --lr $lr \ + --from_pretrained $from_pretrained \ + --save_dir checkpoints diff --git a/ernie_gram/task_configs/cmrc_conf b/ernie_gram/task_configs/cmrc_conf new file mode 100644 index 0000000..846e301 --- /dev/null +++ b/ernie_gram/task_configs/cmrc_conf @@ -0,0 +1,5 @@ +train_file="data/cmrc2018/train/train.json" +dev_file="data/cmrc2018/dev/dev.json" +max_steps=1320 +lr=1.5e-4 +from_pretrained="ernie-gram-zh" diff --git a/ernie_gram/task_configs/msra_ner_conf b/ernie_gram/task_configs/msra_ner_conf new file mode 100644 index 0000000..b0ef9bb --- /dev/null +++ b/ernie_gram/task_configs/msra_ner_conf @@ -0,0 +1,5 @@ +data_dir="data/msra_ner" +epoch=10 +max_steps=13040 +lr=5e-5 +from_pretrained="ernie-gram-zh" diff --git a/ernie_gram/task_configs/xnli_conf b/ernie_gram/task_configs/xnli_conf new file mode 100644 index 0000000..634cc4d --- /dev/null +++ b/ernie_gram/task_configs/xnli_conf @@ -0,0 +1,9 @@ +data_dir="data/xnli" +max_steps=4600 #3 epoch +lr=1.5e-4 +label_map='{"contradictory":0,"contradiction":0,"entailment":1,"neutral":2}' +num_labels=3 +valid_steps=25 +from_pretrained="ernie-gram-zh" +pair_input=1 +bsz=32 diff --git a/ernie_gram/utils.py b/ernie_gram/utils.py new file mode 100644 index 0000000..9fc3c3c --- /dev/null +++ b/ernie_gram/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import argparse +import logging +import paddle + + +class UnpackDataLoader(paddle.io.DataLoader): + def __init__(self, *args, **kwargs): + super(UnpackDataLoader, self).__init__(*args, batch_size=1, **kwargs) + + def __iter__(self): + return ([yy[0] for yy in y] + for y in super(UnpackDataLoader, self).__iter__()) + + +def create_if_not_exists(dir): + try: + dir.mkdir(parents=True) + except FileExistsError: + pass + return dir + + +def get_warmup_and_linear_decay(max_steps, warmup_steps): + if warmup_steps == 0: + return lambda step: 1.0 + else: + return lambda step: min(step / warmup_steps, 1. - (step - warmup_steps) / (max_steps - warmup_steps)) -- GitLab