未验证 提交 f6579ca0 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update base_head.py

上级 c00b77fe
......@@ -19,9 +19,10 @@ import json
class Head(object):
def __init__(self, phase='train'):
"""
config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数
phase: str类型。运行阶段,目前支持train和predict
"""该函数完成一个任务头的构造,至少需要包含一个phase参数。
注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
Args:
phase: str类型。用于区分任务头被调用时所处的任务运行阶段,目前支持训练阶段train和预测阶段predict
"""
self._stop_gradient = {}
self._phase = phase
......@@ -30,17 +31,21 @@ class Head(object):
@property
def inputs_attrs(self):
"""描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
"""描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的
组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为
输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维
度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个对象集及其输入对象的属性描述。"""
dict类型。描述该任务头所依赖的来自各个组件的输出对象。"""
raise NotImplementedError()
@property
def outputs_attr(self):
"""描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
"""描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个
训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return:
dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。
dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。
"""
raise NotImplementedError()
......@@ -49,17 +54,6 @@ class Head(object):
def epoch_inputs_attrs(self):
return {}
# def stop_gradient(source, inputs):
# # if self._inputs is None:
# # raise Exception('You need to build this head first before stop gradient.')
# self._inputs = inputs
# for name, var in self._inputs[source].items():
# # cur_block = self._prog.current_block()
# var = fluid.layers.assign(var)
# var.stop_gradient = True
# self._inputs[name] = var
# return self._inputs
def build(self, inputs, scope_name=""):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册