未验证 提交 c76454e7 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update and rename customization.md to customization_cn.md

上级 b7e9830f
# PALM组件定制化教程
PALM支持对如下组件自定义: PALM支持对如下组件自定义:
- head - **head**
定义一个新的任务输出头,接收来自backbone和reader的输入,输出训练阶段的loss和预测阶段的预测结果。例如:分类任务头,序列标注任务头,机器阅读理解任务头等。 定义一个新的任务输出头,接收来自backbone和reader的输入,输出训练阶段的loss和预测阶段的预测结果。例如:分类任务头,序列标注任务头,机器阅读理解任务头等。
- backbone - **backbone**
定义一个新的主干网络,接收来自reader的文本相关的序列特征输入(如token ids),输出文本的特征向量表示(如词向量、上下文相关的词向量表示、句子向量等)。例如:BERT encoder,CNN encoder等。 定义一个新的主干网络,接收来自reader的文本相关的序列特征输入(如token ids),输出文本的特征向量表示(如词向量、上下文相关的词向量表示、句子向量等)。例如:BERT encoder,CNN encoder等。
- reader - **reader**
定义一个新的数据集载入与预处理模块,接收来自原始数据集文件的输入(纯文本,原始标签等),输出文本相关的序列特征(如token ids,position ids等)。例如:文本分类数据集处理模块;文本匹配数据集处理模块等。 定义一个新的数据集载入与预处理模块,接收来自原始数据集文件的输入(纯文本,原始标签等),输出文本相关的序列特征(如token ids,position ids等)。例如:文本分类数据集处理模块;文本匹配数据集处理模块等。
- optimizer - **optimizer**
定义一个新的优化器 定义一个新的优化器
- lr_sched - **lr_sched**
定义一种新的学习率规划策略 定义一种新的学习率规划策略
PALM中的每个组件均使用类来描述,因此可以允许存在内部记忆(成员变量)。 PALM中的每个组件均使用类来描述,因此可以允许存在内部记忆(成员变量)。
...@@ -38,11 +38,13 @@ head的接口类(Interface)位于`paddlepalm/head/base_head.py`。 ...@@ -38,11 +38,13 @@ head的接口类(Interface)位于`paddlepalm/head/base_head.py`。
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import json import json
import copy import copy
class Head(object): class Head(object):
    def __init__(self, phase='train'):     def __init__(self, phase='train'):
        """该函数完成一个任务头的构造,至少需要包含一个phase参数。         """该函数完成一个任务头的构造,至少需要包含一个phase参数。
        注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。         注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
...@@ -53,6 +55,7 @@ class Head(object): ...@@ -53,6 +55,7 @@ class Head(object):
        self._phase = phase         self._phase = phase
        self._prog = None         self._prog = None
        self._results_buffer = []         self._results_buffer = []
    @property     @property
    def inputs_attrs(self):     def inputs_attrs(self):
        """step级别的任务输入对象声明。         """step级别的任务输入对象声明。
...@@ -78,6 +81,7 @@ class Head(object): ...@@ -78,6 +81,7 @@ class Head(object):
            """             """
        raise NotImplementedError()         raise NotImplementedError()
    @property     @property
    def epoch_inputs_attrs(self):     def epoch_inputs_attrs(self):
        """epoch级别的任务输入对象声明。         """epoch级别的任务输入对象声明。
...@@ -102,6 +106,7 @@ class Head(object): ...@@ -102,6 +106,7 @@ class Head(object):
           需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。            需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
        """         """
        raise NotImplementedError()         raise NotImplementedError()
    def batch_postprocess(self, rt_outputs):     def batch_postprocess(self, rt_outputs):
        """batch/step级别的后处理。         """batch/step级别的后处理。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册