diff --git a/paddlepalm/head/base_head.py b/paddlepalm/head/base_head.py index 9d4a614f21e0f8859df4c58691fabf31d5f6fa05..f71f491003b58a4a563babdd33931a1b069ce61a 100644 --- a/paddlepalm/head/base_head.py +++ b/paddlepalm/head/base_head.py @@ -15,6 +15,7 @@ import os import json +import copy class Head(object): @@ -31,19 +32,26 @@ class Head(object): @property def inputs_attrs(self): - """描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象。使用字典进行描述,字典的key为输出对象所在的 - 组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为 - 输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维 - 度长度可变时,shape中的相应维度设置为-1。 + """step级别的任务输入对象声明。 + + 描述该任务头所依赖的reader、backbone和来自其他任务头的输出对象(每个step获取一次)。使用字典进行描述, + 字典的key为输出对象所在的组件(如’reader‘,’backbone‘等),value为该组件下任务头所需要的输出对象集。 + 输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关组件的输出对象集中),value为该输出对象 + 的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: - dict类型。描述该任务头所依赖的来自各个组件的输出对象。""" + dict类型。描述该任务头所依赖的step级输入,即来自各个组件的输出对象。""" raise NotImplementedError() @property def outputs_attr(self): - """描述该任务头的输出对象,包括每个输出对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个 - 训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方法中进行当前step的后处理。 - 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + """step级别的任务输出对象声明。 + + 描述该任务头的输出对象(每个step输出一次),包括每个输出对象的名字,shape和dtype。输出对象会被加入到 + fetch_list中,从而在每个训练/推理step时得到实时的计算结果,该计算结果可以传入batch_postprocess方 + 法中进行当前step的后处理。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[], + 当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。 """ @@ -52,21 +60,36 @@ class Head(object): @property def epoch_inputs_attrs(self): + """epoch级别的任务输入对象声明。 + + 描述该任务所依赖的来自reader、backbone和来自其他任务头的输出对象(每个epoch结束后产生一次),如完整的 + 样本集,有效的样本数等。使用字典进行描述,字典的key为输出对象所在的组件(如’reader‘,’backbone‘等), + value为该组件下任务头所需要的输出对象集。输出对象集使用字典描述,key为输出对象的名字(该名字需保证在相关 + 组件的输出对象集中),value为该输出对象的shape和dtype。当某个输出对象的某个维度长度可变时,shape中的相 + 应维度设置为-1。 + + Return: + dict类型。描述该任务头所产生的输出对象。注意,在训练阶段时必须包含名为loss的输出对象。 + """ return {} def build(self, inputs, scope_name=""): - """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 + """建立任务头的计算图。 + + 将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 + Args: inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 Return: 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 - """ raise NotImplementedError() - def batch_postprocess(self, rt_outputs): - """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" + """batch/step级别的后处理。 + + 每个训练或推理step后针对当前batch的任务头输出对象的实时计算结果来进行相关后处理。 + 默认将输出结果存储到缓冲区self._results_buffer中。""" if isinstance(rt_outputs, dict): keys = rt_outputs.keys() vals = [rt_outputs[k] for k in keys] @@ -79,8 +102,25 @@ class Head(object): print('WARNING: irregular output results. visualize failed.') self._results_buffer.append(rt_outputs) return None + + def reset(self): + """清空该任务头的缓冲区(在训练或推理过程中积累的处理结果)""" + self._results_buffer = [] + + def get_results(self): + """返回当前任务头积累的处理结果。""" + return copy.deepcopy(self._results_buffer) - def epoch_postprocess(self, post_inputs, output_dir=None): + def epoch_postprocess(self, post_inputs=None, output_dir=None): + """epoch级别的后处理。 + + 每个训练或推理epoch结束后,对积累的各样本的后处理结果results进行后处理。默认情况下,当output_dir为None时,直接将results打印到 + 屏幕上。当指定output_dir时,将results存储在指定的文件夹内,并以任务头所处阶段来作为存储文件的文件名。 + + Args: + post_inputs: 当声明的epoch_inputs_attr不为空时,该参数会携带对应的输入变量的内容。 + output_dir: 积累结果的保存路径。 + """ if output_dir is not None: for i in self._results_buffer: print(i) @@ -90,6 +130,3 @@ class Head(object): with open(os.path.join(output_dir, self._phase), 'w') as writer: for i in self._results_buffer: writer.write(json.dumps(i)+'\n') - - -