Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
c4b03ce3
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
7
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c4b03ce3
编写于
3月 29, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
3月 29, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update base_head.py
上级
f6579ca0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
53 addition
and
16 deletion
+53
-16
paddlepalm/head/base_head.py
paddlepalm/head/base_head.py
+53
-16
未找到文件。
paddlepalm/head/base_head.py
浏览文件 @
c4b03ce3
...
...
@@ -15,6 +15,7 @@
import
os
import
json
import
copy
class
Head
(
object
):
...
...
@@ -31,19 +32,26 @@ class Head(object):
@
property
def
inputs_attrs
(
self
):
"""描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的
组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为
输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维
度长度可变时,shape中的相应维度设置为-1。
"""step级别的任务输入对象声明。
描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象(每个step获取一次)。使用字典进行描述,
字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。
输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象
的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。描述该任务头所依赖的来自各个组件的输出对象。"""
dict类型。描述该任务头所依赖的
step级输入,即
来自各个组件的输出对象。"""
raise
NotImplementedError
()
@
property
def
outputs_attr
(
self
):
"""描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个
训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
"""step级别的任务输出对象声明。
描述该任务头的输出对象(每个step输出一次),包括每个输出对象的名字,shape和dtype。输出对象会被加入到
fetch_list中,从而在每个训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方
法中进行当前step的后处理。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],
当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。
"""
...
...
@@ -52,21 +60,36 @@ class Head(object):
@
property
def
epoch_inputs_attrs
(
self
):
"""epoch级别的任务输入对象声明。
描述该任务所依赖的来自reader、backbone和来自其他任务头的输出对象(每个epoch结束后产生一次),如完整的
样本集,有效的样本数等。使用字典进行描述,字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),
value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关
组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相
应维度设置为-1。
Return:
dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。
"""
return
{}
def
build
(
self
,
inputs
,
scope_name
=
""
):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
"""建立任务头的计算图。
将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
Return:
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""
raise
NotImplementedError
()
def
batch_postprocess
(
self
,
rt_outputs
):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。"""
"""batch/step级别的后处理。
每个训练或推理step后针对当前batch的任务头输出对象的实时计算结果来进行相关后处理。
默认将输出结果存储到缓冲区self._results_buffer中。"""
if
isinstance
(
rt_outputs
,
dict
):
keys
=
rt_outputs
.
keys
()
vals
=
[
rt_outputs
[
k
]
for
k
in
keys
]
...
...
@@ -79,8 +102,25 @@ class Head(object):
print
(
'WARNING: irregular output results. visualize failed.'
)
self
.
_results_buffer
.
append
(
rt_outputs
)
return
None
def
reset
(
self
):
"""清空该任务头的缓冲区(在训练或推理过程中积累的处理结果)"""
self
.
_results_buffer
=
[]
def
get_results
(
self
):
"""返回当前任务头积累的处理结果。"""
return
copy
.
deepcopy
(
self
.
_results_buffer
)
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
def
epoch_postprocess
(
self
,
post_inputs
=
None
,
output_dir
=
None
):
"""epoch级别的后处理。
每个训练或推理epoch结束后,对积累的各样本的后处理结果results进行后处理。默认情况下,当output_dir为None时,直接将results打印到
屏幕上。当指定output_dir时,将results存储在指定的文件夹内,并以任务头所处阶段来作为存储文件的文件名。
Args:
post_inputs: 当声明的epoch_inputs_attr不为空时,该参数会携带对应的输入变量的内容。
output_dir: 积累结果的保存路径。
"""
if
output_dir
is
not
None
:
for
i
in
self
.
_results_buffer
:
print
(
i
)
...
...
@@ -90,6 +130,3 @@ class Head(object):
with
open
(
os
.
path
.
join
(
output_dir
,
self
.
_phase
),
'w'
)
as
writer
:
for
i
in
self
.
_results_buffer
:
writer
.
write
(
json
.
dumps
(
i
)
+
'
\n
'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录