# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import division from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals import sys import io import re import argparse import logging import json import numpy as np from pathlib import Path from collections import namedtuple import paddle as P from paddle.nn import functional as F from ernie.modeling_ernie import ErnieModel, ErnieModelForPretraining, ErnieModelForGeneration from ernie.modeling_ernie import _build_linear, _build_ln, append_name from ernie.tokenizing_ernie import ErnieTokenizer from propeller import log import propeller.paddle as propeller @np.vectorize def rev_lookup(i): return rev_dict[i] def gen_bias(encoder_inputs, decoder_inputs, step): decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2] attn_bias = P.reshape( P.arange( 0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1]) decoder_bias = P.cast( (P.matmul( attn_bias, 1. / attn_bias, transpose_y=True) >= 1.), 'float32') #[1, 1, decoderlen, decoderlen] encoder_bias = P.unsqueeze( P.cast(P.ones_like(encoder_inputs), 'float32'), [1]) #[bsz, 1, encoderlen] encoder_bias = P.tile( encoder_bias, [1, decoder_seqlen, 1]) #[bsz,decoderlen, encoderlen] decoder_bias = P.tile(decoder_bias, [decoder_bsz, 1, 1]) #[bsz, decoderlen, decoderlen] if step > 0: bias = P.concat([ encoder_bias, P.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias ], -1) else: bias = P.concat([encoder_bias, decoder_bias], -1) return bias #def make_data(tokenizer, inputs, max_encode_len): # all_ids, all_sids = [], [] # for i in inputs: # q_ids, q_sids = tokenizer.build_for_ernie( # np.array( # tokenizer.convert_tokens_to_ids(i.split(' '))[: max_encode_len-2], # dtype=np.int64 # ) # ) # all_ids.append(q_ids) # all_sids.append(q_sids) # ml = max(map(len, all_ids)) # all_ids = [np.pad(i, [0, ml-len(i)], mode='constant')for i in all_ids] # all_sids = [np.pad(i, [0, ml-len(i)], mode='constant')for i in all_sids] # all_ids = np.stack(all_ids, 0) # all_sids = np.stack(all_sids, 0) # return all_ids, all_sids def greedy_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, tgt_type_id=3): model.eval() with P.no_grad(): #log.debug(q_ids.numpy().tolist()) _, logits, info = model(q_ids, q_sids) gen_ids = P.argmax(logits, -1) d_batch, d_seqlen = q_ids.shape seqlen = P.cast(q_ids != 0, 'int64').sum(1, keepdim=True) log.debug(seqlen.numpy()) log.debug(d_seqlen) has_stopped = np.zeros([d_batch], dtype=np.bool) gen_seq_len = np.zeros([d_batch], dtype=np.int64) output_ids = [] past_cache = info['caches'] cls_ids = P.ones([d_batch], dtype='int64') * sos_id attn_ids = P.ones([d_batch], dtype='int64') * attn_id ids = P.stack([cls_ids, attn_ids], -1) for step in range(max_decode_len): log.debug('decode step %d' % step) bias = gen_bias(q_ids, ids, step) pos_ids = P.to_tensor( np.tile( np.array( [[step, step + 1]], dtype=np.int64), [d_batch, 1])) pos_ids += seqlen _, logits, info = model( ids, P.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache) gen_ids = P.argmax(logits, -1) past_cached_k, past_cached_v = past_cache cached_k, cached_v = info['caches'] cached_k = [ P.concat([pk, k[:, :1, :]], 1) for pk, k in zip(past_cached_k, cached_k) ] # concat cached cached_v = [ P.concat([pv, v[:, :1, :]], 1) for pv, v in zip(past_cached_v, cached_v) ] past_cache = (cached_k, cached_v) gen_ids = gen_ids[:, 1] ids = P.stack([gen_ids, attn_ids], 1) gen_ids = gen_ids.numpy() has_stopped |= (gen_ids == eos_id).astype(np.bool) gen_seq_len += (1 - has_stopped.astype(np.int64)) output_ids.append(gen_ids.tolist()) if has_stopped.all(): #log.debug('exit because all done') break #if step == 1: break output_ids = np.array(output_ids).transpose([1, 0]) return output_ids BeamSearchState = namedtuple('BeamSearchState', ['log_probs', 'lengths', 'finished']) BeamSearchOutput = namedtuple('BeamSearchOutput', ['scores', 'predicted_ids', 'beam_parent_ids']) def log_softmax(x): e_x = np.exp(x - np.max(x)) return np.log(e_x / e_x.sum()) def mask_prob(p, onehot_eos, finished): is_finished = P.cast(P.reshape(finished, [-1, 1]) != 0, 'float32') p = is_finished * (1. - P.cast(onehot_eos, 'float32')) * -9999. + ( 1. - is_finished) * p return p def hyp_score(log_probs, length, length_penalty): lp = P.pow((5. + P.cast(length, 'float32')) / 6., length_penalty) return log_probs / lp def beam_search_step(state, logits, eos_id, beam_width, is_first_step, length_penalty): """logits.shape == [B*W, V]""" _, vocab_size = logits.shape bsz, beam_width = state.log_probs.shape onehot_eos = P.cast( F.one_hot(P.ones([1], 'int64') * eos_id, vocab_size), 'int64') #[1, V] probs = P.log(F.softmax(logits)) #[B*W, V] probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V] allprobs = P.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V] not_finished = 1 - P.reshape(state.finished, [-1, 1]) #[B*W,1] not_eos = 1 - onehot_eos length_to_add = not_finished * not_eos #[B*W,V] alllen = P.reshape(state.lengths, [-1, 1]) + length_to_add allprobs = P.reshape(allprobs, [-1, beam_width * vocab_size]) alllen = P.reshape(alllen, [-1, beam_width * vocab_size]) allscore = hyp_score(allprobs, alllen, length_penalty) if is_first_step: allscore = P.reshape( allscore, [bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0 scores, idx = P.topk(allscore, k=beam_width) #[B, W] next_beam_id = idx // vocab_size #[B, W] next_word_id = idx % vocab_size gather_idx = P.concat( [P.nonzero(idx != -1)[:, :1], P.reshape(idx, [-1, 1])], 1) next_probs = P.reshape(P.gather_nd(allprobs, gather_idx), idx.shape) next_len = P.reshape(P.gather_nd(alllen, gather_idx), idx.shape) gather_idx = P.concat([ P.nonzero(next_beam_id != -1)[:, :1], P.reshape(next_beam_id, [-1, 1]) ], 1) next_finished = P.reshape( P.gather_nd(state.finished, gather_idx), state.finished. shape) #[gather new beam state according to new beam id] #log.debug(gather_idx.numpy()) #log.debug(state.finished.numpy()) #log.debug(next_finished.numpy()) next_finished += P.cast(next_word_id == eos_id, 'int64') next_finished = P.cast(next_finished > 0, 'int64') #log.debug(next_word_id.numpy()) #log.debug(next_beam_id.numpy()) next_state = BeamSearchState( log_probs=next_probs, lengths=next_len, finished=next_finished) output = BeamSearchOutput( scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id) return output, next_state def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, beam_width=5, tgt_type_id=3, length_penalty=1.0): model.eval() with P.no_grad(): #log.debug(q_ids.numpy().tolist()) _, __, info = model(q_ids, q_sids) d_batch, d_seqlen = q_ids.shape state = BeamSearchState( log_probs=P.zeros([d_batch, beam_width], 'float32'), lengths=P.zeros([d_batch, beam_width], 'int64'), finished=P.zeros([d_batch, beam_width], 'int64')) outputs = [] def reorder_(t, parent_id): """reorder cache according to parent beam id""" gather_idx = P.nonzero( parent_id != -1)[:, 0] * beam_width + P.reshape(parent_id, [-1]) t = P.gather(t, gather_idx) return t def tile_(t, times): _shapes = list(t.shape[1:]) ret = P.reshape( P.tile( P.unsqueeze(t, [1]), [ 1, times, ] + [1, ] * len(_shapes)), [-1, ] + _shapes) return ret cached_k, cached_v = info['caches'] cached_k = [tile_(k, beam_width) for k in cached_k] cached_v = [tile_(v, beam_width) for v in cached_v] past_cache = (cached_k, cached_v) q_ids = tile_(q_ids, beam_width) seqlen = P.cast(q_ids != 0, 'int64').sum(1, keepdim=True) #log.debug(q_ids.shape) cls_ids = P.ones([d_batch * beam_width], dtype='int64') * sos_id attn_ids = P.ones( [d_batch * beam_width], dtype='int64') * attn_id # SOS ids = P.stack([cls_ids, attn_ids], -1) for step in range(max_decode_len): #log.debug('decode step %d' % step) bias = gen_bias(q_ids, ids, step) pos_ids = P.to_tensor( np.tile( np.array( [[step, step + 1]], dtype=np.int64), [d_batch * beam_width, 1])) pos_ids += seqlen _, logits, info = model( ids, P.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache) output, state = beam_search_step( state, logits[:, 1], eos_id=eos_id, beam_width=beam_width, is_first_step=(step == 0), length_penalty=length_penalty) outputs.append(output) past_cached_k, past_cached_v = past_cache cached_k, cached_v = info['caches'] cached_k = [ reorder_( P.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) for pk, k in zip(past_cached_k, cached_k) ] # concat cached cached_v = [ reorder_( P.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) for pv, v in zip(past_cached_v, cached_v) ] past_cache = (cached_k, cached_v) pred_ids_flatten = P.reshape(output.predicted_ids, [d_batch * beam_width]) ids = P.stack([pred_ids_flatten, attn_ids], 1) if state.finished.numpy().all(): #log.debug('exit because all done') break #if step == 1: break final_ids = P.stack([o.predicted_ids for o in outputs], 0) final_parent_ids = P.stack([o.beam_parent_ids for o in outputs], 0) final_ids = P.fluid.layers.gather_tree( final_ids, final_parent_ids)[:, :, 0] #pick best beam final_ids = P.transpose( P.reshape(final_ids, [-1, d_batch * 1]), [1, 0]) return final_ids en_patten = re.compile(r'^[a-zA-Z0-9]*$') def post_process(token): if token.startswith('##'): ret = token[2:] else: if en_patten.match(token): ret = ' ' + token else: ret = token return ret if __name__ == '__main__': sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') parser = argparse.ArgumentParser('seq2seq model with ERNIE') parser.add_argument( '--from_pretrained', type=Path, required=True, help='pretrained model directory or tag') parser.add_argument('--bsz', type=int, default=8, help='batchsize') parser.add_argument('--max_encode_len', type=int, default=640) parser.add_argument('--max_decode_len', type=int, default=120) parser.add_argument('--tgt_type_id', type=int, default=3) parser.add_argument('--beam_width', type=int, default=5) parser.add_argument( '--attn_token', type=str, default='[ATTN]', help='if [ATTN] not in vocab, you can specified [MAKK] as attn-token') parser.add_argument('--length_penalty', type=float, default=1.0) parser.add_argument( '--save_dir', type=str, required=True, help='model dir to be loaded') args = parser.parse_args() env = P.distributed.ParallelEnv() ernie = ErnieModelForGeneration.from_pretrained( args.from_pretrained, name='') tokenizer = ErnieTokenizer.from_pretrained( args.from_pretrained, mask_token=None) rev_dict = {v: k for k, v in tokenizer.vocab.items()} rev_dict[tokenizer.pad_id] = '' # replace [PAD] rev_dict[tokenizer.unk_id] = '' # replace [PAD] sd = P.load(args.save_dir) ernie.set_state_dict(sd) def map_fn(src_ids): src_ids = src_ids[:args.max_encode_len] src_ids, src_sids = tokenizer.build_for_ernie(src_ids) return (src_ids, src_sids) feature_column = propeller.data.FeatureColumns([ propeller.data.TextColumn( 'seg_a', unk_id=tokenizer.unk_id, vocab_dict=tokenizer.vocab, tokenizer=tokenizer.tokenize), ]) dataset = feature_column.build_dataset_from_stdin('predict').map( map_fn).padded_batch(args.bsz) for step, (encoder_ids, encoder_sids) in enumerate(dataset): #result_ids = greedy_search_infilling(ernie, P.to_tensor(encoder_ids), P.to_tensor(encoder_sids), # eos_id=tokenizer.sep_id, # sos_id=tokenizer.cls_id, # attn_id=tokenizer.vocab[args.attn_id], # max_decode_len=args.max_decode_len, # max_encode_len=args.max_encode_len, # beam_width=args.beam_width, # tgt_type_id=args.tgt_type_id) result_ids = beam_search_infilling( ernie, P.to_tensor(encoder_ids), P.to_tensor(encoder_sids), eos_id=tokenizer.sep_id, sos_id=tokenizer.cls_id, attn_id=tokenizer.vocab[args.attn_token], max_decode_len=args.max_decode_len, max_encode_len=args.max_encode_len, beam_width=args.beam_width, length_penalty=args.length_penalty, tgt_type_id=args.tgt_type_id) output_str = rev_lookup(result_ids.numpy()) for ostr in output_str.tolist(): if '[SEP]' in ostr: ostr = ostr[:ostr.index('[SEP]')] ostr = ''.join(map(post_process, ostr)) ostr = ostr.strip() print(ostr)