提交 2efeb39b 编写于 作者: X xixiaoyao

fix save predict model

上级 d44b6381
*.pyc
__pycache__
pretrain_model
pretrain
output*
output_model
build
dist
......
......@@ -15,7 +15,6 @@ if __name__ == '__main__':
config = json.load(open('./pretrain/ernie/ernie_config.json'))
# ernie = palm.backbone.ERNIE(...)
ernie = palm.backbone.ERNIE.from_config(config)
# pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred')
# cls_reader2 = palm.reader.cls(train_file_topic, vocab_path, batch_size, max_seqlen)
# cls_reader3 = palm.reader.cls(train_file_subj, vocab_path, batch_size, max_seqlen)
......@@ -30,7 +29,6 @@ if __name__ == '__main__':
print(cls_reader.outputs_attr)
# 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的
cls_head = palm.head.Classify(4, 1024, 0.1)
# cls_pred_head = palm.head.Classify(4, 1024, 0.1, phase='pred')
# 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制
trainer = palm.Trainer('senti_cls', cls_reader, cls_head)
......@@ -64,7 +62,12 @@ if __name__ == '__main__':
# print(trainer.train_one_step(next(iterator_fn())))
# trainer.train_one_epoch()
trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs/ckpt')
# for save predict model.
pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred')
cls_pred_head = palm.head.Classify(4, 1024, phase='pred')
trainer.build_predict_head(cls_pred_head, pred_ernie)
trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict')
# trainer.save()
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""v1.1"""
class reader(object):
"""interface of data manager."""
def __init__(self, config):
assert isinstance(config, dict)
# @property
# def inputs_attr(self):
# """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1.
# Return:
# dict类型。对各个输入对象的属性描述。例如,
# 对于文本分类任务,可能需要包含输入文本和所属标签的id
# {"text": ([], 'str'),
# "label": ([], 'int')}
# 对于标注任务,可能需要输入词序列和对应的标签
# {"tokens", ([-1], 'str'),
# "tags", ([-1], 'str')}
# 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等
# {"paragraph", ([], 'str'),
# "question", ([], 'str'),
# "start_position", ([], 'int')
# """
# raise NotImplementedError()
@property
def outputs_attr(self):
"""描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1)
Return:
dict类型。对各个输入对象的属性描述。例如,
对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象)
{"token_ids": ([-1, max_len], 'int64'),
"input_ids": ([-1, max_len], 'int64'),
"segment_ids": ([-1, max_len], 'int64'),
"input_mask": ([-1, max_len], 'float32'),
"label": ([-1], 'int')}
"""
raise NotImplementedError()
# def parse_line(self):
# """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。
# 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。
# 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。
# """
# raise NotImplementedError()
#
# def tokenize(self, line):
# """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。
# Args:
# - line: a unicode string.
# Return:
# a list of tokens
# """
# raise NotImplementedError()
def iterator(self):
"""数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。
Yield:
(dict) elements that meet the requirements in output_templete
"""
raise NotImplementedError()
@property
def num_examples(self):
"""数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。"""
raise NotImplementedError()
class backbone(object):
"""interface of backbone model."""
def __init__(self, config, phase):
"""
Args:
config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数
phase: str类型。运行阶段,目前支持train和predict
"""
assert isinstance(config, dict)
@property
def inputs_attr(self):
"""描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输入对象的属性描述。例如,
对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象
{"token_ids": ([-1, max_len], 'int64'),
"input_ids": ([-1, max_len], 'int64'),
"segment_ids": ([-1, max_len], 'int64'),
"input_mask": ([-1, max_len], 'float32')}"""
raise NotImplementedError()
@property
def outputs_attr(self):
"""描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输出对象的属性描述。例如,
对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象
{"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'),
"sentence_emb": ([-1, hidden_size], 'float32'),
"sim_vec": ([-1, hidden_size], 'float32')}"""
raise NotImplementedError()
def build(self, inputs):
"""建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
Return:
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""
raise NotImplementedError()
class task_paradigm(object):
def __init__(self, config, phase, backbone_config):
"""
config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数
phase: str类型。运行阶段,目前支持train和predict
"""
@property
def inputs_attrs(self):
"""描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个对象集及其输入对象的属性描述。"""
raise NotImplementedError()
@property
def outputs_attr(self):
"""描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。
"""
raise NotImplementedError()
@property
def epoch_inputs_attrs(self):
return {}
def build(self, inputs, scope_name=""):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
Return:
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""
raise NotImplementedError()
def postprocess(self, rt_outputs):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。"""
pass
def epoch_postprocess(self, post_inputs):
pass
......@@ -38,7 +38,7 @@ class Trainer(object):
self._reader = reader
self._pred_reader = None
self._task_head = task_head
self._pred_head = pred_head
self._pred_head = None
# if save_predict_model:
# self._save_predict_model = True
......@@ -89,20 +89,24 @@ class Trainer(object):
self._lock = False
self._build_forward = False
def build_predict_head(self, pred_backbone, pred_prog=None, pred_init_prog=None):
def build_predict_head(self, pred_head, pred_backbone, pred_prog=None, pred_init_prog=None):
self._pred_head = pred_head
# self._pred_reader = self._reader.clone(phase='pred')
pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
# pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']
# _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
# _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
# _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
pred_input_names, pred_shape_and_dtypes, _ = reader_helper.merge_input_attrs(backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_input_names, pred_shape_and_dtypes, _ = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
if pred_prog is None:
pred_prog = fluid.Program()
self._pred_prog = pred_prog
if pred_init_prog is None:
pred_init_prog = fluid.Program()
self._pred_init_prog = pred_init_prog
with fluid.program_guard(pred_prog, pred_init_prog):
pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
# pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
......@@ -121,8 +125,6 @@ class Trainer(object):
self._build_head(pred_task_inputs, phase='pred', scope=scope)
def build_forward(self, backbone, pred_backbone=None, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None):
# assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
......@@ -154,7 +156,6 @@ class Trainer(object):
print('joint input shape and dtypes:')
print(joint_shape_and_dtypes)
input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]
if train_prog is None:
......@@ -172,6 +173,7 @@ class Trainer(object):
# bb_output_vars = self._backbone.build(net_inputs, scope_name='__paddlepalm_')
bb_output_vars = backbone.build(net_inputs)
assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())
# self._bb_output_vars.keys
# fluid.framework.switch_main_program(train_prog)
......@@ -293,10 +295,14 @@ class Trainer(object):
pass
def train(self, iterator, save_path=None, save_steps=None, save_type='ckpt', print_steps=5):
"""
Argument:
save_type: ckpt, predict, pretrain
"""
save_type = save_type.split(',')
if 'predict' in save_type:
assert self._pred_head is not None, "Predict head not found! You should call set_predict_head first if you want to save predict model."
assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
save_predict = True
if not os.path.exists(save_path):
......@@ -369,11 +375,11 @@ class Trainer(object):
# cur_task.save()
if (save_predict or save_ckpt) and self._cur_train_step % save_steps == 0:
if save_predict_model:
self.save(save_path, suffix='pred.step'+str(global_step))
if save_predict:
self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
if save_ckpt:
fluid.io.save_persistables(self.exe, os.path.join(save_path, 'ckpt.step'+str(global_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(global_step)))
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
# save_path = os.path.join(main_conf['save_path'], 'ckpt',
# "step_" + str(global_step))
......@@ -422,7 +428,7 @@ class Trainer(object):
dirpath = save_path
self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]
prog = fluid.default_main_program().clone()
prog = self._pred_prog.clone()
fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)
conf = {}
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = ClassifyReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
self._num_classes = config['n_classes']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', True)
# self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 0)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import ClassifyReader
def match(vocab_path, max_seq_len, do_lower_case=True, phase, dev_count=1):
config={
xxx}
return Reader(config())
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = ClassifyReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', True),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', True)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"label_ids": [[-1,1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MaskLMReader
import numpy as np
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = MaskLMReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', True)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
elif phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
@property
def outputs_attr(self):
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1, 1], 'int64'],
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64'],
}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'position_ids', 'segment_ids', 'input_mask',
'task_ids', 'mask_label', 'mask_pos']
outputs = {n: i for n,i in zip(names, x)}
# outputs['batchsize_x_seqlen'] = [self._batch_size * len(outputs['token_ids'][0]) - 1]
return outputs
for batch in self._data_generator():
# print(np.shape(list_to_dict(batch)['token_ids']))
# print(list_to_dict(batch)['mask_label'].tolist())
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlepalm.interface import reader
from paddlepalm.reader.utils.reader4ernie import MRCReader
class Reader(reader):
def __init__(self, config, phase='train', dev_count=1, print_prefix=''):
"""
Args:
phase: train, eval, pred
"""
self._is_training = phase == 'train'
reader = MRCReader(config['vocab_path'],
max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', False),
tokenizer='FullTokenizer',
for_cn=config.get('for_cn', False),
doc_stride=config['doc_stride'],
max_query_length=config['max_query_len'],
random_seed=config.get('seed', None))
self._reader = reader
self._dev_count = dev_count
self._batch_size = config['batch_size']
self._max_seq_len = config['max_seq_len']
if phase == 'train':
self._input_file = config['train_file']
# self._num_epochs = config['num_epochs']
self._num_epochs = None # 防止iteartor终止
self._shuffle = config.get('shuffle', True)
self._shuffle_buffer = config.get('shuffle_buffer', 5000)
if phase == 'eval':
self._input_file = config['dev_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
elif phase == 'pred':
self._input_file = config['pred_file']
self._num_epochs = 1
self._shuffle = False
self._batch_size = config.get('pred_batch_size', self._batch_size)
self._phase = phase
# self._batch_size =
self._print_first_n = config.get('print_first_n', 1)
# TODO: without slide window version
self._with_slide_window = config.get('with_slide_window', False)
@property
def outputs_attr(self):
if self._is_training:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"start_positions": [[-1, 1], 'int64'],
"end_positions": [[-1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64']
}
else:
return {"token_ids": [[-1, -1, 1], 'int64'],
"position_ids": [[-1, -1, 1], 'int64'],
"segment_ids": [[-1, -1, 1], 'int64'],
"task_ids": [[-1, -1, 1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"unique_ids": [[-1, 1], 'int64']
}
@property
def epoch_outputs_attr(self):
if not self._is_training:
return {"examples": None,
"features": None}
def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
def iterator(self):
def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'start_positions', 'end_positions', 'unique_ids']
outputs = {n: i for n,i in zip(names, x)}
if self._is_training:
del outputs['unique_ids']
else:
del outputs['start_positions']
del outputs['end_positions']
return outputs
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def prepare_batch_data(insts,
total_token_num,
max_len=None,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
1. generate Tensor of data
2. generate Tensor of position
3. generate self attention mask, [shape: batch_size * max_len * max_len]
"""
batch_src_ids = [inst[0] for inst in insts]
batch_sent_ids = [inst[1] for inst in insts]
batch_pos_ids = [inst[2] for inst in insts]
labels_list = []
# compatible with mrqa, whose example includes start/end positions,
# or unique id
for i in range(3, len(insts[0]), 1):
labels = [inst[i] for inst in insts]
labels = np.array(labels).astype("int64").reshape([-1, 1])
labels_list.append(labels)
# First step: do mask without padding
if mask_id >= 0:
out, mask_label, mask_pos = mask(
batch_src_ids,
total_token_num,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
else:
out = batch_src_ids
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out,
max_len=max_len,
pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(
batch_pos_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
sent_id = pad_batch_data(
batch_sent_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
if mask_id >= 0:
return_list = [
src_id, pos_id, sent_id, self_input_mask, mask_label, mask_pos
] + labels_list
else:
return_list = [src_id, pos_id, sent_id, self_input_mask] + labels_list
return return_list if len(return_list) > 1 else return_list[0]
def pad_batch_data(insts,
max_len=None,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
if max_len is None:
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange
def mask(batch_tokens,
seg_labels,
mask_word_tags,
total_token_num,
vocab_size,
CLS=1,
SEP=2,
MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
mask_word = mask_word_tags[sent_index]
prob_index += pre_sent_len
if mask_word:
beg = 0
for token_index, token in enumerate(sent):
seg_label = seg_labels[sent_index][token_index]
if seg_label == 1:
continue
if beg == 0:
if seg_label != -1:
beg = token_index
continue
prob = prob_mask[prob_index + beg]
if prob > 0.15:
pass
else:
for index in xrange(beg, token_index):
prob = prob_mask[prob_index + index]
base_prob = 1.0
if index == beg:
base_prob = 0.15
if base_prob * 0.2 < prob <= base_prob:
mask_label.append(sent[index])
sent[index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + index)
elif base_prob * 0.1 < prob <= base_prob * 0.2:
mask_label.append(sent[index])
sent[index] = replace_ids[prob_index + index]
mask_flag = True
mask_pos.append(sent_index * max_len + index)
else:
mask_label.append(sent[index])
mask_pos.append(sent_index * max_len + index)
if seg_label == -1:
beg = 0
else:
beg = token_index
else:
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index +
token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def pad_batch_data(insts,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False,
return_seq_lens=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
if return_seq_lens:
seq_lens = np.array([len(inst) for inst in insts])
return_list += [seq_lens.astype("int64").reshape([-1, 1])]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
prob_mask = np.random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
mask_flag = False
prob_index += pre_sent_len
for token_index, token in enumerate(sent):
prob = prob_mask[prob_index + token_index]
if prob > 0.15:
continue
elif 0.03 < prob <= 0.15:
# mask
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
elif 0.015 < prob <= 0.03:
# random replace
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
sent[token_index] = replace_ids[prob_index + token_index]
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
else:
# keep the original token
if token != SEP and token != CLS:
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
pre_sent_len = len(sent)
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
mask_flag = True
mask_pos.append(sent_index * max_len + token_index)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return batch_tokens, mask_label, mask_pos
def prepare_batch_data(insts,
total_token_num,
max_len=None,
voc_size=0,
pad_id=None,
cls_id=None,
sep_id=None,
mask_id=None,
task_id=0,
return_input_mask=True,
return_max_len=True,
return_num_token=False):
"""
1. generate Tensor of data
2. generate Tensor of position
3. generate self attention mask, [shape: batch_size * max_len * max_len]
"""
batch_src_ids = [inst[0] for inst in insts]
batch_sent_ids = [inst[1] for inst in insts]
batch_pos_ids = [inst[2] for inst in insts]
# 这里是否应该反过来???否则在task layer里展开后的word embedding是padding后的,这时候word的index是跟没有padding时的index对不上的?
# First step: do mask without padding
out, mask_label, mask_pos = mask(
batch_src_ids,
total_token_num,
vocab_size=voc_size,
CLS=cls_id,
SEP=sep_id,
MASK=mask_id)
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out,
max_len=max_len,
pad_idx=pad_id, return_input_mask=True)
pos_id = pad_batch_data(
batch_pos_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
sent_id = pad_batch_data(
batch_sent_ids,
max_len=max_len,
pad_idx=pad_id,
return_pos=False,
return_input_mask=False)
task_ids = np.ones_like(
src_id, dtype="int64") * task_id
return_list = [
src_id, pos_id, sent_id, self_input_mask, task_ids, mask_label, mask_pos
]
return return_list if len(return_list) > 1 else return_list[0]
def pad_batch_data(insts,
max_len=None,
pad_idx=0,
return_pos=False,
return_input_mask=False,
return_max_len=False,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
if max_len is None:
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
if __name__ == "__main__":
pass
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class MRQAExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (
tokenization.printable_text(self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
if self.start_position:
s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s
class MRQAFeature(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None,
is_impossible=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
此差异已折叠。
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
class TaskParadigm(task_paradigm):
'''
classification
'''
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._hidden_size = backbone_config['hidden_size']
self.num_classes = config['n_classes']
if 'initializer_range' in config:
self._param_initializer = config['initializer_range']
else:
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=backbone_config.get('initializer_range', 0.02))
if 'dropout_prob' in config:
self._dropout_prob = config['dropout_prob']
else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._preds = []
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']}
else:
reader = {}
bb = {"sentence_embedding": [[-1, self._hidden_size], 'float32']}
return {'reader': reader, 'backbone': bb}
@property
def outputs_attrs(self):
if self._is_training:
return {'loss': [[1], 'float32']}
else:
return {'logits': [[-1, self.num_classes], 'float32']}
def build(self, inputs, scope_name=''):
sent_emb = inputs['backbone']['sentence_embedding']
if self._is_training:
label_ids = inputs['reader']['label_ids']
cls_feats = fluid.layers.dropout(
x=sent_emb,
dropout_prob=self._dropout_prob,
dropout_implementation="upscale_in_train")
logits = fluid.layers.fc(
input=sent_emb,
size=self.num_classes,
param_attr=fluid.ParamAttr(
name=scope_name+"cls_out_w",
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.)))
if self._is_training:
loss = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=label_ids)
loss = layers.mean(loss)
return {"loss": loss}
else:
return {"logits":logits}
def postprocess(self, rt_outputs):
if not self._is_training:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
class TaskParadigm(task_paradigm):
'''
matching
'''
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._hidden_size = backbone_config['hidden_size']
if 'initializer_range' in config:
self._param_initializer = config['initializer_range']
else:
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=backbone_config.get('initializer_range', 0.02))
if 'dropout_prob' in config:
self._dropout_prob = config['dropout_prob']
else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._preds = []
@property
def inputs_attrs(self):
if self._is_training:
reader = {"label_ids": [[-1, 1], 'int64']}
else:
reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
return {'reader': reader, 'backbone': bb}
@property
def outputs_attrs(self):
if self._is_training:
return {"loss": [[1], 'float32']}
else:
return {"logits": [[-1, 2], 'float32']}
def build(self, inputs, scope_name=""):
if self._is_training:
labels = inputs["reader"]["label_ids"]
cls_feats = inputs["backbone"]["sentence_pair_embedding"]
if self._is_training:
cls_feats = fluid.layers.dropout(
x=cls_feats,
dropout_prob=self._dropout_prob,
dropout_implementation="upscale_in_train")
logits = fluid.layers.fc(
input=cls_feats,
size=2,
param_attr=fluid.ParamAttr(
name=scope_name+"cls_out_w",
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b",
initializer=fluid.initializer.Constant(0.)))
if self._is_training:
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=labels, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
return {'loss': loss}
else:
return {'logits': logits}
def postprocess(self, rt_outputs):
if not self._is_training:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers
from paddlepalm.backbone.utils.transformer import pre_process_layer
class TaskParadigm(task_paradigm):
'''
matching
'''
def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train'
self._emb_size = backbone_config['hidden_size']
self._hidden_size = backbone_config['hidden_size']
self._vocab_size = backbone_config['vocab_size']
self._hidden_act = backbone_config['hidden_act']
self._initializer_range = backbone_config['initializer_range']
@property
def inputs_attrs(self):
reader = {
"mask_label": [[-1, 1], 'int64'],
"mask_pos": [[-1, 1], 'int64']}
if not self._is_training:
del reader['mask_label']
del reader['batchsize_x_seqlen']
bb = {
"encoder_outputs": [[-1, -1, self._hidden_size], 'float32'],
"embedding_table": [[-1, self._vocab_size, self._emb_size], 'float32']}
return {'reader': reader, 'backbone': bb}
@property
def outputs_attrs(self):
if self._is_training:
return {"loss": [[1], 'float32']}
else:
return {"logits": [[-1], 'float32']}
def build(self, inputs, scope_name=""):
mask_pos = inputs["reader"]["mask_pos"]
if self._is_training:
mask_label = inputs["reader"]["mask_label"]
max_position = inputs["reader"]["batchsize_x_seqlen"] - 1
mask_pos = fluid.layers.elementwise_min(mask_pos, max_position)
mask_pos.stop_gradient = True
word_emb = inputs["backbone"]["embedding_table"]
enc_out = inputs["backbone"]["encoder_outputs"]
emb_size = word_emb.shape[-1]
_param_initializer = fluid.initializer.TruncatedNormal(
scale=self._initializer_range)
reshaped_emb_out = fluid.layers.reshape(
x=enc_out, shape=[-1, emb_size])
# extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
# transform: fc
mask_trans_feat = fluid.layers.fc(
input=mask_feat,
size=emb_size,
act=self._hidden_act,
param_attr=fluid.ParamAttr(
name=scope_name+'mask_lm_trans_fc.w_0',
initializer=_param_initializer),
bias_attr=fluid.ParamAttr(name=scope_name+'mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name=scope_name+'mask_lm_trans')
mask_lm_out_bias_attr = fluid.ParamAttr(
name=scope_name+"mask_lm_out_fc.b_0",
initializer=fluid.initializer.Constant(value=0.0))
fc_out = fluid.layers.matmul(
x=mask_trans_feat,
y=word_emb,
transpose_y=True)
fc_out += fluid.layers.create_parameter(
shape=[self._vocab_size],
dtype='float32',
attr=mask_lm_out_bias_attr,
is_bias=True)
if self._is_training:
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=mask_label)
loss = fluid.layers.mean(mask_lm_loss)
return {'loss': loss}
else:
return {'logits': fc_out}
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册