Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
c76454e7
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c76454e7
编写于
3月 30, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
3月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update and rename customization.md to customization_cn.md
上级
b7e9830f
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
11 addition
and
6 deletion
+11
-6
customization_cn.md
customization_cn.md
+11
-6
未找到文件。
customization.md
→
customization
_cn
.md
浏览文件 @
c76454e7
# PALM组件定制化教程
PALM支持对如下组件自定义:
-
head
-
**head**
定义一个新的任务输出头,接收来自backbone和reader的输入,输出训练阶段的loss和预测阶段的预测结果。例如:分类任务头,序列标注任务头,机器阅读理解任务头等。
-
backbone
-
**backbone**
定义一个新的主干网络,接收来自reader的文本相关的序列特征输入(如token ids),输出文本的特征向量表示(如词向量、上下文相关的词向量表示、句子向量等)。例如:BERT encoder,CNN encoder等。
-
reader
-
**reader**
定义一个新的数据集载入与预处理模块,接收来自原始数据集文件的输入(纯文本,原始标签等),输出文本相关的序列特征(如token ids,position ids等)。例如:文本分类数据集处理模块;文本匹配数据集处理模块等。
-
optimizer
-
**optimizer**
定义一个新的优化器
-
lr_sched
-
**lr_sched**
定义一种新的学习率规划策略
PALM中的每个组件均使用类来描述,因此可以允许存在内部记忆(成员变量)。
...
...
@@ -38,11 +38,13 @@ head的接口类(Interface)位于`paddlepalm/head/base_head.py`。
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
json
import
copy
class
Head
(
object
):
def
__init__
(
self
,
phase
=
'train'
):
"""该函数完成一个任务头的构造,至少需要包含一个phase参数。
注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
...
...
@@ -53,6 +55,7 @@ class Head(object):
self
.
_phase
=
phase
self
.
_prog
=
None
self
.
_results_buffer
=
[]
@
property
def
inputs_attrs
(
self
):
"""step级别的任务输入对象声明。
...
...
@@ -78,6 +81,7 @@ class Head(object):
"""
raise
NotImplementedError
()
@
property
def
epoch_inputs_attrs
(
self
):
"""epoch级别的任务输入对象声明。
...
...
@@ -102,6 +106,7 @@ class Head(object):
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""
raise
NotImplementedError
()
def
batch_postprocess
(
self
,
rt_outputs
):
"""batch/step级别的后处理。
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录