PALM支持对如下组件自定义: - head 定义一个新的任务输出头,接收来自backbone和reader的输入,输出训练阶段的loss和预测阶段的预测结果。例如:分类任务头,序列标注任务头,机器阅读理解任务头等。 - backbone 定义一个新的主干网络,接收来自reader的文本相关的序列特征输入(如token ids),输出文本的特征向量表示(如词向量、上下文相关的词向量表示、句子向量等)。例如:BERT encoder,CNN encoder等。 - reader 定义一个新的数据集载入与预处理模块,接收来自原始数据集文件的输入(纯文本,原始标签等),输出文本相关的序列特征(如token ids,position ids等)。例如:文本分类数据集处理模块;文本匹配数据集处理模块等。 - optimizer 定义一个新的优化器 - lr_sched 定义一种新的学习率规划策略 PALM中的每个组件均使用类来描述,因此可以允许存在内部记忆(成员变量)。 新增某种类型的组件时,只需要实现该组件类型所在目录下的接口类中所描述的方法。若希望新增的组件跟框架的某个内置组件功能相似,那么实现新增组件时,可以继承自已有的内置组件,且仅对需要变动的方法进行修改即可。 ### head自定义 head的接口类(Interface)位于`paddlepalm/head/base_head.py`。 该接口类定义如下: ```python # -*- coding: UTF-8 -*- #   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # #     http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import json import copy class Head(object):     def __init__(self, phase='train'):         """该函数完成一个任务头的构造,至少需要包含一个phase参数。         注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。         Args:             phase: str类型。用于区分任务头被调用时所处的任务运行阶段,目前支持训练阶段train和预测阶段predict             """         self._stop_gradient = {}         self._phase = phase         self._prog = None         self._results_buffer = []     @property     def inputs_attrs(self):         """step级别的任务输入对象声明。         描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象(每个step获取一次)。使用字典进行描述,         字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。         输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象         的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相应维度设置为-1。         Return:             dict类型。描述该任务头所依赖的step级输入,即来自各个组件的输出对象。"""         raise NotImplementedError()     @property     def outputs_attr(self):         """step级别的任务输出对象声明。         描述该任务头的输出对象(每个step输出一次),包括每个输出对象的名字,shape和dtype。输出对象会被加入到         fetch_list中,从而在每个训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方         法中进行当前step的后处理。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],         当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。         Return:             dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。             """         raise NotImplementedError()     @property     def epoch_inputs_attrs(self):         """epoch级别的任务输入对象声明。         描述该任务所依赖的来自reader、backbone和来自其他任务头的输出对象(每个epoch结束后产生一次),如完整的         样本集,有效的样本数等。使用字典进行描述,字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),         value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关         组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相         应维度设置为-1。                  Return:             dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。         """         return {}     def build(self, inputs, scope_name=""):         """建立任务头的计算图。         将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。         Args:             inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象         Return:            需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。         """         raise NotImplementedError()     def batch_postprocess(self, rt_outputs):         """batch/step级别的后处理。         每个训练或推理step后针对当前batch的任务头输出对象的实时计算结果来进行相关后处理。         默认将输出结果存储到缓冲区self._results_buffer中。"""         if isinstance(rt_outputs, dict):             keys = rt_outputs.keys()             vals = [rt_outputs[k] for k in keys]             lens = [len(v) for v in vals]             if len(set(lens)) == 1:                 results = [dict(zip(*[keys, i])) for i in zip(*vals)]                 self._results_buffer.extend(results)                 return results             else:                 print('WARNING: irregular output results. visualize failed.')                 self._results_buffer.append(rt_outputs)         return None     def reset(self):         """清空该任务头的缓冲区(在训练或推理过程中积累的处理结果)"""         self._results_buffer = []     def get_results(self):         """返回当前任务头积累的处理结果。"""         return copy.deepcopy(self._results_buffer)              def epoch_postprocess(self, post_inputs=None, output_dir=None):         """epoch级别的后处理。         每个训练或推理epoch结束后,对积累的各样本的后处理结果results进行后处理。默认情况下,当output_dir为None时,直接将results打印到         屏幕上。当指定output_dir时,将results存储在指定的文件夹内,并以任务头所处阶段来作为存储文件的文件名。         Args:             post_inputs: 当声明的epoch_inputs_attr不为空时,该参数会携带对应的输入变量的内容。             output_dir: 积累结果的保存路径。         """         if output_dir is not None:             for i in self._results_buffer:                 print(i)         else:             if not os.path.exists(output_dir):                 os.makedirs(output_dir)             with open(os.path.join(output_dir, self._phase), 'w') as writer:                 for i in self._results_buffer:                     writer.write(json.dumps(i)+'\n') ``` 在基类的基础上,定义一个全新的Head时需要至少实现的方法有: - \_\_init\_\_ - inputs_attrs - outputs_attr - build 可以重写的方法有: - epoch_inputs_attrs - batch_postprocess - epoch_postprocess ### backbone自定义 backbone的接口类(Interface)位于`paddlepalm/backbone/base_backbone.py`。 该接口类定义如下: ```python # -*- coding: UTF-8 -*- #   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # #     http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. class Backbone(object):     """interface of backbone model."""     def __init__(self, phase):         """该函数完成一个主干网络的构造,至少需要包含一个phase参数。         注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。         Args:             phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict             """         assert isinstance(config, dict)     @property     def inputs_attr(self):         """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象         为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape         中的相应维度设置为-1。         Return:             dict类型。对各个输入对象的属性描述。例如,             对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象                 {"token_ids": ([-1, max_len], 'int64'),                  "input_ids": ([-1, max_len], 'int64'),                  "segment_ids": ([-1, max_len], 'int64'),                  "input_mask": ([-1, max_len], 'float32')}"""         raise NotImplementedError()     @property     def outputs_attr(self):         """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如         str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。                  Return:             dict类型。对各个输出对象的属性描述。例如,             对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象                 {"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'),                  "sentence_emb": ([-1, hidden_size], 'float32'),                  "sim_vec": ([-1, hidden_size], 'float32')}"""          raise NotImplementedError()     def build(self, inputs):         """建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。         Args:             inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象         Return:            需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。             """ raise NotImplementedError() ``` 在基类的基础上,定义一个全新的Backbone时需要至少实现的方法有: - \_\_init\_\_ - input_attrs - output_attr - build ### reader自定义 reader的接口类(Interface)位于`paddlepalm/reader/base_reader.py`。 该接口类定义如下: ```python # -*- coding: UTF-8 -*- #   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # #     http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import copy class Reader(object):     """interface of data reader."""     def __init__(self, phase='train'):         """该函数完成一个Reader的构造,至少需要包含一个phase参数。         注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。         Args:             phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict             """                  self._phase = phase         self._batch_size = None         self._num_epochs = 1         self._register = set()         self._registered_backbone = None     @classmethod     def create_register(self):         return set()              def clone(self, phase='train'):         """拷贝一个新的reader对象。"""         if phase == self._phase:             return copy(self)         else:             ret = copy(self)             ret._phase = phase             return ret     def require_attr(self, attr_name):         """在注册器中新增一个需要产生的对象。         Args:             attr_name: 需要产出的对象的对象名,例如’segment_ids‘。             """         self._register.add(attr_name)                  def register_with(self, backbone):         """根据backbone对输入对象的依赖,在注册器中对每个依赖的输入对象进行注册。         Args:             backbone: 需要对接的主干网络。         """         for attr in backbone.inputs_attr:             self.require_attr(attr)         self._registered_backbone = backbone     def get_registered_backbone(self):         """返回该reader所注册的backbone。"""         return self._registered_backbone     def _get_registed_attrs(self, attrs):         ret = {}         for i in self._register:             if i not in attrs:                 raise NotImplementedError('output attr {} is not found in this reader.'.format(i))             ret[i] = attrs[i]         return ret     def load_data(self, input_file, batch_size, num_epochs=None, \                   file_format='tsv', shuffle_train=True):         """将磁盘上的数据载入到reader中。         注意:实现该方法时需要同步创建self._batch_size和self._num_epochs。         Args:             input_file: 数据集文件路径。文件格式需要满足`file_format`参数的要求。             batch_size: 迭代器每次yield出的样本数量。注意:当环境中存在多个GPU时,batch_size需要保证被GPU卡数整除。             num_epochs: 数据集遍历次数。默认为None, 在单任务模式下代表遍历一次,在多任务模式下该参数会被上层的Trainer进行自动赋值。该参数仅对训练阶段有效。             file_format: 输入文件的文件格式。目前支持的格式: tsv. 默认为tsv.             shuffle_train: 是否打乱训练集中的样本。默认为True。该参数仅对训练阶段有效。         """         raise NotImplementedError()     @property     def outputs_attr(self):         """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据         类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。         注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1)         Return:             dict类型。对各个输入对象的属性描述。例如,             对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象)                 {"token_ids": ([-1, max_len], 'int64'),                  "input_ids": ([-1, max_len], 'int64'),                  "segment_ids": ([-1, max_len], 'int64'),                  "input_mask": ([-1, max_len], 'float32'),                  "label": ([-1], 'int')}         """         raise NotImplementedError()          def _iterator(self):         """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。         Yield:             dict类型。符合outputs_attr描述的当前step的输出对象。         """         raise NotImplementedError()     def get_epoch_outputs(self):         """返回数据集每个epoch遍历后的输出对象。"""         raise NotImplementedError()     @property     def num_examples(self):         """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时         该接口应返回runtime阶段的实际样本数。"""         raise NotImplementedError()     @property     def num_epochs(self):         """数据集遍历次数"""         return self._num_epochs ``` 在基类的基础上,定义一个全新的Reader时需要至少实现的方法有: - \_\_init\_\_ - outputs_attr - load_data - _iterator - num_examples 可以重写的方法有: - get_epoch_outputs