diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md b/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b6a8d230efa022c7207b9d95c371310968b23122 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md @@ -0,0 +1,99 @@ +## 概述 + +ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶段加入span-by-span 生成任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过填充式生成机制和噪声感知机制来缓解曝光偏差问题。此外, ERNIE-GEN 采样多片段-多粒度目标文本采样策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。 +

+
+

+ +更多详情参考论文[ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation](https://arxiv.org/abs/2001.11314) + +## 命令行预测 + +```shell +$ hub run ernie_gen_couplet --input_text="人增福寿年增岁" --use_gpu True --beam_width 5 +``` + +## API + +```python +def generate(texts, use_gpu=False, beam_width=5): +``` + +预测API,由上联生成下联。 + +**参数** + +* texts (list\[str\]): 上联文本; +* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA\_VISIBLE\_DEVICES环境变量**; +* beam_width: beam search宽度,决定每个上联输出的下联数量。 + +**返回** + +* results (list[list][str]): 下联文本,每个上联会生成beam_width个下联。 + +**代码示例** + +```python +import paddlehub as hub + +module = hub.Module(name="ernie_gen_couplet") + +test_texts = ["人增福寿年增岁", "风吹云乱天垂泪"] +results = module.genrate(texts=test_texts, use_gpu=True, beam_width=5) +for result in results: + print(result) +``` + +## 服务部署 + +PaddleHub Serving 可以部署在线服务。 + +### 第一步:启动PaddleHub Serving + +运行启动命令: +```shell +$ hub serving start -m ernie_gen_couplet -p 8866 +``` + +这样就完成了一个服务化API的部署,默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +### 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json + +# 发送HTTP请求 + +data = {'texts':["人增福寿年增岁", "风吹云乱天垂泪"], + 'use_gpu':False, 'beam_width':5} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/ernie_gen_couplet" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) + +# 保存结果 +results = r.json()["results"] +for result in results: + print(result) +``` + +## 查看代码 + +https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ + +### 依赖 + +paddlepaddle >= 1.8.2 + +paddlehub >= 1.7.0 + + +## 更新历史 + +* 1.0.0 + + 初始发布 diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/__init__.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/model/decode.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..c58fdbe2e8902346162f8733ef0cd94ba65757a2 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/decode.py @@ -0,0 +1,301 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import numpy as np +from collections import namedtuple + +import paddle.fluid as F +import paddle.fluid.layers as L +import paddle.fluid.dygraph as D + + +def gen_bias(encoder_inputs, decoder_inputs, step): + decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2] + attn_bias = L.reshape( + L.range(0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1]) + decoder_bias = L.cast( + (L.matmul(attn_bias, 1. / attn_bias, transpose_y=True) >= 1.), + 'float32') #[1, 1, decoderlen, decoderlen] + encoder_bias = L.unsqueeze( + L.cast(L.ones_like(encoder_inputs), 'float32'), + [1]) #[bsz, 1, encoderlen] + encoder_bias = L.expand( + encoder_bias, [1, decoder_seqlen, 1]) #[bsz,decoderlen, encoderlen] + decoder_bias = L.expand(decoder_bias, + [decoder_bsz, 1, 1]) #[bsz, decoderlen, decoderlen] + if step > 0: + bias = L.concat([ + encoder_bias, + L.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias + ], -1) + else: + bias = L.concat([encoder_bias, decoder_bias], -1) + return bias + + +@D.no_grad +def greedy_search_infilling(model, + q_ids, + q_sids, + sos_id, + eos_id, + attn_id, + max_encode_len=640, + max_decode_len=100, + tgt_type_id=3): + model.eval() + _, logits, info = model(q_ids, q_sids) + gen_ids = L.argmax(logits, -1) + d_batch, d_seqlen = q_ids.shape + seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) + has_stopped = np.zeros([d_batch], dtype=np.bool) + gen_seq_len = np.zeros([d_batch], dtype=np.int64) + output_ids = [] + + past_cache = info['caches'] + + cls_ids = L.ones([d_batch], dtype='int64') * sos_id + attn_ids = L.ones([d_batch], dtype='int64') * attn_id + ids = L.stack([cls_ids, attn_ids], -1) + for step in range(max_decode_len): + bias = gen_bias(q_ids, ids, step) + pos_ids = D.to_variable( + np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch, 1])) + pos_ids += seqlen + _, logits, info = model( + ids, + L.ones_like(ids) * tgt_type_id, + pos_ids=pos_ids, + attn_bias=bias, + past_cache=past_cache) + gen_ids = L.argmax(logits, -1) + + past_cached_k, past_cached_v = past_cache + cached_k, cached_v = info['caches'] + cached_k = [ + L.concat([pk, k[:, :1, :]], 1) + for pk, k in zip(past_cached_k, cached_k) + ] # concat cached + cached_v = [ + L.concat([pv, v[:, :1, :]], 1) + for pv, v in zip(past_cached_v, cached_v) + ] + past_cache = (cached_k, cached_v) + + gen_ids = gen_ids[:, 1] + ids = L.stack([gen_ids, attn_ids], 1) + + gen_ids = gen_ids.numpy() + has_stopped |= (gen_ids == eos_id).astype(np.bool) + gen_seq_len += (1 - has_stopped.astype(np.int64)) + output_ids.append(gen_ids.tolist()) + if has_stopped.all(): + break + output_ids = np.array(output_ids).transpose([1, 0]) + return output_ids + + +BeamSearchState = namedtuple('BeamSearchState', + ['log_probs', 'lengths', 'finished']) +BeamSearchOutput = namedtuple('BeamSearchOutput', + ['scores', 'predicted_ids', 'beam_parent_ids']) + + +def log_softmax(x): + e_x = np.exp(x - np.max(x)) + return np.log(e_x / e_x.sum()) + + +def mask_prob(p, onehot_eos, finished): + is_finished = L.cast(L.reshape(finished, [-1, 1]) != 0, 'float32') + p = is_finished * (1. - L.cast(onehot_eos, 'float32')) * -9999. + ( + 1. - is_finished) * p + return p + + +def hyp_score(log_probs, length, length_penalty): + lp = L.pow((5. + L.cast(length, 'float32')) / 6., length_penalty) + return log_probs / lp + + +def beam_search_step(state, logits, eos_id, beam_width, is_first_step, + length_penalty): + """logits.shape == [B*W, V]""" + _, vocab_size = logits.shape + + bsz, beam_width = state.log_probs.shape + onehot_eos = L.cast( + F.one_hot(L.ones([1], 'int64') * eos_id, vocab_size), 'int64') #[1, V] + + probs = L.log(L.softmax(logits)) #[B*W, V] + probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V] + allprobs = L.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V] + + not_finished = 1 - L.reshape(state.finished, [-1, 1]) #[B*W,1] + not_eos = 1 - onehot_eos + length_to_add = not_finished * not_eos #[B*W,V] + alllen = L.reshape(state.lengths, [-1, 1]) + length_to_add + + allprobs = L.reshape(allprobs, [-1, beam_width * vocab_size]) + alllen = L.reshape(alllen, [-1, beam_width * vocab_size]) + allscore = hyp_score(allprobs, alllen, length_penalty) + if is_first_step: + allscore = L.reshape( + allscore, + [bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0 + scores, idx = L.topk(allscore, k=beam_width) #[B, W] + next_beam_id = idx // vocab_size #[B, W] + next_word_id = idx % vocab_size + + gather_idx = L.concat([L.where(idx != -1)[:, :1], + L.reshape(idx, [-1, 1])], 1) + next_probs = L.reshape(L.gather_nd(allprobs, gather_idx), idx.shape) + next_len = L.reshape(L.gather_nd(alllen, gather_idx), idx.shape) + + gather_idx = L.concat( + [L.where(next_beam_id != -1)[:, :1], + L.reshape(next_beam_id, [-1, 1])], 1) + next_finished = L.reshape( + L.gather_nd(state.finished, gather_idx), + state.finished.shape) #[gather new beam state according to new beam id] + + next_finished += L.cast(next_word_id == eos_id, 'int64') + next_finished = L.cast(next_finished > 0, 'int64') + + next_state = BeamSearchState( + log_probs=next_probs, lengths=next_len, finished=next_finished) + output = BeamSearchOutput( + scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id) + + return output, next_state + + +@D.no_grad +def beam_search_infilling(model, + q_ids, + q_sids, + sos_id, + eos_id, + attn_id, + max_encode_len=640, + max_decode_len=100, + beam_width=5, + tgt_type_id=3, + length_penalty=1.0): + model.eval() + _, __, info = model(q_ids, q_sids) + d_batch, d_seqlen = q_ids.shape + + state = BeamSearchState( + log_probs=L.zeros([d_batch, beam_width], 'float32'), + lengths=L.zeros([d_batch, beam_width], 'int64'), + finished=L.zeros([d_batch, beam_width], 'int64')) + outputs = [] + + def reorder_(t, parent_id): + """reorder cache according to parent beam id""" + gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape( + parent_id, [-1]) + t = L.gather(t, gather_idx) + return t + + def tile_(t, times): + _shapes = list(t.shape[1:]) + ret = L.reshape( + L.expand(L.unsqueeze(t, [1]), [ + 1, + times, + ] + [ + 1, + ] * len(_shapes)), [ + -1, + ] + _shapes) + return ret + + cached_k, cached_v = info['caches'] + cached_k = [tile_(k, beam_width) for k in cached_k] + cached_v = [tile_(v, beam_width) for v in cached_v] + past_cache = (cached_k, cached_v) + + q_ids = tile_(q_ids, beam_width) + seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) + + cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id + attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id # SOS + ids = L.stack([cls_ids, attn_ids], -1) + for step in range(max_decode_len): + bias = gen_bias(q_ids, ids, step) + pos_ids = D.to_variable( + np.tile( + np.array([[step, step + 1]], dtype=np.int64), + [d_batch * beam_width, 1])) + pos_ids += seqlen + + _, logits, info = model( + ids, + L.ones_like(ids) * tgt_type_id, + pos_ids=pos_ids, + attn_bias=bias, + past_cache=past_cache) + + output, state = beam_search_step( + state, + logits[:, 1], + eos_id=eos_id, + beam_width=beam_width, + is_first_step=(step == 0), + length_penalty=length_penalty) + outputs.append(output) + + past_cached_k, past_cached_v = past_cache + cached_k, cached_v = info['caches'] + cached_k = [ + reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) + for pk, k in zip(past_cached_k, cached_k) + ] # concat cached + cached_v = [ + reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) + for pv, v in zip(past_cached_v, cached_v) + ] + past_cache = (cached_k, cached_v) + + pred_ids_flatten = L.reshape(output.predicted_ids, + [d_batch * beam_width]) + ids = L.stack([pred_ids_flatten, attn_ids], 1) + + if state.finished.numpy().all(): + break + + final_ids = L.stack([o.predicted_ids for o in outputs], 0) + final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0) + final_ids = L.gather_tree(final_ids, final_parent_ids) #[:, :, + #0] #pick best beam + final_ids = L.transpose( + L.reshape(final_ids, [-1, d_batch * 1, beam_width]), [1, 2, 0]) + return final_ids + + +en_patten = re.compile(r'^[a-zA-Z0-9]*$') + + +def post_process(token): + if token.startswith('##'): + ret = token[2:] + else: + if en_patten.match(token): + ret = ' ' + token + else: + ret = token + return ret diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/model/file_utils.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..613a5213a83e7fbd2a126cdb49b12eb62d4de41f --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/file_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from tqdm import tqdm +from paddlehub.common.logger import logger +from paddlehub.common.dir import MODULE_HOME + + +def _fetch_from_remote(url, force_download=False): + import tempfile, requests, tarfile + cached_dir = os.path.join(MODULE_HOME, "ernie_for_gen") + if force_download or not os.path.exists(cached_dir): + with tempfile.NamedTemporaryFile() as f: + #url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz' + r = requests.get(url, stream=True) + total_len = int(r.headers.get('content-length')) + for chunk in tqdm( + r.iter_content(chunk_size=1024), + total=total_len // 1024, + desc='downloading %s' % url, + unit='KB'): + if chunk: + f.write(chunk) + f.flush() + logger.debug('extacting... to %s' % f.name) + with tarfile.open(f.name) as tf: + tf.extractall(path=cached_dir) + logger.debug('%s cached in %s' % (url, cached_dir)) + return cached_dir + + +def add_docstring(doc): + def func(f): + f.__doc__ += ('\n======other docs from supper class ======\n%s' % doc) + return f + + return func diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/model/modeling_ernie.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/modeling_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2304f67d7347e584c244ab8384eff0720f7cc2 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/modeling_ernie.py @@ -0,0 +1,379 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import logging + +import paddle.fluid.dygraph as D +import paddle.fluid as F +import paddle.fluid.layers as L + +log = logging.getLogger(__name__) + + +def _build_linear(n_in, n_out, name, init, act=None): + return D.Linear( + n_in, + n_out, + param_attr=F.ParamAttr( + name='%s.w_0' % name if name is not None else None, + initializer=init), + bias_attr='%s.b_0' % name if name is not None else None, + act=act) + + +def _build_ln(n_in, name): + return D.LayerNorm( + normalized_shape=n_in, + param_attr=F.ParamAttr( + name='%s_layer_norm_scale' % name if name is not None else None, + initializer=F.initializer.Constant(1.)), + bias_attr=F.ParamAttr( + name='%s_layer_norm_bias' % name if name is not None else None, + initializer=F.initializer.Constant(1.)), + ) + + +def append_name(name, postfix): + if name is None: + return None + elif name == '': + return postfix + else: + return '%s_%s' % (name, postfix) + + +class AttentionLayer(D.Layer): + def __init__(self, cfg, name=None): + super(AttentionLayer, self).__init__() + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + d_model = cfg['hidden_size'] + n_head = cfg['num_attention_heads'] + assert d_model % n_head == 0 + d_model_q = cfg.get('query_hidden_size_per_head', + d_model // n_head) * n_head + d_model_v = cfg.get('value_hidden_size_per_head', + d_model // n_head) * n_head + self.n_head = n_head + self.d_key = d_model_q // n_head + self.q = _build_linear(d_model, d_model_q, append_name( + name, 'query_fc'), initializer) + self.k = _build_linear(d_model, d_model_q, append_name(name, 'key_fc'), + initializer) + self.v = _build_linear(d_model, d_model_v, append_name( + name, 'value_fc'), initializer) + self.o = _build_linear(d_model_v, d_model, append_name( + name, 'output_fc'), initializer) + self.dropout = lambda i: L.dropout( + i, + dropout_prob=cfg['attention_probs_dropout_prob'], + dropout_implementation="upscale_in_train", + ) if self.training else i + + def forward(self, queries, keys, values, attn_bias, past_cache): + assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3 + + q = self.q(queries) + k = self.k(keys) + v = self.v(values) + + cache = (k, v) + if past_cache is not None: + cached_k, cached_v = past_cache + k = L.concat([cached_k, k], 1) + v = L.concat([cached_v, v], 1) + + q = L.transpose( + L.reshape(q, [0, 0, self.n_head, q.shape[-1] // self.n_head]), + [0, 2, 1, 3]) #[batch, head, seq, dim] + k = L.transpose( + L.reshape(k, [0, 0, self.n_head, k.shape[-1] // self.n_head]), + [0, 2, 1, 3]) #[batch, head, seq, dim] + v = L.transpose( + L.reshape(v, [0, 0, self.n_head, v.shape[-1] // self.n_head]), + [0, 2, 1, 3]) #[batch, head, seq, dim] + + q = L.scale(q, scale=self.d_key**-0.5) + score = L.matmul(q, k, transpose_y=True) + if attn_bias is not None: + score += attn_bias + score = L.softmax(score, use_cudnn=True) + score = self.dropout(score) + + out = L.matmul(score, v) + out = L.transpose(out, [0, 2, 1, 3]) + out = L.reshape(out, [0, 0, out.shape[2] * out.shape[3]]) + + out = self.o(out) + return out, cache + + +class PositionwiseFeedForwardLayer(D.Layer): + def __init__(self, cfg, name=None): + super(PositionwiseFeedForwardLayer, self).__init__() + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + d_model = cfg['hidden_size'] + d_ffn = cfg.get('intermediate_size', 4 * d_model) + assert cfg['hidden_act'] in ['relu', 'gelu'] + self.i = _build_linear( + d_model, + d_ffn, + append_name(name, 'fc_0'), + initializer, + act=cfg['hidden_act']) + self.o = _build_linear(d_ffn, d_model, append_name(name, 'fc_1'), + initializer) + prob = cfg.get('intermediate_dropout_prob', 0.) + self.dropout = lambda i: L.dropout( + i, + dropout_prob=prob, + dropout_implementation="upscale_in_train", + ) if self.training else i + + def forward(self, inputs): + hidden = self.i(inputs) + hidden = self.dropout(hidden) + out = self.o(hidden) + return out + + +class ErnieBlock(D.Layer): + def __init__(self, cfg, name=None): + super(ErnieBlock, self).__init__() + d_model = cfg['hidden_size'] + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + + self.attn = AttentionLayer( + cfg, name=append_name(name, 'multi_head_att')) + self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att')) + self.ffn = PositionwiseFeedForwardLayer( + cfg, name=append_name(name, 'ffn')) + self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn')) + prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob']) + self.dropout = lambda i: L.dropout( + i, + dropout_prob=prob, + dropout_implementation="upscale_in_train", + ) if self.training else i + + def forward(self, inputs, attn_bias=None, past_cache=None): + attn_out, cache = self.attn( + inputs, inputs, inputs, attn_bias, + past_cache=past_cache) #self attn + attn_out = self.dropout(attn_out) + hidden = attn_out + inputs + hidden = self.ln1(hidden) # dropout/ add/ norm + + ffn_out = self.ffn(hidden) + ffn_out = self.dropout(ffn_out) + hidden = ffn_out + hidden + hidden = self.ln2(hidden) + return hidden, cache + + +class ErnieEncoderStack(D.Layer): + def __init__(self, cfg, name=None): + super(ErnieEncoderStack, self).__init__() + n_layers = cfg['num_hidden_layers'] + self.block = D.LayerList([ + ErnieBlock(cfg, append_name(name, 'layer_%d' % i)) + for i in range(n_layers) + ]) + + def forward(self, inputs, attn_bias=None, past_cache=None): + if past_cache is not None: + assert isinstance( + past_cache, tuple + ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr( + type(past_cache)) + past_cache = list(zip(*past_cache)) + else: + past_cache = [None] * len(self.block) + cache_list_k, cache_list_v, hidden_list = [], [], [inputs] + + for b, p in zip(self.block, past_cache): + inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p) + cache_k, cache_v = cache + cache_list_k.append(cache_k) + cache_list_v.append(cache_v) + hidden_list.append(inputs) + + return inputs, hidden_list, (cache_list_k, cache_list_v) + + +class ErnieModel(D.Layer): + def __init__(self, cfg, name=None): + """ + Fundamental pretrained Ernie model + """ + log.debug('init ErnieModel with config: %s' % repr(cfg)) + D.Layer.__init__(self) + d_model = cfg['hidden_size'] + d_emb = cfg.get('emb_size', cfg['hidden_size']) + d_vocab = cfg['vocab_size'] + d_pos = cfg['max_position_embeddings'] + d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size'] + self.n_head = cfg['num_attention_heads'] + self.return_additional_info = cfg.get('return_additional_info', False) + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + + self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder')) + self.word_emb = D.Embedding([d_vocab, d_emb], + param_attr=F.ParamAttr( + name=append_name( + name, 'word_embedding'), + initializer=initializer)) + self.pos_emb = D.Embedding([d_pos, d_emb], + param_attr=F.ParamAttr( + name=append_name(name, 'pos_embedding'), + initializer=initializer)) + self.sent_emb = D.Embedding([d_sent, d_emb], + param_attr=F.ParamAttr( + name=append_name( + name, 'sent_embedding'), + initializer=initializer)) + prob = cfg['hidden_dropout_prob'] + self.dropout = lambda i: L.dropout( + i, + dropout_prob=prob, + dropout_implementation="upscale_in_train", + ) if self.training else i + + self.encoder_stack = ErnieEncoderStack(cfg, append_name( + name, 'encoder')) + if cfg.get('has_pooler', True): + self.pooler = _build_linear( + cfg['hidden_size'], + cfg['hidden_size'], + append_name(name, 'pooled_fc'), + initializer, + act='tanh') + else: + self.pooler = None + self.train() + + def eval(self): + if F.in_dygraph_mode(): + super(ErnieModel, self).eval() + self.training = False + for l in self.sublayers(): + l.training = False + + def train(self): + if F.in_dygraph_mode(): + super(ErnieModel, self).train() + self.training = True + for l in self.sublayers(): + l.training = True + + def forward(self, + src_ids, + sent_ids=None, + pos_ids=None, + input_mask=None, + attn_bias=None, + past_cache=None, + use_causal_mask=False): + """ + Args: + src_ids (`Variable` of shape `[batch_size, seq_len]`): + Indices of input sequence tokens in the vocabulary. + sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`): + aka token_type_ids, Segment token indices to indicate first and second portions of the inputs. + if None, assume all tokens come from `segment_a` + pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`): + Indices of positions of each input sequence tokens in the position embeddings. + input_mask(optional `Variable` of shape `[batch_size, seq_len]`): + Mask to avoid performing attention on the padding token indices of the encoder input. + attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`): + 3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask + past_cache(optional, tuple of two lists: cached key and cached value, + each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`): + cached key/value tensor that will be concated to generated key/value when performing self attention. + if set, `attn_bias` should not be None. + + Returns: + pooled (`Variable` of shape `[batch_size, hidden_size]`): + output logits of pooler classifier + encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`): + output logits of transformer stack + """ + assert len( + src_ids.shape + ) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (repr( + src_ids.shape)) + assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' + d_batch = L.shape(src_ids)[0] + d_seqlen = L.shape(src_ids)[1] + if pos_ids is None: + pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1]) + pos_ids = L.cast(pos_ids, 'int64') + if attn_bias is None: + if input_mask is None: + input_mask = L.cast(src_ids != 0, 'float32') + assert len(input_mask.shape) == 2 + input_mask = L.unsqueeze(input_mask, axes=[-1]) + attn_bias = L.matmul(input_mask, input_mask, transpose_y=True) + if use_causal_mask: + sequence = L.reshape( + L.range(0, d_seqlen, 1, dtype='float32') + 1., + [1, 1, -1, 1]) + causal_mask = L.cast( + (L.matmul(sequence, 1. / sequence, transpose_y=True) >= 1.), + 'float32') + attn_bias *= causal_mask + else: + assert len( + attn_bias.shape + ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape + attn_bias = (1. - attn_bias) * -10000.0 + attn_bias = L.unsqueeze(attn_bias, [1]) + attn_bias = L.expand(attn_bias, + [1, self.n_head, 1, 1]) # avoid broadcast =_= + attn_bias.stop_gradient = True + + if sent_ids is None: + sent_ids = L.zeros_like(src_ids) + + src_embedded = self.word_emb(src_ids) + pos_embedded = self.pos_emb(pos_ids) + sent_embedded = self.sent_emb(sent_ids) + embedded = src_embedded + pos_embedded + sent_embedded + + embedded = self.dropout(self.ln(embedded)) + + encoded, hidden_list, cache_list = self.encoder_stack( + embedded, attn_bias, past_cache=past_cache) + if self.pooler is not None: + pooled = self.pooler(encoded[:, 0, :]) + else: + pooled = None + + additional_info = { + 'hiddens': hidden_list, + 'caches': cache_list, + } + + if self.return_additional_info: + return pooled, encoded, additional_info + else: + return pooled, encoded diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/model/modeling_ernie_gen.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/modeling_ernie_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..c2245ec3f03c4bf75ece5c5856e7074d4ab28b68 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/modeling_ernie_gen.py @@ -0,0 +1,78 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as F +import paddle.fluid.layers as L + +from ernie_gen_couplet.model.modeling_ernie import ErnieModel +from ernie_gen_couplet.model.modeling_ernie import _build_linear, _build_ln, append_name + + +class ErnieModelForGeneration(ErnieModel): + def __init__(self, cfg, name=None): + cfg['return_additional_info'] = True + cfg['has_pooler'] = False + super(ErnieModelForGeneration, self).__init__(cfg, name=name) + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + d_model = cfg['hidden_size'] + d_vocab = cfg['vocab_size'] + + self.mlm = _build_linear( + d_model, + d_model, + append_name(name, 'mask_lm_trans_fc'), + initializer, + act=cfg['hidden_act']) + self.mlm_ln = _build_ln( + d_model, name=append_name(name, 'mask_lm_trans')) + self.mlm_bias = L.create_parameter( + dtype='float32', + shape=[d_vocab], + attr=F.ParamAttr( + name=append_name(name, 'mask_lm_out_fc.b_0'), + initializer=F.initializer.Constant(value=0.0)), + is_bias=True, + ) + + def forward(self, src_ids, *args, **kwargs): + tgt_labels = kwargs.pop('tgt_labels', None) + tgt_pos = kwargs.pop('tgt_pos', None) + encode_only = kwargs.pop('encode_only', False) + _, encoded, info = ErnieModel.forward(self, src_ids, *args, **kwargs) + if encode_only: + return None, None, info + elif tgt_labels is None: + encoded = self.mlm(encoded) + encoded = self.mlm_ln(encoded) + logits = L.matmul( + encoded, self.word_emb.weight, transpose_y=True) + self.mlm_bias + output_ids = L.argmax(logits, -1) + return output_ids, logits, info + else: + encoded_2d = L.gather_nd(encoded, tgt_pos) + encoded_2d = self.mlm(encoded_2d) + encoded_2d = self.mlm_ln(encoded_2d) + logits_2d = L.matmul( + encoded_2d, self.word_emb.weight, + transpose_y=True) + self.mlm_bias + if len(tgt_labels.shape) == 1: + tgt_labels = L.reshape(tgt_labels, [-1, 1]) + + loss = L.reduce_mean( + L.softmax_with_cross_entropy( + logits_2d, + tgt_labels, + soft_label=(tgt_labels.shape[-1] != 1))) + return loss, logits_2d, info diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/model/tokenizing_ernie.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/tokenizing_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..3039b7028f5da991189527b8145b05c952dafbbd --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/model/tokenizing_ernie.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import re +import logging +from functools import partial + +import numpy as np + +import io + +open = partial(io.open, encoding='utf8') + +log = logging.getLogger(__name__) + +_max_input_chars_per_word = 100 + + +def _wordpiece(token, vocab, unk_token, prefix='##', sentencepiece_prefix=''): + """ wordpiece: helloworld => [hello, ##world] """ + chars = list(token) + if len(chars) > _max_input_chars_per_word: + return [unk_token], [(0, len(chars))] + + is_bad = False + start = 0 + sub_tokens = [] + sub_pos = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start == 0: + substr = sentencepiece_prefix + substr + if start > 0: + substr = prefix + substr + if substr in vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + sub_pos.append((start, end)) + start = end + if is_bad: + return [unk_token], [(0, len(chars))] + else: + return sub_tokens, sub_pos + + +class ErnieTokenizer(object): + def __init__(self, + vocab, + unk_token='[UNK]', + sep_token='[SEP]', + cls_token='[CLS]', + pad_token='[PAD]', + mask_token='[MASK]', + wordpiece_prefix='##', + sentencepiece_prefix='', + lower=True, + encoding='utf8', + special_token_list=[]): + if not isinstance(vocab, dict): + raise ValueError( + 'expect `vocab` to be instance of dict, got %s' % type(vocab)) + self.vocab = vocab + self.lower = lower + self.prefix = wordpiece_prefix + self.sentencepiece_prefix = sentencepiece_prefix + self.pad_id = self.vocab[pad_token] + self.cls_id = cls_token and self.vocab[cls_token] + self.sep_id = sep_token and self.vocab[sep_token] + self.unk_id = unk_token and self.vocab[unk_token] + self.mask_id = mask_token and self.vocab[mask_token] + self.unk_token = unk_token + special_tokens = { + pad_token, cls_token, sep_token, unk_token, mask_token + } | set(special_token_list) + pat_str = '' + for t in special_tokens: + if t is None: + continue + pat_str += '(%s)|' % re.escape(t) + pat_str += r'([a-zA-Z0-9]+|\S)' + log.debug('regex: %s' % pat_str) + self.pat = re.compile(pat_str) + self.encoding = encoding + + def tokenize(self, text): + if len(text) == 0: + return [] + if six.PY3 and not isinstance(text, six.string_types): + text = text.decode(self.encoding) + if six.PY2 and isinstance(text, str): + text = text.decode(self.encoding) + + res = [] + for match in self.pat.finditer(text): + match_group = match.group(0) + if match.groups()[-1]: + if self.lower: + match_group = match_group.lower() + words, _ = _wordpiece( + match_group, + vocab=self.vocab, + unk_token=self.unk_token, + prefix=self.prefix, + sentencepiece_prefix=self.sentencepiece_prefix) + else: + words = [match_group] + res += words + return res + + def convert_tokens_to_ids(self, tokens): + return [self.vocab.get(t, self.unk_id) for t in tokens] + + def truncate(self, id1, id2, seqlen): + len1 = len(id1) + len2 = len(id2) + half = seqlen // 2 + if len1 > len2: + len1_truncated, len2_truncated = max(half, seqlen - len2), min( + half, len2) + else: + len1_truncated, len2_truncated = min(half, seqlen - len1), max( + half, seqlen - len1) + return id1[:len1_truncated], id2[:len2_truncated] + + def build_for_ernie(self, text_id, pair_id=[]): + """build sentence type id, add [CLS] [SEP]""" + text_id_type = np.zeros_like(text_id, dtype=np.int64) + ret_id = np.concatenate([[self.cls_id], text_id, [self.sep_id]], 0) + ret_id_type = np.concatenate([[0], text_id_type, [0]], 0) + + if len(pair_id): + pair_id_type = np.ones_like(pair_id, dtype=np.int64) + ret_id = np.concatenate([ret_id, pair_id, [self.sep_id]], 0) + ret_id_type = np.concatenate([ret_id_type, pair_id_type, [1]], 0) + return ret_id, ret_id_type + + def encode(self, text, pair=None, truncate_to=None): + text_id = np.array( + self.convert_tokens_to_ids(self.tokenize(text)), dtype=np.int64) + text_id_type = np.zeros_like(text_id, dtype=np.int64) + if pair is not None: + pair_id = np.array( + self.convert_tokens_to_ids(self.tokenize(pair)), dtype=np.int64) + else: + pair_id = [] + if truncate_to is not None: + text_id, pair_id = self.truncate( + text_id, [] if pair_id is None else pair_id, truncate_to) + + ret_id, ret_id_type = self.build_for_ernie(text_id, pair_id) + return ret_id, ret_id_type diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py new file mode 100644 index 0000000000000000000000000000000000000000..8640a3559c9c5b78c6d3f40a72d66364e39c7738 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py @@ -0,0 +1,187 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast +import json + +import paddle.fluid as fluid +import paddlehub as hub +from paddlehub.module.module import runnable +from paddlehub.module.nlp_module import DataFormatError +from paddlehub.common.logger import logger +from paddlehub.module.module import moduleinfo, serving + +import argparse +import os +import numpy as np + +import paddle.fluid.dygraph as D + +from ernie_gen_couplet.model.tokenizing_ernie import ErnieTokenizer +from ernie_gen_couplet.model.decode import beam_search_infilling +from ernie_gen_couplet.model.modeling_ernie_gen import ErnieModelForGeneration + + +@moduleinfo( + name="ernie_gen_couplet", + version="1.0.0", + summary= + "ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for couplet generation task.", + author="baidu-nlp", + author_email="", + type="nlp/text_generation", +) +class ErnieGen(hub.NLPPredictionModule): + def _initialize(self): + """ + initialize with the necessary elements + """ + assets_path = os.path.join(self.directory, "assets") + gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_couplet") + ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json') + ernie_cfg = dict(json.loads(open(ernie_cfg_path).read())) + ernie_vocab_path = os.path.join(assets_path, 'vocab.txt') + ernie_vocab = { + j.strip().split('\t')[0]: i + for i, j in enumerate(open(ernie_vocab_path).readlines()) + } + + with fluid.dygraph.guard(fluid.CPUPlace()): + with fluid.unique_name.guard(): + self.model = ErnieModelForGeneration(ernie_cfg) + finetuned_states, _ = D.load_dygraph(gen_checkpoint_path) + self.model.set_dict(finetuned_states) + + self.tokenizer = ErnieTokenizer(ernie_vocab) + self.rev_dict = {v: k for k, v in self.tokenizer.vocab.items()} + self.rev_dict[self.tokenizer.pad_id] = '' # replace [PAD] + self.rev_dict[self.tokenizer.unk_id] = '' # replace [PAD] + self.rev_lookup = np.vectorize(lambda i: self.rev_dict[i]) + + @serving + def generate(self, texts, use_gpu=False, beam_width=5): + """ + Get the right rolls from the left rolls. + + Args: + texts(list): the left rolls. + use_gpu(bool): whether use gpu to predict or not + beam_width(int): the beam search width. + + Returns: + results(list): the right rolls. + """ + if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ: + use_gpu = False + logger.warning( + "use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True" + ) + if use_gpu: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + if texts and isinstance(texts, list): + predicted_data = texts + else: + raise ValueError( + "The input data is inconsistent with expectations.") + + with fluid.dygraph.guard(place): + self.model.eval() + results = [] + for text in predicted_data: + sample_results = [] + ids, sids = self.tokenizer.encode(text) + src_ids = D.to_variable(np.expand_dims(ids, 0)) + src_sids = D.to_variable(np.expand_dims(sids, 0)) + output_ids = beam_search_infilling( + self.model, + src_ids, + src_sids, + eos_id=self.tokenizer.sep_id, + sos_id=self.tokenizer.cls_id, + attn_id=self.tokenizer.vocab['[MASK]'], + max_decode_len=20, + max_encode_len=20, + beam_width=beam_width, + tgt_type_id=1) + output_str = self.rev_lookup(output_ids[0].numpy()) + + for ostr in output_str.tolist(): + if '[SEP]' in ostr: + ostr = ostr[:ostr.index('[SEP]')] + sample_results.append("".join(ostr)) + results.append(sample_results) + return results + + def add_module_config_arg(self): + """ + Add the command config options + """ + self.arg_config_group.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=False, + help="whether use GPU for prediction") + + self.arg_config_group.add_argument( + '--beam_width', type=int, default=5, help="the beam search width") + + @runnable + def run_cmd(self, argvs): + """ + Run as a command + """ + self.parser = argparse.ArgumentParser( + description='Run the %s module.' % self.name, + prog='hub run %s' % self.name, + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group( + title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", + description= + "Run configuration for controlling module behavior, optional.") + + self.add_module_config_arg() + self.add_module_input_arg() + + args = self.parser.parse_args(argvs) + + try: + input_data = self.check_input_data(args) + except DataFormatError and RuntimeError: + self.parser.print_help() + return None + + results = self.generate( + texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width) + + return results + + @serving + def serving_method(self, texts, use_gpu=False): + """ + Run as a service. + """ + return self.generate(texts, use_gpu) + + +if __name__ == "__main__": + module = ErnieGen() + for result in module.generate(['人增福寿年增岁', '风吹云乱天垂泪'], beam_width=5): + print(result) diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md b/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e66239b27733c1a56a53985caa243057a0ff5f6a --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md @@ -0,0 +1,99 @@ +## 概述 + +ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶段加入span-by-span 生成任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过填充式生成机制和噪声感知机制来缓解曝光偏差问题。此外, ERNIE-GEN 采样多片段-多粒度目标文本采样策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。 +

+
+

+ +更多详情参考论文[ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation](https://arxiv.org/abs/2001.11314) + +## 命令行预测 + +```shell +$ hub run ernie_gen_poetry --input_text="宝积峰前露术香,使君行旆照晴阳。" --use_gpu True --beam_width 5 +``` + +## API + +```python +def generate(texts, use_gpu=False, beam_width=5): +``` + +预测API,输入诗歌开头,输出诗歌下文。 + +**参数** + +* texts (list\[str\]): 诗歌的开头; +* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA\_VISIBLE\_DEVICES环境变量**; +* beam_width: beam search宽度,决定每个诗歌开头输出的下文数目。 + +**返回** + +* results (list[list][str]): 诗歌下文,每个诗歌开头会生成beam_width个下文。 + +**代码示例** + +```python +import paddlehub as hub + +module = hub.Module(name="ernie_gen_poetry") + +test_texts = ["宝积峰前露术香,使君行旆照晴阳。"] +results = module.genrate(texts=test_texts, use_gpu=True, beam_width=5) +for result in results: + print(result) +``` + +## 服务部署 + +PaddleHub Serving 可以部署在线服务。 + +### 第一步:启动PaddleHub Serving + +运行启动命令: +```shell +$ hub serving start -m ernie_gen_poetry -p 8866 +``` + +这样就完成了一个服务化API的部署,默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +### 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json + +# 发送HTTP请求 + +data = {'texts':["宝积峰前露术香,使君行旆照晴阳。"], + 'use_gpu':False, 'beam_width':5} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/ernie_gen_poetry" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) + +# 保存结果 +results = r.json()["results"] +for result in results: + print(result) +``` + +## 查看代码 + +https://github.com/PaddlePaddle/ERNIE/blob/repro/ernie-gen/ + +### 依赖 + +paddlepaddle >= 1.8.2 + +paddlehub >= 1.7.0 + + +## 更新历史 + +* 1.0.0 + + 初始发布 diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/__init__.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/model/decode.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..c58fdbe2e8902346162f8733ef0cd94ba65757a2 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/decode.py @@ -0,0 +1,301 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import numpy as np +from collections import namedtuple + +import paddle.fluid as F +import paddle.fluid.layers as L +import paddle.fluid.dygraph as D + + +def gen_bias(encoder_inputs, decoder_inputs, step): + decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2] + attn_bias = L.reshape( + L.range(0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1]) + decoder_bias = L.cast( + (L.matmul(attn_bias, 1. / attn_bias, transpose_y=True) >= 1.), + 'float32') #[1, 1, decoderlen, decoderlen] + encoder_bias = L.unsqueeze( + L.cast(L.ones_like(encoder_inputs), 'float32'), + [1]) #[bsz, 1, encoderlen] + encoder_bias = L.expand( + encoder_bias, [1, decoder_seqlen, 1]) #[bsz,decoderlen, encoderlen] + decoder_bias = L.expand(decoder_bias, + [decoder_bsz, 1, 1]) #[bsz, decoderlen, decoderlen] + if step > 0: + bias = L.concat([ + encoder_bias, + L.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias + ], -1) + else: + bias = L.concat([encoder_bias, decoder_bias], -1) + return bias + + +@D.no_grad +def greedy_search_infilling(model, + q_ids, + q_sids, + sos_id, + eos_id, + attn_id, + max_encode_len=640, + max_decode_len=100, + tgt_type_id=3): + model.eval() + _, logits, info = model(q_ids, q_sids) + gen_ids = L.argmax(logits, -1) + d_batch, d_seqlen = q_ids.shape + seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) + has_stopped = np.zeros([d_batch], dtype=np.bool) + gen_seq_len = np.zeros([d_batch], dtype=np.int64) + output_ids = [] + + past_cache = info['caches'] + + cls_ids = L.ones([d_batch], dtype='int64') * sos_id + attn_ids = L.ones([d_batch], dtype='int64') * attn_id + ids = L.stack([cls_ids, attn_ids], -1) + for step in range(max_decode_len): + bias = gen_bias(q_ids, ids, step) + pos_ids = D.to_variable( + np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch, 1])) + pos_ids += seqlen + _, logits, info = model( + ids, + L.ones_like(ids) * tgt_type_id, + pos_ids=pos_ids, + attn_bias=bias, + past_cache=past_cache) + gen_ids = L.argmax(logits, -1) + + past_cached_k, past_cached_v = past_cache + cached_k, cached_v = info['caches'] + cached_k = [ + L.concat([pk, k[:, :1, :]], 1) + for pk, k in zip(past_cached_k, cached_k) + ] # concat cached + cached_v = [ + L.concat([pv, v[:, :1, :]], 1) + for pv, v in zip(past_cached_v, cached_v) + ] + past_cache = (cached_k, cached_v) + + gen_ids = gen_ids[:, 1] + ids = L.stack([gen_ids, attn_ids], 1) + + gen_ids = gen_ids.numpy() + has_stopped |= (gen_ids == eos_id).astype(np.bool) + gen_seq_len += (1 - has_stopped.astype(np.int64)) + output_ids.append(gen_ids.tolist()) + if has_stopped.all(): + break + output_ids = np.array(output_ids).transpose([1, 0]) + return output_ids + + +BeamSearchState = namedtuple('BeamSearchState', + ['log_probs', 'lengths', 'finished']) +BeamSearchOutput = namedtuple('BeamSearchOutput', + ['scores', 'predicted_ids', 'beam_parent_ids']) + + +def log_softmax(x): + e_x = np.exp(x - np.max(x)) + return np.log(e_x / e_x.sum()) + + +def mask_prob(p, onehot_eos, finished): + is_finished = L.cast(L.reshape(finished, [-1, 1]) != 0, 'float32') + p = is_finished * (1. - L.cast(onehot_eos, 'float32')) * -9999. + ( + 1. - is_finished) * p + return p + + +def hyp_score(log_probs, length, length_penalty): + lp = L.pow((5. + L.cast(length, 'float32')) / 6., length_penalty) + return log_probs / lp + + +def beam_search_step(state, logits, eos_id, beam_width, is_first_step, + length_penalty): + """logits.shape == [B*W, V]""" + _, vocab_size = logits.shape + + bsz, beam_width = state.log_probs.shape + onehot_eos = L.cast( + F.one_hot(L.ones([1], 'int64') * eos_id, vocab_size), 'int64') #[1, V] + + probs = L.log(L.softmax(logits)) #[B*W, V] + probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V] + allprobs = L.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V] + + not_finished = 1 - L.reshape(state.finished, [-1, 1]) #[B*W,1] + not_eos = 1 - onehot_eos + length_to_add = not_finished * not_eos #[B*W,V] + alllen = L.reshape(state.lengths, [-1, 1]) + length_to_add + + allprobs = L.reshape(allprobs, [-1, beam_width * vocab_size]) + alllen = L.reshape(alllen, [-1, beam_width * vocab_size]) + allscore = hyp_score(allprobs, alllen, length_penalty) + if is_first_step: + allscore = L.reshape( + allscore, + [bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0 + scores, idx = L.topk(allscore, k=beam_width) #[B, W] + next_beam_id = idx // vocab_size #[B, W] + next_word_id = idx % vocab_size + + gather_idx = L.concat([L.where(idx != -1)[:, :1], + L.reshape(idx, [-1, 1])], 1) + next_probs = L.reshape(L.gather_nd(allprobs, gather_idx), idx.shape) + next_len = L.reshape(L.gather_nd(alllen, gather_idx), idx.shape) + + gather_idx = L.concat( + [L.where(next_beam_id != -1)[:, :1], + L.reshape(next_beam_id, [-1, 1])], 1) + next_finished = L.reshape( + L.gather_nd(state.finished, gather_idx), + state.finished.shape) #[gather new beam state according to new beam id] + + next_finished += L.cast(next_word_id == eos_id, 'int64') + next_finished = L.cast(next_finished > 0, 'int64') + + next_state = BeamSearchState( + log_probs=next_probs, lengths=next_len, finished=next_finished) + output = BeamSearchOutput( + scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id) + + return output, next_state + + +@D.no_grad +def beam_search_infilling(model, + q_ids, + q_sids, + sos_id, + eos_id, + attn_id, + max_encode_len=640, + max_decode_len=100, + beam_width=5, + tgt_type_id=3, + length_penalty=1.0): + model.eval() + _, __, info = model(q_ids, q_sids) + d_batch, d_seqlen = q_ids.shape + + state = BeamSearchState( + log_probs=L.zeros([d_batch, beam_width], 'float32'), + lengths=L.zeros([d_batch, beam_width], 'int64'), + finished=L.zeros([d_batch, beam_width], 'int64')) + outputs = [] + + def reorder_(t, parent_id): + """reorder cache according to parent beam id""" + gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape( + parent_id, [-1]) + t = L.gather(t, gather_idx) + return t + + def tile_(t, times): + _shapes = list(t.shape[1:]) + ret = L.reshape( + L.expand(L.unsqueeze(t, [1]), [ + 1, + times, + ] + [ + 1, + ] * len(_shapes)), [ + -1, + ] + _shapes) + return ret + + cached_k, cached_v = info['caches'] + cached_k = [tile_(k, beam_width) for k in cached_k] + cached_v = [tile_(v, beam_width) for v in cached_v] + past_cache = (cached_k, cached_v) + + q_ids = tile_(q_ids, beam_width) + seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True) + + cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id + attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id # SOS + ids = L.stack([cls_ids, attn_ids], -1) + for step in range(max_decode_len): + bias = gen_bias(q_ids, ids, step) + pos_ids = D.to_variable( + np.tile( + np.array([[step, step + 1]], dtype=np.int64), + [d_batch * beam_width, 1])) + pos_ids += seqlen + + _, logits, info = model( + ids, + L.ones_like(ids) * tgt_type_id, + pos_ids=pos_ids, + attn_bias=bias, + past_cache=past_cache) + + output, state = beam_search_step( + state, + logits[:, 1], + eos_id=eos_id, + beam_width=beam_width, + is_first_step=(step == 0), + length_penalty=length_penalty) + outputs.append(output) + + past_cached_k, past_cached_v = past_cache + cached_k, cached_v = info['caches'] + cached_k = [ + reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) + for pk, k in zip(past_cached_k, cached_k) + ] # concat cached + cached_v = [ + reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) + for pv, v in zip(past_cached_v, cached_v) + ] + past_cache = (cached_k, cached_v) + + pred_ids_flatten = L.reshape(output.predicted_ids, + [d_batch * beam_width]) + ids = L.stack([pred_ids_flatten, attn_ids], 1) + + if state.finished.numpy().all(): + break + + final_ids = L.stack([o.predicted_ids for o in outputs], 0) + final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0) + final_ids = L.gather_tree(final_ids, final_parent_ids) #[:, :, + #0] #pick best beam + final_ids = L.transpose( + L.reshape(final_ids, [-1, d_batch * 1, beam_width]), [1, 2, 0]) + return final_ids + + +en_patten = re.compile(r'^[a-zA-Z0-9]*$') + + +def post_process(token): + if token.startswith('##'): + ret = token[2:] + else: + if en_patten.match(token): + ret = ' ' + token + else: + ret = token + return ret diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/model/file_utils.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..613a5213a83e7fbd2a126cdb49b12eb62d4de41f --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/file_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from tqdm import tqdm +from paddlehub.common.logger import logger +from paddlehub.common.dir import MODULE_HOME + + +def _fetch_from_remote(url, force_download=False): + import tempfile, requests, tarfile + cached_dir = os.path.join(MODULE_HOME, "ernie_for_gen") + if force_download or not os.path.exists(cached_dir): + with tempfile.NamedTemporaryFile() as f: + #url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz' + r = requests.get(url, stream=True) + total_len = int(r.headers.get('content-length')) + for chunk in tqdm( + r.iter_content(chunk_size=1024), + total=total_len // 1024, + desc='downloading %s' % url, + unit='KB'): + if chunk: + f.write(chunk) + f.flush() + logger.debug('extacting... to %s' % f.name) + with tarfile.open(f.name) as tf: + tf.extractall(path=cached_dir) + logger.debug('%s cached in %s' % (url, cached_dir)) + return cached_dir + + +def add_docstring(doc): + def func(f): + f.__doc__ += ('\n======other docs from supper class ======\n%s' % doc) + return f + + return func diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2304f67d7347e584c244ab8384eff0720f7cc2 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie.py @@ -0,0 +1,379 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import logging + +import paddle.fluid.dygraph as D +import paddle.fluid as F +import paddle.fluid.layers as L + +log = logging.getLogger(__name__) + + +def _build_linear(n_in, n_out, name, init, act=None): + return D.Linear( + n_in, + n_out, + param_attr=F.ParamAttr( + name='%s.w_0' % name if name is not None else None, + initializer=init), + bias_attr='%s.b_0' % name if name is not None else None, + act=act) + + +def _build_ln(n_in, name): + return D.LayerNorm( + normalized_shape=n_in, + param_attr=F.ParamAttr( + name='%s_layer_norm_scale' % name if name is not None else None, + initializer=F.initializer.Constant(1.)), + bias_attr=F.ParamAttr( + name='%s_layer_norm_bias' % name if name is not None else None, + initializer=F.initializer.Constant(1.)), + ) + + +def append_name(name, postfix): + if name is None: + return None + elif name == '': + return postfix + else: + return '%s_%s' % (name, postfix) + + +class AttentionLayer(D.Layer): + def __init__(self, cfg, name=None): + super(AttentionLayer, self).__init__() + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + d_model = cfg['hidden_size'] + n_head = cfg['num_attention_heads'] + assert d_model % n_head == 0 + d_model_q = cfg.get('query_hidden_size_per_head', + d_model // n_head) * n_head + d_model_v = cfg.get('value_hidden_size_per_head', + d_model // n_head) * n_head + self.n_head = n_head + self.d_key = d_model_q // n_head + self.q = _build_linear(d_model, d_model_q, append_name( + name, 'query_fc'), initializer) + self.k = _build_linear(d_model, d_model_q, append_name(name, 'key_fc'), + initializer) + self.v = _build_linear(d_model, d_model_v, append_name( + name, 'value_fc'), initializer) + self.o = _build_linear(d_model_v, d_model, append_name( + name, 'output_fc'), initializer) + self.dropout = lambda i: L.dropout( + i, + dropout_prob=cfg['attention_probs_dropout_prob'], + dropout_implementation="upscale_in_train", + ) if self.training else i + + def forward(self, queries, keys, values, attn_bias, past_cache): + assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3 + + q = self.q(queries) + k = self.k(keys) + v = self.v(values) + + cache = (k, v) + if past_cache is not None: + cached_k, cached_v = past_cache + k = L.concat([cached_k, k], 1) + v = L.concat([cached_v, v], 1) + + q = L.transpose( + L.reshape(q, [0, 0, self.n_head, q.shape[-1] // self.n_head]), + [0, 2, 1, 3]) #[batch, head, seq, dim] + k = L.transpose( + L.reshape(k, [0, 0, self.n_head, k.shape[-1] // self.n_head]), + [0, 2, 1, 3]) #[batch, head, seq, dim] + v = L.transpose( + L.reshape(v, [0, 0, self.n_head, v.shape[-1] // self.n_head]), + [0, 2, 1, 3]) #[batch, head, seq, dim] + + q = L.scale(q, scale=self.d_key**-0.5) + score = L.matmul(q, k, transpose_y=True) + if attn_bias is not None: + score += attn_bias + score = L.softmax(score, use_cudnn=True) + score = self.dropout(score) + + out = L.matmul(score, v) + out = L.transpose(out, [0, 2, 1, 3]) + out = L.reshape(out, [0, 0, out.shape[2] * out.shape[3]]) + + out = self.o(out) + return out, cache + + +class PositionwiseFeedForwardLayer(D.Layer): + def __init__(self, cfg, name=None): + super(PositionwiseFeedForwardLayer, self).__init__() + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + d_model = cfg['hidden_size'] + d_ffn = cfg.get('intermediate_size', 4 * d_model) + assert cfg['hidden_act'] in ['relu', 'gelu'] + self.i = _build_linear( + d_model, + d_ffn, + append_name(name, 'fc_0'), + initializer, + act=cfg['hidden_act']) + self.o = _build_linear(d_ffn, d_model, append_name(name, 'fc_1'), + initializer) + prob = cfg.get('intermediate_dropout_prob', 0.) + self.dropout = lambda i: L.dropout( + i, + dropout_prob=prob, + dropout_implementation="upscale_in_train", + ) if self.training else i + + def forward(self, inputs): + hidden = self.i(inputs) + hidden = self.dropout(hidden) + out = self.o(hidden) + return out + + +class ErnieBlock(D.Layer): + def __init__(self, cfg, name=None): + super(ErnieBlock, self).__init__() + d_model = cfg['hidden_size'] + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + + self.attn = AttentionLayer( + cfg, name=append_name(name, 'multi_head_att')) + self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att')) + self.ffn = PositionwiseFeedForwardLayer( + cfg, name=append_name(name, 'ffn')) + self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn')) + prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob']) + self.dropout = lambda i: L.dropout( + i, + dropout_prob=prob, + dropout_implementation="upscale_in_train", + ) if self.training else i + + def forward(self, inputs, attn_bias=None, past_cache=None): + attn_out, cache = self.attn( + inputs, inputs, inputs, attn_bias, + past_cache=past_cache) #self attn + attn_out = self.dropout(attn_out) + hidden = attn_out + inputs + hidden = self.ln1(hidden) # dropout/ add/ norm + + ffn_out = self.ffn(hidden) + ffn_out = self.dropout(ffn_out) + hidden = ffn_out + hidden + hidden = self.ln2(hidden) + return hidden, cache + + +class ErnieEncoderStack(D.Layer): + def __init__(self, cfg, name=None): + super(ErnieEncoderStack, self).__init__() + n_layers = cfg['num_hidden_layers'] + self.block = D.LayerList([ + ErnieBlock(cfg, append_name(name, 'layer_%d' % i)) + for i in range(n_layers) + ]) + + def forward(self, inputs, attn_bias=None, past_cache=None): + if past_cache is not None: + assert isinstance( + past_cache, tuple + ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr( + type(past_cache)) + past_cache = list(zip(*past_cache)) + else: + past_cache = [None] * len(self.block) + cache_list_k, cache_list_v, hidden_list = [], [], [inputs] + + for b, p in zip(self.block, past_cache): + inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p) + cache_k, cache_v = cache + cache_list_k.append(cache_k) + cache_list_v.append(cache_v) + hidden_list.append(inputs) + + return inputs, hidden_list, (cache_list_k, cache_list_v) + + +class ErnieModel(D.Layer): + def __init__(self, cfg, name=None): + """ + Fundamental pretrained Ernie model + """ + log.debug('init ErnieModel with config: %s' % repr(cfg)) + D.Layer.__init__(self) + d_model = cfg['hidden_size'] + d_emb = cfg.get('emb_size', cfg['hidden_size']) + d_vocab = cfg['vocab_size'] + d_pos = cfg['max_position_embeddings'] + d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size'] + self.n_head = cfg['num_attention_heads'] + self.return_additional_info = cfg.get('return_additional_info', False) + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + + self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder')) + self.word_emb = D.Embedding([d_vocab, d_emb], + param_attr=F.ParamAttr( + name=append_name( + name, 'word_embedding'), + initializer=initializer)) + self.pos_emb = D.Embedding([d_pos, d_emb], + param_attr=F.ParamAttr( + name=append_name(name, 'pos_embedding'), + initializer=initializer)) + self.sent_emb = D.Embedding([d_sent, d_emb], + param_attr=F.ParamAttr( + name=append_name( + name, 'sent_embedding'), + initializer=initializer)) + prob = cfg['hidden_dropout_prob'] + self.dropout = lambda i: L.dropout( + i, + dropout_prob=prob, + dropout_implementation="upscale_in_train", + ) if self.training else i + + self.encoder_stack = ErnieEncoderStack(cfg, append_name( + name, 'encoder')) + if cfg.get('has_pooler', True): + self.pooler = _build_linear( + cfg['hidden_size'], + cfg['hidden_size'], + append_name(name, 'pooled_fc'), + initializer, + act='tanh') + else: + self.pooler = None + self.train() + + def eval(self): + if F.in_dygraph_mode(): + super(ErnieModel, self).eval() + self.training = False + for l in self.sublayers(): + l.training = False + + def train(self): + if F.in_dygraph_mode(): + super(ErnieModel, self).train() + self.training = True + for l in self.sublayers(): + l.training = True + + def forward(self, + src_ids, + sent_ids=None, + pos_ids=None, + input_mask=None, + attn_bias=None, + past_cache=None, + use_causal_mask=False): + """ + Args: + src_ids (`Variable` of shape `[batch_size, seq_len]`): + Indices of input sequence tokens in the vocabulary. + sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`): + aka token_type_ids, Segment token indices to indicate first and second portions of the inputs. + if None, assume all tokens come from `segment_a` + pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`): + Indices of positions of each input sequence tokens in the position embeddings. + input_mask(optional `Variable` of shape `[batch_size, seq_len]`): + Mask to avoid performing attention on the padding token indices of the encoder input. + attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`): + 3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask + past_cache(optional, tuple of two lists: cached key and cached value, + each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`): + cached key/value tensor that will be concated to generated key/value when performing self attention. + if set, `attn_bias` should not be None. + + Returns: + pooled (`Variable` of shape `[batch_size, hidden_size]`): + output logits of pooler classifier + encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`): + output logits of transformer stack + """ + assert len( + src_ids.shape + ) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (repr( + src_ids.shape)) + assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' + d_batch = L.shape(src_ids)[0] + d_seqlen = L.shape(src_ids)[1] + if pos_ids is None: + pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1]) + pos_ids = L.cast(pos_ids, 'int64') + if attn_bias is None: + if input_mask is None: + input_mask = L.cast(src_ids != 0, 'float32') + assert len(input_mask.shape) == 2 + input_mask = L.unsqueeze(input_mask, axes=[-1]) + attn_bias = L.matmul(input_mask, input_mask, transpose_y=True) + if use_causal_mask: + sequence = L.reshape( + L.range(0, d_seqlen, 1, dtype='float32') + 1., + [1, 1, -1, 1]) + causal_mask = L.cast( + (L.matmul(sequence, 1. / sequence, transpose_y=True) >= 1.), + 'float32') + attn_bias *= causal_mask + else: + assert len( + attn_bias.shape + ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape + attn_bias = (1. - attn_bias) * -10000.0 + attn_bias = L.unsqueeze(attn_bias, [1]) + attn_bias = L.expand(attn_bias, + [1, self.n_head, 1, 1]) # avoid broadcast =_= + attn_bias.stop_gradient = True + + if sent_ids is None: + sent_ids = L.zeros_like(src_ids) + + src_embedded = self.word_emb(src_ids) + pos_embedded = self.pos_emb(pos_ids) + sent_embedded = self.sent_emb(sent_ids) + embedded = src_embedded + pos_embedded + sent_embedded + + embedded = self.dropout(self.ln(embedded)) + + encoded, hidden_list, cache_list = self.encoder_stack( + embedded, attn_bias, past_cache=past_cache) + if self.pooler is not None: + pooled = self.pooler(encoded[:, 0, :]) + else: + pooled = None + + additional_info = { + 'hiddens': hidden_list, + 'caches': cache_list, + } + + if self.return_additional_info: + return pooled, encoded, additional_info + else: + return pooled, encoded diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie_gen.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..c2245ec3f03c4bf75ece5c5856e7074d4ab28b68 --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/modeling_ernie_gen.py @@ -0,0 +1,78 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as F +import paddle.fluid.layers as L + +from ernie_gen_couplet.model.modeling_ernie import ErnieModel +from ernie_gen_couplet.model.modeling_ernie import _build_linear, _build_ln, append_name + + +class ErnieModelForGeneration(ErnieModel): + def __init__(self, cfg, name=None): + cfg['return_additional_info'] = True + cfg['has_pooler'] = False + super(ErnieModelForGeneration, self).__init__(cfg, name=name) + initializer = F.initializer.TruncatedNormal( + scale=cfg['initializer_range']) + d_model = cfg['hidden_size'] + d_vocab = cfg['vocab_size'] + + self.mlm = _build_linear( + d_model, + d_model, + append_name(name, 'mask_lm_trans_fc'), + initializer, + act=cfg['hidden_act']) + self.mlm_ln = _build_ln( + d_model, name=append_name(name, 'mask_lm_trans')) + self.mlm_bias = L.create_parameter( + dtype='float32', + shape=[d_vocab], + attr=F.ParamAttr( + name=append_name(name, 'mask_lm_out_fc.b_0'), + initializer=F.initializer.Constant(value=0.0)), + is_bias=True, + ) + + def forward(self, src_ids, *args, **kwargs): + tgt_labels = kwargs.pop('tgt_labels', None) + tgt_pos = kwargs.pop('tgt_pos', None) + encode_only = kwargs.pop('encode_only', False) + _, encoded, info = ErnieModel.forward(self, src_ids, *args, **kwargs) + if encode_only: + return None, None, info + elif tgt_labels is None: + encoded = self.mlm(encoded) + encoded = self.mlm_ln(encoded) + logits = L.matmul( + encoded, self.word_emb.weight, transpose_y=True) + self.mlm_bias + output_ids = L.argmax(logits, -1) + return output_ids, logits, info + else: + encoded_2d = L.gather_nd(encoded, tgt_pos) + encoded_2d = self.mlm(encoded_2d) + encoded_2d = self.mlm_ln(encoded_2d) + logits_2d = L.matmul( + encoded_2d, self.word_emb.weight, + transpose_y=True) + self.mlm_bias + if len(tgt_labels.shape) == 1: + tgt_labels = L.reshape(tgt_labels, [-1, 1]) + + loss = L.reduce_mean( + L.softmax_with_cross_entropy( + logits_2d, + tgt_labels, + soft_label=(tgt_labels.shape[-1] != 1))) + return loss, logits_2d, info diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/model/tokenizing_ernie.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/tokenizing_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..3039b7028f5da991189527b8145b05c952dafbbd --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/model/tokenizing_ernie.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import re +import logging +from functools import partial + +import numpy as np + +import io + +open = partial(io.open, encoding='utf8') + +log = logging.getLogger(__name__) + +_max_input_chars_per_word = 100 + + +def _wordpiece(token, vocab, unk_token, prefix='##', sentencepiece_prefix=''): + """ wordpiece: helloworld => [hello, ##world] """ + chars = list(token) + if len(chars) > _max_input_chars_per_word: + return [unk_token], [(0, len(chars))] + + is_bad = False + start = 0 + sub_tokens = [] + sub_pos = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start == 0: + substr = sentencepiece_prefix + substr + if start > 0: + substr = prefix + substr + if substr in vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + sub_pos.append((start, end)) + start = end + if is_bad: + return [unk_token], [(0, len(chars))] + else: + return sub_tokens, sub_pos + + +class ErnieTokenizer(object): + def __init__(self, + vocab, + unk_token='[UNK]', + sep_token='[SEP]', + cls_token='[CLS]', + pad_token='[PAD]', + mask_token='[MASK]', + wordpiece_prefix='##', + sentencepiece_prefix='', + lower=True, + encoding='utf8', + special_token_list=[]): + if not isinstance(vocab, dict): + raise ValueError( + 'expect `vocab` to be instance of dict, got %s' % type(vocab)) + self.vocab = vocab + self.lower = lower + self.prefix = wordpiece_prefix + self.sentencepiece_prefix = sentencepiece_prefix + self.pad_id = self.vocab[pad_token] + self.cls_id = cls_token and self.vocab[cls_token] + self.sep_id = sep_token and self.vocab[sep_token] + self.unk_id = unk_token and self.vocab[unk_token] + self.mask_id = mask_token and self.vocab[mask_token] + self.unk_token = unk_token + special_tokens = { + pad_token, cls_token, sep_token, unk_token, mask_token + } | set(special_token_list) + pat_str = '' + for t in special_tokens: + if t is None: + continue + pat_str += '(%s)|' % re.escape(t) + pat_str += r'([a-zA-Z0-9]+|\S)' + log.debug('regex: %s' % pat_str) + self.pat = re.compile(pat_str) + self.encoding = encoding + + def tokenize(self, text): + if len(text) == 0: + return [] + if six.PY3 and not isinstance(text, six.string_types): + text = text.decode(self.encoding) + if six.PY2 and isinstance(text, str): + text = text.decode(self.encoding) + + res = [] + for match in self.pat.finditer(text): + match_group = match.group(0) + if match.groups()[-1]: + if self.lower: + match_group = match_group.lower() + words, _ = _wordpiece( + match_group, + vocab=self.vocab, + unk_token=self.unk_token, + prefix=self.prefix, + sentencepiece_prefix=self.sentencepiece_prefix) + else: + words = [match_group] + res += words + return res + + def convert_tokens_to_ids(self, tokens): + return [self.vocab.get(t, self.unk_id) for t in tokens] + + def truncate(self, id1, id2, seqlen): + len1 = len(id1) + len2 = len(id2) + half = seqlen // 2 + if len1 > len2: + len1_truncated, len2_truncated = max(half, seqlen - len2), min( + half, len2) + else: + len1_truncated, len2_truncated = min(half, seqlen - len1), max( + half, seqlen - len1) + return id1[:len1_truncated], id2[:len2_truncated] + + def build_for_ernie(self, text_id, pair_id=[]): + """build sentence type id, add [CLS] [SEP]""" + text_id_type = np.zeros_like(text_id, dtype=np.int64) + ret_id = np.concatenate([[self.cls_id], text_id, [self.sep_id]], 0) + ret_id_type = np.concatenate([[0], text_id_type, [0]], 0) + + if len(pair_id): + pair_id_type = np.ones_like(pair_id, dtype=np.int64) + ret_id = np.concatenate([ret_id, pair_id, [self.sep_id]], 0) + ret_id_type = np.concatenate([ret_id_type, pair_id_type, [1]], 0) + return ret_id, ret_id_type + + def encode(self, text, pair=None, truncate_to=None): + text_id = np.array( + self.convert_tokens_to_ids(self.tokenize(text)), dtype=np.int64) + text_id_type = np.zeros_like(text_id, dtype=np.int64) + if pair is not None: + pair_id = np.array( + self.convert_tokens_to_ids(self.tokenize(pair)), dtype=np.int64) + else: + pair_id = [] + if truncate_to is not None: + text_id, pair_id = self.truncate( + text_id, [] if pair_id is None else pair_id, truncate_to) + + ret_id, ret_id_type = self.build_for_ernie(text_id, pair_id) + return ret_id, ret_id_type diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py new file mode 100644 index 0000000000000000000000000000000000000000..cfa86632ce99f2ced3305ed506663da4fcdcdb5c --- /dev/null +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py @@ -0,0 +1,187 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast +import json + +import paddle.fluid as fluid +import paddlehub as hub +from paddlehub.module.module import runnable +from paddlehub.module.nlp_module import DataFormatError +from paddlehub.common.logger import logger +from paddlehub.module.module import moduleinfo, serving + +import argparse +import os +import numpy as np + +import paddle.fluid.dygraph as D + +from ernie_gen_poetry.model.tokenizing_ernie import ErnieTokenizer +from ernie_gen_poetry.model.decode import beam_search_infilling +from ernie_gen_poetry.model.modeling_ernie_gen import ErnieModelForGeneration + + +@moduleinfo( + name="ernie_gen_poetry", + version="1.0.0", + summary= + "ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.", + author="baidu-nlp", + author_email="", + type="nlp/text_generation", +) +class ErnieGen(hub.NLPPredictionModule): + def _initialize(self): + """ + initialize with the necessary elements + """ + assets_path = os.path.join(self.directory, "assets") + gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_poetry") + ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json') + ernie_cfg = dict(json.loads(open(ernie_cfg_path).read())) + ernie_vocab_path = os.path.join(assets_path, 'vocab.txt') + ernie_vocab = { + j.strip().split('\t')[0]: i + for i, j in enumerate(open(ernie_vocab_path).readlines()) + } + + with fluid.dygraph.guard(fluid.CPUPlace()): + with fluid.unique_name.guard(): + self.model = ErnieModelForGeneration(ernie_cfg) + finetuned_states, _ = D.load_dygraph(gen_checkpoint_path) + self.model.set_dict(finetuned_states) + + self.tokenizer = ErnieTokenizer(ernie_vocab) + self.rev_dict = {v: k for k, v in self.tokenizer.vocab.items()} + self.rev_dict[self.tokenizer.pad_id] = '' # replace [PAD] + self.rev_dict[self.tokenizer.unk_id] = '' # replace [PAD] + self.rev_lookup = np.vectorize(lambda i: self.rev_dict[i]) + + @serving + def generate(self, texts, use_gpu=False, beam_width=5): + """ + Get the continuation of the input poetry. + + Args: + texts(list): the front part of a poetry. + use_gpu(bool): whether use gpu to predict or not + beam_width(int): the beam search width. + + Returns: + results(list): the poetry continuations. + """ + if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ: + use_gpu = False + logger.warning( + "use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True" + ) + if use_gpu: + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + if texts and isinstance(texts, list): + predicted_data = texts + else: + raise ValueError( + "The input data is inconsistent with expectations.") + + with fluid.dygraph.guard(place): + self.model.eval() + results = [] + for text in predicted_data: + sample_results = [] + ids, sids = self.tokenizer.encode(text) + src_ids = D.to_variable(np.expand_dims(ids, 0)) + src_sids = D.to_variable(np.expand_dims(sids, 0)) + output_ids = beam_search_infilling( + self.model, + src_ids, + src_sids, + eos_id=self.tokenizer.sep_id, + sos_id=self.tokenizer.cls_id, + attn_id=self.tokenizer.vocab['[MASK]'], + max_decode_len=80, + max_encode_len=20, + beam_width=beam_width, + tgt_type_id=1) + output_str = self.rev_lookup(output_ids[0].numpy()) + + for ostr in output_str.tolist(): + if '[SEP]' in ostr: + ostr = ostr[:ostr.index('[SEP]')] + sample_results.append("".join(ostr)) + results.append(sample_results) + return results + + def add_module_config_arg(self): + """ + Add the command config options + """ + self.arg_config_group.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=False, + help="whether use GPU for prediction") + + self.arg_config_group.add_argument( + '--beam_width', type=int, default=5, help="the beam search width") + + @runnable + def run_cmd(self, argvs): + """ + Run as a command + """ + self.parser = argparse.ArgumentParser( + description='Run the %s module.' % self.name, + prog='hub run %s' % self.name, + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group( + title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", + description= + "Run configuration for controlling module behavior, optional.") + + self.add_module_config_arg() + self.add_module_input_arg() + + args = self.parser.parse_args(argvs) + + try: + input_data = self.check_input_data(args) + except DataFormatError and RuntimeError: + self.parser.print_help() + return None + + results = self.generate( + texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width) + + return results + + @serving + def serving_method(self, texts, use_gpu=False): + """ + Run as a service. + """ + return self.generate(texts, use_gpu) + + +if __name__ == "__main__": + module = ErnieGen() + for result in module.generate(['宝积峰前露术香,使君行旆照晴阳。'], beam_width=5): + print(result)