未验证 提交 f6579ca0 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update base_head.py

上级 c00b77fe
...@@ -19,9 +19,10 @@ import json ...@@ -19,9 +19,10 @@ import json
class Head(object): class Head(object):
def __init__(self, phase='train'): def __init__(self, phase='train'):
""" """该函数完成一个任务头的构造,至少需要包含一个phase参数。
config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
phase: str类型。运行阶段,目前支持train和predict Args:
phase: str类型。用于区分任务头被调用时所处的任务运行阶段,目前支持训练阶段train和预测阶段predict
""" """
self._stop_gradient = {} self._stop_gradient = {}
self._phase = phase self._phase = phase
...@@ -30,17 +31,21 @@ class Head(object): ...@@ -30,17 +31,21 @@ class Head(object):
@property @property
def inputs_attrs(self): def inputs_attrs(self):
"""描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 """描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的
组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为
输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维
度长度可变时,shape中的相应维度设置为-1。
Return: Return:
dict类型。对各个对象集及其输入对象的属性描述。""" dict类型。描述该任务头所依赖的来自各个组件的输出对象。"""
raise NotImplementedError() raise NotImplementedError()
@property @property
def outputs_attr(self): def outputs_attr(self):
"""描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 """描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个
训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。
当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return: Return:
dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -49,17 +54,6 @@ class Head(object): ...@@ -49,17 +54,6 @@ class Head(object):
def epoch_inputs_attrs(self): def epoch_inputs_attrs(self):
return {} return {}
# def stop_gradient(source, inputs):
# # if self._inputs is None:
# # raise Exception('You need to build this head first before stop gradient.')
# self._inputs = inputs
# for name, var in self._inputs[source].items():
# # cur_block = self._prog.current_block()
# var = fluid.layers.assign(var)
# var.stop_gradient = True
# self._inputs[name] = var
# return self._inputs
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args: Args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册