diff --git a/customization.md b/customization_cn.md similarity index 99% rename from customization.md rename to customization_cn.md index 4e946b085ad5775df53f9925c35f1c6e2cf773d0..bca62554efdd8b49540dbf63573904444d951c79 100644 --- a/customization.md +++ b/customization_cn.md @@ -1,16 +1,16 @@ - +# PALM组件定制化教程 PALM支持对如下组件自定义: -- head +- **head** 定义一个新的任务输出头,接收来自backbone和reader的输入,输出训练阶段的loss和预测阶段的预测结果。例如:分类任务头,序列标注任务头,机器阅读理解任务头等。 -- backbone +- **backbone** 定义一个新的主干网络,接收来自reader的文本相关的序列特征输入(如token ids),输出文本的特征向量表示(如词向量、上下文相关的词向量表示、句子向量等)。例如:BERT encoder,CNN encoder等。 -- reader +- **reader** 定义一个新的数据集载入与预处理模块,接收来自原始数据集文件的输入(纯文本,原始标签等),输出文本相关的序列特征(如token ids,position ids等)。例如:文本分类数据集处理模块;文本匹配数据集处理模块等。 -- optimizer +- **optimizer** 定义一个新的优化器 -- lr_sched +- **lr_sched** 定义一种新的学习率规划策略 PALM中的每个组件均使用类来描述,因此可以允许存在内部记忆(成员变量)。 @@ -38,11 +38,13 @@ head的接口类(Interface)位于`paddlepalm/head/base_head.py`。 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import json import copy class Head(object): +     def __init__(self, phase='train'):         """该函数完成一个任务头的构造,至少需要包含一个phase参数。         注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 @@ -53,6 +55,7 @@ class Head(object):         self._phase = phase         self._prog = None         self._results_buffer = [] +     @property     def inputs_attrs(self):         """step级别的任务输入对象声明。 @@ -78,6 +81,7 @@ class Head(object):             """         raise NotImplementedError() +     @property     def epoch_inputs_attrs(self):         """epoch级别的任务输入对象声明。 @@ -102,6 +106,7 @@ class Head(object):            需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。         """         raise NotImplementedError() +     def batch_postprocess(self, rt_outputs):         """batch/step级别的后处理。