From f6579ca0b24a612c4b0a137b188ebcf7590973e0 Mon Sep 17 00:00:00 2001 From: Xiaoyao Xi <24541791+xixiaoyao@users.noreply.github.com> Date: Fri, 27 Mar 2020 20:21:29 +0800 Subject: [PATCH] Update base_head.py --- paddlepalm/head/base_head.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/paddlepalm/head/base_head.py b/paddlepalm/head/base_head.py index 2446885..9d4a614 100644 --- a/paddlepalm/head/base_head.py +++ b/paddlepalm/head/base_head.py @@ -19,9 +19,10 @@ import json class Head(object): def __init__(self, phase='train'): - """ - config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 - phase: str类型。运行阶段,目前支持train和predict + """该函数完成一个任务头的构造,至少需要包含一个phase参数。 + 注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。 + Args: + phase: str类型。用于区分任务头被调用时所处的任务运行阶段,目前支持训练阶段train和预测阶段predict """ self._stop_gradient = {} self._phase = phase @@ -30,17 +31,21 @@ class Head(object): @property def inputs_attrs(self): - """描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + """描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的 + 组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为 + 输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维 + 度长度可变时,shape中的相应维度设置为-1。 Return: - dict类型。对各个对象集及其输入对象的属性描述。""" + dict类型。描述该任务头所依赖的来自各个组件的输出对象。""" raise NotImplementedError() @property def outputs_attr(self): - """描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + """描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个 + 训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 Return: - dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 + dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。 """ raise NotImplementedError() @@ -49,17 +54,6 @@ class Head(object): def epoch_inputs_attrs(self): return {} - # def stop_gradient(source, inputs): - # # if self._inputs is None: - # # raise Exception('You need to build this head first before stop gradient.') - # self._inputs = inputs - # for name, var in self._inputs[source].items(): - # # cur_block = self._prog.current_block() - # var = fluid.layers.assign(var) - # var.stop_gradient = True - # self._inputs[name] = var - # return self._inputs - def build(self, inputs, scope_name=""): """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 Args: -- GitLab