Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
f6579ca0
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
7
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f6579ca0
编写于
3月 27, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
3月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update base_head.py
上级
c00b77fe
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
18 deletion
+12
-18
paddlepalm/head/base_head.py
paddlepalm/head/base_head.py
+12
-18
未找到文件。
paddlepalm/head/base_head.py
浏览文件 @
f6579ca0
...
...
@@ -19,9 +19,10 @@ import json
class
Head
(
object
):
def
__init__
(
self
,
phase
=
'train'
):
"""
config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数
phase: str类型。运行阶段,目前支持train和predict
"""该函数完成一个任务头的构造,至少需要包含一个phase参数。
注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
Args:
phase: str类型。用于区分任务头被调用时所处的任务运行阶段,目前支持训练阶段train和预测阶段predict
"""
self
.
_stop_gradient
=
{}
self
.
_phase
=
phase
...
...
@@ -30,17 +31,21 @@ class Head(object):
@
property
def
inputs_attrs
(
self
):
"""描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
"""描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的
组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为
输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维
度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。
对各个对象集及其输入对象的属性描述
。"""
dict类型。
描述该任务头所依赖的来自各个组件的输出对象
。"""
raise
NotImplementedError
()
@
property
def
outputs_attr
(
self
):
"""描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个
训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。
对各个输入对象的属性描述。注意,训练阶段
必须包含名为loss的输出对象。
dict类型。
描述该任务头所产生的输出对象。注意,在训练阶段时
必须包含名为loss的输出对象。
"""
raise
NotImplementedError
()
...
...
@@ -49,17 +54,6 @@ class Head(object):
def
epoch_inputs_attrs
(
self
):
return
{}
# def stop_gradient(source, inputs):
# # if self._inputs is None:
# # raise Exception('You need to build this head first before stop gradient.')
# self._inputs = inputs
# for name, var in self._inputs[source].items():
# # cur_block = self._prog.current_block()
# var = fluid.layers.assign(var)
# var.stop_gradient = True
# self._inputs[name] = var
# return self._inputs
def
build
(
self
,
inputs
,
scope_name
=
""
):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录