提交 e180d15b 编写于 作者: X xixiaoyao

add bert

上级 09ef8c62
import sys
from paddlepalm.mtl_controller import Controller
sys.path.append('paddlepalm')
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""v1.1
BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid import layers
from paddlepalm.backbone.utils.transformer import pre_process_layer, encoder
from paddlepalm.interface import backbone
class Model(backbone):
def __init__(self, config, phase):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
self._emb_size = config["hidden_size"]
self._n_layer = config["num_hidden_layers"]
self._n_head = config["num_attention_heads"]
self._voc_size = config["vocab_size"]
self._max_position_seq_len = config["max_position_embeddings"]
self._sent_types = config["type_vocab_size"]
self._hidden_act = config["hidden_act"]
self._prepostprocess_dropout = config["hidden_dropout_prob"]
self._attention_dropout = config["attention_probs_dropout_prob"]
self.model_name = model_name
self._word_emb_name = self.model_name + "word_embedding"
self._pos_emb_name = self.model_name + "pos_embedding"
self._sent_emb_name = self.model_name + "sent_embedding"
# Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default.
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config["initializer_range"])
@property
def inputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']}
@property
def outputs_attr(self):
return {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
def build(self, inputs):
src_ids = inputs['token_ids']
pos_ids = inputs['position_ids']
sent_ids = inputs['segment_ids']
input_mask = inputs['input_mask']
# padding id in vocabulary must be set to 0
emb_out = layers.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype="float32",
param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
self.emb_out = emb_out
position_emb_out = layers.embedding(
input=pos_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype="float32",
param_attr=fluid.ParamAttr(
name=self._pos_emb_name, initializer=self._param_initializer))
self.position_emb_out = position_emb_out
sent_emb_out = layers.embedding(
sent_ids,
size=[self._sent_types, self._emb_size],
dtype="float32"
param_attr=fluid.ParamAttr(
name=self._sent_emb_name, initializer=self._param_initializer))
self.sent_emb_out = sent_emb_out
emb_out = emb_out + position_emb_out + sent_emb_out
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
self_attn_mask = layers.matmul(
x = input_mask, y = input_mask, transpose_y = True)
self_attn_mask = layers.scale(
x = self_attn_mask, scale = 10000.0, bias = -1.0, bias_after_scale = False)
n_head_self_attn_mask = layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
enc_out = encoder(
enc_input = emb_out,
attn_bias = n_head_self_attn_mask,
n_layer = self._n_layer,
n_head = self._n_head,
d_key = self._emb_size // self._n_head,
d_value = self._emb_size // self._n_head,
d_model = self._emb_size,
d_inner_hid = self._emb_size * 4,
prepostprocess_dropout = self._prepostprocess_dropout,
attention_dropout = self._attention_dropout,
relu_dropout = 0,
hidden_act = self._hidden_act,
preprocess_cmd = "",
postprocess_cmd = "dan",
param_initializer = self._param_initializer,
name = self.model_name + 'encoder')
next_sent_feat = layers.slice(
input = enc_out, axes = [1], starts = [0], ends = [1])
next_sent_feat = layers.fc(
input = next_sent_feat,
size = self._emb_size,
act = "tanh",
param_attr = fluid.ParamAttr(
name = self.model_name + "pooled_fc.w_0",
initializer = self._param_initializer),
bias_attr = "pooled_fc.b_0")
return {'word_embedding': emb_out,
'encoder_outputs': enc_out,
'sentence_embedding': next_sent_feat,
'sentence_pair_embedding': next_sent_feat}
def postprocess(self, rt_outputs):
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ernie model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from paddle import fluid
from paddle.fluid import layers
from paddlepalm.backbone.utils.transformer import pre_process_layer, encoder
from paddlepalm.interface import backbone
class Model(backbone):
def __init__(self,
config,
phase):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
self._emb_size = config['hidden_size']
self._n_layer = config['num_hidden_layers']
self._n_head = config['num_attention_heads']
self._voc_size = config['vocab_size']
self._max_position_seq_len = config['max_position_embeddings']
if config['sent_type_vocab_size']:
self._sent_types = config['sent_type_vocab_size']
else:
self._sent_types = config['type_vocab_size']
self._task_types = config['task_type_vocab_size']
self._hidden_act = config['hidden_act']
self._prepostprocess_dropout = config['hidden_dropout_prob']
self._attention_dropout = config['attention_probs_dropout_prob']
self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding"
self._sent_emb_name = "sent_embedding"
self._task_emb_name = "task_embedding"
self._emb_dtype = "float32"
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range'])
@property
def inputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1,-1, 1], 'int64']}
@property
def outputs_attr(self):
return {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
def build(self, inputs, scope_name=""):
src_ids = inputs['token_ids']
pos_ids = inputs['position_ids']
sent_ids = inputs['segment_ids']
input_mask = inputs['input_mask']
task_ids = inputs['task_ids']
# padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
position_emb_out = fluid.layers.embedding(
input=pos_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
sent_emb_out = fluid.layers.embedding(
sent_ids,
size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._sent_emb_name, initializer=self._param_initializer))
emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out
task_emb_out = fluid.layers.embedding(
task_ids,
size=[self._task_types, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._task_emb_name,
initializer=self._param_initializer))
emb_out = emb_out + task_emb_out
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder')
self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True)
self_attn_mask = fluid.layers.scale(
x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
enc_out = encoder(
enc_input=emb_out,
attn_bias=n_head_self_attn_mask,
n_layer=self._n_layer,
n_head=self._n_head,
d_key=self._emb_size // self._n_head,
d_value=self._emb_size // self._n_head,
d_model=self._emb_size,
d_inner_hid=self._emb_size * 4,
prepostprocess_dropout=self._prepostprocess_dropout,
attention_dropout=self._attention_dropout,
relu_dropout=0,
hidden_act=self._hidden_act,
preprocess_cmd="",
postprocess_cmd="dan",
param_initializer=self._param_initializer,
name=scope_name+'encoder')
next_sent_feat = fluid.layers.slice(
input=enc_out, axes=[1], starts=[0], ends=[1])
next_sent_feat = fluid.layers.reshape(next_sent_feat, [-1, next_sent_feat.shape[-1]])
next_sent_feat = fluid.layers.fc(
input=next_sent_feat,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name=scope_name+"pooled_fc.w_0", initializer=self._param_initializer),
bias_attr=scope_name+"pooled_fc.b_0")
return {'embedding_table': embedding_table,
'word_embedding': emb_out,
'encoder_outputs': enc_out,
'sentence_embedding': next_sent_feat,
'sentence_pair_embedding': next_sent_feat}
def postprocess(self, rt_outputs):
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
param_initializer=None,
name='multi_head_att'):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_query_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_query_fc.b_0')
k = layers.fc(input=keys,
size=d_key * n_head,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_key_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_key_fc.b_0')
v = layers.fc(input=values,
size=d_value * n_head,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_value_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_value_fc.b_0')
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_output_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_output_fc.b_0')
return proj_out
def positionwise_feed_forward(x,
d_inner_hid,
d_hid,
dropout_rate,
hidden_act,
param_initializer=None,
name='ffn'):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act=hidden_act,
param_attr=fluid.ParamAttr(
name=name + '_fc_0.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_0.b_0')
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_fc_1.w_0', initializer=param_initializer),
bias_attr=name + '_fc_1.b_0')
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.,
name=''):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out_dtype = out.dtype
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float32")
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_layer_norm_scale',
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
name=name + '_layer_norm_bias',
initializer=fluid.initializer.Constant(0.)))
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float16")
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(
enc_input,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_att'),
None,
None,
attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
param_initializer=param_initializer,
name=name + '_multi_head_att')
attn_output = post_process_layer(
enc_input,
attn_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_att')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_ffn'),
d_inner_hid,
d_model,
relu_dropout,
hidden_act,
param_initializer=param_initializer,
name=name + '_ffn')
return post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_ffn')
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd,
postprocess_cmd,
param_initializer=param_initializer,
name=name + '_layer_' + str(i))
enc_input = enc_output
enc_output = pre_process_layer(
enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder")
return enc_output
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
BACKBONE_DIR='paddlepalm.backbone'
TASK_INSTANCE_DIR='paddlepalm.task_instance'
READER_DIR='paddlepalm.reader'
PARADIGM_DIR='paddlepalm.task_paradigm'
OPTIMIZER_DIR='paddlepalm.optimizer'
OPTIMIZE_METHOD='optimize'
REQUIRED_ARGS={
'task_instance': str,
'backbone': str,
'optimizer': str,
'learning_rate': float,
'batch_size': int
}
OPTIONAL_ARGS={
'mix_ratio': str,
'target_tag': str,
'reuse_rag': str
}
TASK_REQUIRED_ARGS={
'paradigm': str,
'reader': str,
'train_file': str
}
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""v1.1"""
class reader(object):
"""interface of data manager."""
def __init__(self, config):
assert isinstance(config, dict)
# @property
# def inputs_attr(self):
# """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1.
# Return:
# dict类型。对各个输入对象的属性描述。例如,
# 对于文本分类任务,可能需要包含输入文本和所属标签的id
# {"text": ([], 'str'),
# "label": ([], 'int')}
# 对于标注任务,可能需要输入词序列和对应的标签
# {"tokens", ([-1], 'str'),
# "tags", ([-1], 'str')}
# 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等
# {"paragraph", ([], 'str'),
# "question", ([], 'str'),
# "start_position", ([], 'int')
# """
# raise NotImplementedError()
@property
def outputs_attr(self):
"""描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1)
Return:
dict类型。对各个输入对象的属性描述。例如,
对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象)
{"token_ids": ([-1, max_len], 'int64'),
"input_ids": ([-1, max_len], 'int64'),
"segment_ids": ([-1, max_len], 'int64'),
"input_mask": ([-1, max_len], 'float32'),
"label": ([-1], 'int')}
"""
raise NotImplementedError()
# def parse_line(self):
# """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。
# 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。
# 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。
# """
# raise NotImplementedError()
#
# def tokenize(self, line):
# """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。
# Args:
# - line: a unicode string.
# Return:
# a list of tokens
# """
# raise NotImplementedError()
def iterator(self):
"""数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。
Yield:
(dict) elements that meet the requirements in output_templete
"""
raise NotImplementedError()
@property
def num_examples(self):
"""数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。"""
raise NotImplementedError()
class backbone(object):
"""interface of backbone model."""
def __init__(self, config, phase):
"""
Args:
config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数
phase: str类型。运行阶段,目前支持train和predict
"""
assert isinstance(config, dict)
@property
def inputs_attr(self):
"""描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输入对象的属性描述。例如,
对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象
{"token_ids": ([-1, max_len], 'int64'),
"input_ids": ([-1, max_len], 'int64'),
"segment_ids": ([-1, max_len], 'int64'),
"input_mask": ([-1, max_len], 'float32')}"""
raise NotImplementedError()
@property
def outputs_attr(self):
"""描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输出对象的属性描述。例如,
对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象
{"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'),
"sentence_emb": ([-1, hidden_size], 'float32'),
"sim_vec": ([-1, hidden_size], 'float32')}"""
raise NotImplementedError()
def build(self, inputs):
"""建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
Return:
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""
raise NotImplementedError()
class task_paradigm(object):
def __init__(self, config, phase, backbone_config):
"""
config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数
phase: str类型。运行阶段,目前支持train和predict
"""
@property
def inputs_attrs(self):
"""描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个对象集及其输入对象的属性描述。"""
raise NotImplementedError()
@property
def outputs_attr(self):
"""描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。
"""
raise NotImplementedError()
def build(self, inputs):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
Return:
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""
raise NotImplementedError()
def postprocess(self, rt_outputs):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。"""
pass
def post_postprocess(self, global_buffer):
pass
此差异已折叠。
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
""" Applies linear warmup of learning rate from 0 and decay to 0."""
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="scheduled_learning_rate")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
warmup_lr = learning_rate * (global_step / warmup_steps)
fluid.layers.tensor.assign(warmup_lr, lr)
with switch.default():
decayed_lr = fluid.layers.learning_rate_scheduler.polynomial_decay(
learning_rate=learning_rate,
decay_steps=num_train_steps,
end_learning_rate=0.0,
power=1.0,
cycle=False)
fluid.layers.tensor.assign(decayed_lr, lr)
return lr
def optimize(loss, config, max_train_steps=None, warmup_steps=0, train_program=None):
if warmup_steps > 0:
decay_strategy = config.get('lr_scheduler', 'linear_warmup_decay')
if decay_strategy == 'noam_decay':
scheduled_lr = fluid.layers.learning_rate_scheduler\
.noam_decay(1/(warmup_steps *(config['learning_rate'] ** 2)),
warmup_steps)
elif decay_strategy == 'linear_warmup_decay':
scheduled_lr = linear_warmup_decay(config['learning_rate'], warmup_steps,
max_train_steps)
else:
raise ValueError("Unkown lr_scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'")
optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr)
else:
optimizer = fluid.optimizer.Adam(learning_rate=config['learning_rate'])
scheduled_lr = config['learning_rate']
clip_norm_thres = 1.0
# When using mixed precision training, scale the gradient clip threshold
# by loss_scaling
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres))
def exclude_from_weight_decay(name):
if name.find("layer_norm") > -1:
return True
bias_suffix = ["_bias", "_b", ".b_0"]
for suffix in bias_suffix:
if name.endswith(suffix):
return True
return False
param_list = dict()
for param in train_program.global_block().all_parameters():
param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True
_, param_grads = optimizer.minimize(loss)
if config.get('weight_decay', 0) > 0:
for param, grad in param_grads:
if exclude_from_weight_decay(param.name):
continue
with param.block.program._optimized_guard(
[param, grad]), fluid.framework.name_scope("weight_decay"):
updated_param = param - param_list[
param.name] * config['weight_decay'] * scheduled_lr
fluid.layers.assign(output=param, input=updated_param)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = ClassifyReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', False)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = ClassifyReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', False)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MaskLMReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = MaskLMReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', False)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
@property
def outputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1, 1], 'int64'],
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64']
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask',
'task_ids', 'mask_label', 'mask_pos']
outputs = {n: i for n,i in zip(names, x)}
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
此差异已折叠。
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = MRCReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
tokenizer='FullTokenizer',
for_cn=config.get('for_cn', False),
doc_stride=config['doc_stride'],
max_query_length=config['max_query_len'],
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
# self._num_epochs = config['num_epochs']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', False)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
if phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
# TODO: without slide window version
self._with_slide_window = config.get('with_slide_window', False)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"start_positions": [[-1, 1], 'int64'],
"end_positions": [[-1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"unique_ids": [[-1, 1], 'int64']
}
@property
def epoch_outputs_attr(self):
if not self._is_training:
return {"examples": None,
"features": None}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'start_positions', 'end_positions', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
if self._is_training:
del outputs['unique_ids']
else:
del outputs['start_positions']
del outputs['end_positions']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def prepare_batch_data(insts,
total_token_num,
max_len=None,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
1. generate Tensor of data
2. generate Tensor of position
3. generate self attention mask, [shape: batch_size * max_len * max_len]
"""
batch_src_ids = [inst[0] for inst in insts]
batch_sent_ids = [inst[1] for inst in insts]
batch_pos_ids = [inst[2] for inst in insts]
labels_list = []
# compatible with mrqa, whose example includes start/end positions,
# or unique id
for i in range(3, len(insts[0]), 1):
labels = [inst[i] for inst in insts]
labels = np.array(labels).astype("int64").reshape([-1, 1])
labels_list.append(labels)
# First step: do mask without padding
if mask_id >= 0:
out, mask_label, mask_pos = mask(
batch_src_ids,
total_token_num,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
else:
out = batch_src_ids
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out,
max_len=max_len,
pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(
batch_pos_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
sent_id = pad_batch_data(
batch_sent_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
if mask_id >= 0:
return_list = [
src_id, pos_id, sent_id, self_input_mask, mask_label, mask_pos
] + labels_list
else:
return_list = [src_id, pos_id, sent_id, self_input_mask] + labels_list
return return_list if len(return_list) > 1 else return_list[0]
def pad_batch_data(insts,
max_len=None,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
if max_len is None:
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange
def mask(batch_tokens,
seg_labels,
mask_word_tags,
total_token_num,
vocab_size,
CLS=1,
SEP=2,
MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
mask_word = mask_word_tags[sent_index]
prob_index += pre_sent_len
if mask_word:
beg = 0
for token_index, token in enumerate(sent):
seg_label = seg_labels[sent_index][token_index]
if seg_label == 1:
continue
if beg == 0:
if seg_label != -1:
beg = token_index
continue
prob = prob_mask[prob_index + beg]
if prob > 0.15:
pass
else:
for index in xrange(beg, token_index):
prob = prob_mask[prob_index + index]
base_prob = 1.0
if index == beg:
base_prob = 0.15
if base_prob * 0.2 < prob <= base_prob:
mask_label.append(sent[index])
sent[index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + index)
elif base_prob * 0.1 < prob <= base_prob * 0.2:
mask_label.append(sent[index])
sent[index] = replace_ids[prob_index + index]
mask_flag = True
mask_pos.append(sent_index * max_len + index)
else:
mask_label.append(sent[index])
mask_pos.append(sent_index * max_len + index)
if seg_label == -1:
beg = 0
else:
beg = token_index
else:
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index +
token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def pad_batch_data(insts,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False,
return_seq_lens=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype("int64").reshape([-1, 1])]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def prepare_batch_data(insts,
total_token_num,
max_len=None,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
task_id=0,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
1. generate Tensor of data
2. generate Tensor of position
3. generate self attention mask, [shape: batch_size * max_len * max_len]
"""
batch_src_ids = [inst[0] for inst in insts]
batch_sent_ids = [inst[1] for inst in insts]
batch_pos_ids = [inst[2] for inst in insts]
# 这里是否应该反过来???否则在task layer里展开后的word embedding是padding后的,这时候word的index是跟没有padding时的index对不上的?
# First step: do mask without padding
out, mask_label, mask_pos = mask(
batch_src_ids,
total_token_num,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out,
max_len=max_len,
pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(
batch_pos_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
sent_id = pad_batch_data(
batch_sent_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
task_ids = np.ones_like(
src_id, dtype="int64") * task_id
return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
]
return return_list if len(return_list) > 1 else return_list[0]
def pad_batch_data(insts,
max_len=None,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
if max_len is None:
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class MRQAExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (
tokenization.printable_text(self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class MRQAFeature(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader as base_reader
from paddlepalm.interface import task_paradigm as base_paradigm
import os
import json
from paddle import fluid
class TaskInstance(object):
def __init__(self, name, id, config={}, verbose=True):
self._name = name
self._config = config
self._verbose = verbose
self._save_infermodel_path = os.path.join(self._config['save_path'], 'infer_model')
self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt')
# following flags can be fetch from instance config file
self._is_target = config.get('is_target', True)
self._first_target = config.get('is_first_target', False)
self._task_reuse_scope = config.get('task_reuse_scope', name)
self._feeded_var_names = None
self._target_vars = None
# training process management
self._mix_ratio = None
self._expected_train_steps = None
self._expected_train_epochs = None
self._steps_pur_epoch = None
self._cur_train_epoch = 0
self._cur_train_step = 0
self._train_finish = False
# 存放不同运行阶段(train,eval,pred)的数据集reader,key为phase,value为Reader实例
self._reader = {'train': None, 'eval': None, 'pred': None}
self._input_layer = None
self._inputname_to_varname = {}
self._task_layer = {'train': None, 'eval': None, 'pred': None}
self._pred_input_name_list = []
self._pred_input_varname_list = []
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
self._Reader = None
self._Paradigm = None
self._exe = fluid.Executor(fluid.CPUPlace())
self._save_protocol = {
'input_names': 'self._pred_input_name_list',
'input_varnames': 'self._pred_input_varname_list',
'fetch_list': 'self._pred_fetch_name_list'}
def build_task_layer(self, net_inputs, phase):
output_vars = self._task_layer[phase].build(net_inputs)
if phase == 'pred':
if output_vars is not None:
self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
else:
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
return output_vars
def postprocess(self, rt_outputs, phase):
return self._task_layer[phase].postprocess(rt_outputs)
def epoch_postprocess(self, epoch_inputs, phase):
return self._task_layer[phase].epoch_postprocess(epoch_inputs)
def save(self, suffix=''):
dirpath = self._save_infermodel_path + suffix
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe)
# fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, params_filename='__params__')
print(self._name + ': inference model saved at ' + dirpath)
conf = {}
for k, strv in self._save_protocol.items():
exec('v={}'.format(strv))
conf[k] = v
with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
writer.write(json.dumps(conf, indent=1))
def load(self, infer_model_path=None):
if infer_model_path is None:
infer_model_path = self._save_infermodel_path
for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items():
strv = self._save_protocol[k]
exec('{}=v'.format(strv))
pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
fluid.io.load_inference_model(infer_model_path, self._exe)
# pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
# fluid.io.load_inference_model(infer_model_path, self._exe, params_filename='__params__')
print(self._name+': inference model loaded from ' + infer_model_path)
return pred_prog
@property
def name(self):
return self._name
@property
def Reader(self):
return self._Reader
@Reader.setter
def Reader(self, cls):
assert base_reader.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_reader.__name__, \
cls.__bases__[-1].__name__)
self._Reader = cls
@property
def Paradigm(self):
return self._Paradigm
@Paradigm.setter
def Paradigm(self, cls):
assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \
"expect: {}, receive: {}.".format(base_paradigm.__name__, \
cls.__bases__[-1].__name__)
self._Paradigm = cls
@property
def config(self):
return self._config
@property
def reader(self):
return self._reader
@property
def pred_input(self):
return zip(*[self._pred_input_name_list, self._pred_input_varname_list])
@pred_input.setter
def pred_input(self, val):
assert isinstance(val, dict)
self._pred_input_name_list, self._pred_input_varname_list = \
zip(*[[k, v.name] for k,v in val.items()])
# print(self._pred_input_name_list)
@property
def pred_fetch_list(self):
return [self._pred_fetch_name_list, self._pred_fetch_var_list]
@property
def task_layer(self):
return self._task_layer
@property
def is_first_target(self):
return self._is_first_target
@is_first_target.setter
def is_first_target(self, value):
self._is_first_target = bool(value)
if self._is_first_target:
assert self._is_target, "ERROR: only target task could be set as main task."
if self._verbose and self._is_first_target:
print("{}: set as main task".format(self._name))
@property
def is_target(self):
if self._is_target is not None:
return self._is_target
else:
raise ValueError("{}: is_target is None".format(self._name))
@is_target.setter
def is_target(self, value):
self._is_target = bool(value)
if self._verbose:
if self._is_target:
print('{}: set as target task.'.format(self._name))
else:
print('{}: set as aux task.'.format(self._name))
@property
def mix_ratio(self):
if self._mix_ratio is not None:
return self._mix_ratio
else:
raise ValueError("{}: mix_ratio is None".format(self._name))
@mix_ratio.setter
def mix_ratio(self, value):
self._mix_ratio = float(value)
if self._verbose:
print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))
@property
def expected_train_steps(self):
return self._expected_train_steps
@expected_train_steps.setter
def expected_train_steps(self, value):
self._expected_train_steps = value
self._expected_train_epochs = value / float(self._steps_pur_epoch)
@property
def expected_train_epochs(self):
return self._expected_train_epochs
@property
def cur_train_epoch(self):
return self._cur_train_epoch
@cur_train_epoch.setter
def cur_train_epoch(self, value):
self._cur_train_epoch = value
@property
def cur_train_step(self):
return self._cur_train_step
@cur_train_step.setter
def cur_train_step(self, value):
self._cur_train_step = value
if self._cur_train_step > self._steps_pur_epoch:
self._cur_train_epoch += 1
self._cur_train_step = 1
if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
self._train_finish = True
# fluid.io.save_inference_model(self._save_infermodel_path, )
@property
def steps_pur_epoch(self):
return self._steps_pur_epoch
@steps_pur_epoch.setter
def steps_pur_epoch(self, value):
self._steps_pur_epoch = value
@property
def train_finish(self):
return self._train_finish
@property
def task_reuse_scope(self):
if self._task_reuse_scope is not None:
return self._task_reuse_scope
else:
raise ValueError("{}: task_reuse_scope is None".format(self._name))
@task_reuse_scope.setter
def task_reuse_scope(self, scope_name):
self._task_reuse_scope = str(scope_name)
if self._verbose:
print('{}: task_reuse_scope is set to {}'.format(self._name, self._task_reuse_scope))
def check_instances(insts):
"""to check ids, first_target"""
pass
def _check_ids():
pass
def _check_targets():
pass
def _check_reuse_scopes():
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers
class TaskParadigm(task_paradigm):
'''
classification
'''
def __init___(self, config, phase):
self._is_training = phase == 'train'
self.sent_emb_size = config['hidden_size']
self.num_classes = config['n_classes']
@property
def inputs_attrs(self):
return {'bakcbone': {"sentence_emb": [-1, self.sent_emb_size], 'float32']},
'reader': {"label_ids": [[-1, 1], 'int64']}}
@property
def outputs_attrs(self):
if self._is_training:
return {'loss': [[1], 'float32']}
else:
return {'logits': [-1, self.num_classes], 'float32'}
def build(self, **inputs):
sent_emb = inputs['backbone']['sentence_emb']
label_ids = inputs['reader']['label_ids']
logits = fluid.layers.fc(
input=ent_emb
size=self.num_classes,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.1)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
loss = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=label_ids)
loss = layers.mean(loss)
if self._is_training:
return {"loss": loss}
else:
return {"logits":logits}
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers
class TaskParadigm(task_paradigm):
'''
matching
'''
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._hidden_size = backbone_config['hidden_size']
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']}
else:
reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
return {'reader': reader, 'backbone': bb}
@property
def outputs_attrs(self):
if self._is_training:
return {"loss": [[1], 'float32']}
else:
return {"logits": [[-1, 1], 'float32']}
def build(self, inputs):
if self._is_training:
labels = inputs["reader"]["label_ids"]
cls_feats = inputs["backbone"]["sentence_pair_embedding"]
cls_feats = fluid.layers.dropout(
x=cls_feats,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
logits = fluid.layers.fc(
input=cls_feats,
size=2,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b",
initializer=fluid.initializer.Constant(0.)))
if self._is_training:
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
return {'loss': loss}
else:
return {'logits': logits}
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers
from paddlepalm.backbone.utils.transformer import pre_process_layer
class TaskParadigm(task_paradigm):
'''
matching
'''
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._emb_size = backbone_config['hidden_size']
self._hidden_size = backbone_config['hidden_size']
self._vocab_size = backbone_config['vocab_size']
self._hidden_act = backbone_config['hidden_act']
self._initializer_range = backbone_config['initializer_range']
@property
def inputs_attrs(self):
reader = {
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64']}
if not self._is_training:
del reader['mask_label']
bb = {
"encoder_outputs": [[-1, -1, self._hidden_size], 'float32'],
"embedding_table": [[-1, self._vocab_size, self._emb_size], 'float32']}
return {'reader': reader, 'backbone': bb}
@property
def outputs_attrs(self):
if self._is_training:
return {"loss": [[1], 'float32']}
else:
return {"logits": [[-1], 'float32']}
def build(self, inputs):
if self._is_training:
mask_label = inputs["reader"]["mask_label"]
mask_pos = inputs["reader"]["mask_pos"]
word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"]
emb_size = word_emb.shape[-1]
_param_initializer = fluid.initializer.TruncatedNormal(
scale=self._initializer_range)
mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
reshaped_emb_out = fluid.layers.reshape(
x=enc_out, shape=[-1, emb_size])
# extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
# transform: fc
mask_trans_feat = fluid.layers.fc(
input=mask_feat,
size=emb_size,
act=self._hidden_act,
param_attr=fluid.ParamAttr(
name='mask_lm_trans_fc.w_0',
initializer=_param_initializer),
bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name='mask_lm_trans')
mask_lm_out_bias_attr = fluid.ParamAttr(
name="mask_lm_out_fc.b_0",
initializer=fluid.initializer.Constant(value=0.0))
# print fluid.default_main_program().global_block()
# fc_out = fluid.layers.matmul(
# x=mask_trans_feat,
# y=fluid.default_main_program().global_block().var(
# _word_emb_name),
# transpose_y=True)
fc_out = fluid.layers.matmul(
x=mask_trans_feat,
y=word_emb,
transpose_y=True)
fc_out += fluid.layers.create_parameter(
shape=[self._vocab_size],
dtype='float32',
attr=mask_lm_out_bias_attr,
is_bias=True)
if self._is_training:
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
loss = fluid.layers.mean(mask_lm_loss)
return {'loss': loss}
else:
return {'logits': fc_out}
此差异已折叠。
此差异已折叠。
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MAXLEN = 70
def print_dict(dic, title=""):
if title:
title = ' ' + title + ' '
left_len = (MAXLEN - len(title)) // 2
title = '-' * left_len + title
right_len = MAXLEN - len(title)
title = title + '-' * right_len
else:
title = '-' * MAXLEN
print(title)
for name in dic:
print("{: <25}\t{}".format(str(name), str(dic[name])))
print("")
# print("-" * MAXLEN + '\n')
此差异已折叠。
此差异已折叠。
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册