未验证 提交 c4b03ce3 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update base_head.py

上级 f6579ca0
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import json import json
import copy
class Head(object): class Head(object):
...@@ -31,19 +32,26 @@ class Head(object): ...@@ -31,19 +32,26 @@ class Head(object):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
"""描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的 """step级别的任务输入对象声明。
组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为
输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维 描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象(每个step获取一次)。使用字典进行描述,
度长度可变时,shape中的相应维度设置为-1。 字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。
输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象
的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return: Return:
dict类型。描述该任务头所依赖的来自各个组件的输出对象。""" dict类型。描述该任务头所依赖的step级输入,即来自各个组件的输出对象。"""
raise NotImplementedError() raise NotImplementedError()
@property @property
def outputs_attr(self): def outputs_attr(self):
"""描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个 """step级别的任务输出对象声明。
训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 描述该任务头的输出对象(每个step输出一次),包括每个输出对象的名字,shape和dtype。输出对象会被加入到
fetch_list中,从而在每个训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方
法中进行当前step的后处理。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],
当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return: Return:
dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。 dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。
""" """
...@@ -52,21 +60,36 @@ class Head(object): ...@@ -52,21 +60,36 @@ class Head(object):
@property @property
def epoch_inputs_attrs(self): def epoch_inputs_attrs(self):
"""epoch级别的任务输入对象声明。
描述该任务所依赖的来自reader、backbone和来自其他任务头的输出对象(每个epoch结束后产生一次),如完整的
样本集,有效的样本数等。使用字典进行描述,字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),
value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关
组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相
应维度设置为-1。
Return:
dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。
"""
return {} return {}
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 """建立任务头的计算图。
将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args: Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
Return: Return:
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
""" """
raise NotImplementedError() raise NotImplementedError()
def batch_postprocess(self, rt_outputs): def batch_postprocess(self, rt_outputs):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" """batch/step级别的后处理。
每个训练或推理step后针对当前batch的任务头输出对象的实时计算结果来进行相关后处理。
默认将输出结果存储到缓冲区self._results_buffer中。"""
if isinstance(rt_outputs, dict): if isinstance(rt_outputs, dict):
keys = rt_outputs.keys() keys = rt_outputs.keys()
vals = [rt_outputs[k] for k in keys] vals = [rt_outputs[k] for k in keys]
...@@ -79,8 +102,25 @@ class Head(object): ...@@ -79,8 +102,25 @@ class Head(object):
print('WARNING: irregular output results. visualize failed.') print('WARNING: irregular output results. visualize failed.')
self._results_buffer.append(rt_outputs) self._results_buffer.append(rt_outputs)
return None return None
def reset(self):
"""清空该任务头的缓冲区(在训练或推理过程中积累的处理结果)"""
self._results_buffer = []
def get_results(self):
"""返回当前任务头积累的处理结果。"""
return copy.deepcopy(self._results_buffer)
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs=None, output_dir=None):
"""epoch级别的后处理。
每个训练或推理epoch结束后,对积累的各样本的后处理结果results进行后处理。默认情况下,当output_dir为None时,直接将results打印到
屏幕上。当指定output_dir时,将results存储在指定的文件夹内,并以任务头所处阶段来作为存储文件的文件名。
Args:
post_inputs: 当声明的epoch_inputs_attr不为空时,该参数会携带对应的输入变量的内容。
output_dir: 积累结果的保存路径。
"""
if output_dir is not None: if output_dir is not None:
for i in self._results_buffer: for i in self._results_buffer:
print(i) print(i)
...@@ -90,6 +130,3 @@ class Head(object): ...@@ -90,6 +130,3 @@ class Head(object):
with open(os.path.join(output_dir, self._phase), 'w') as writer: with open(os.path.join(output_dir, self._phase), 'w') as writer:
for i in self._results_buffer: for i in self._results_buffer:
writer.write(json.dumps(i)+'\n') writer.write(json.dumps(i)+'\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册