提交 c6e33be8 编写于 作者: X xixiaoyao

fix bugs

上级 b44f5e58
...@@ -154,7 +154,11 @@ class task_paradigm(object): ...@@ -154,7 +154,11 @@ class task_paradigm(object):
raise NotImplementedError() raise NotImplementedError()
def build(self, inputs): @property
def epoch_inputs_attrs(self):
return {}
def build(self, inputs, scope_name=""):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args: Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
...@@ -168,6 +172,6 @@ class task_paradigm(object): ...@@ -168,6 +172,6 @@ class task_paradigm(object):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。"""
pass pass
def post_postprocess(self, global_buffer): def epoch_postprocess(self, post_inputs):
pass pass
...@@ -182,21 +182,17 @@ def _fit_attr(conf, fit_attr, strict=False): ...@@ -182,21 +182,17 @@ def _fit_attr(conf, fit_attr, strict=False):
class Controller(object): class Controller(object):
def __init__(self, config=None, task_dir='.', for_train=True): def __init__(self, config, task_dir='.', for_train=True):
""" """
Args: Args:
config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径; config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径;
""" """
self._for_train = for_train self._for_train = for_train
# default mtl_conf assert isinstance(config, str) or isinstance(config, dict), "a config dict or config file path is required to create a Controller."
# if config is None and config_path is None:
# raise ValueError('For config and config_path, at least one of them should be set.')
if isinstance(config, str): if isinstance(config, str):
mtl_conf = _parse_yaml(config, support_cmd_line=True) mtl_conf = _parse_yaml(config, support_cmd_line=True)
# if config is not None:
# mtl_conf = _merge_conf(config, mtl_conf)
else: else:
mtl_conf = config mtl_conf = config
...@@ -518,6 +514,11 @@ class Controller(object): ...@@ -518,6 +514,11 @@ class Controller(object):
def _init_pred(self, instance, infer_model_path): def _init_pred(self, instance, infer_model_path):
inst = instance inst = instance
if 'pred_output_path' not in inst.config:
inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name)
if not os.path.exists(inst.config['pred_output_path']):
os.makedirs(inst.config['pred_output_path'])
pred_backbone = self.Backbone(self.bb_conf, phase='pred') pred_backbone = self.Backbone(self.bb_conf, phase='pred')
pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf) pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf)
...@@ -563,7 +564,12 @@ class Controller(object): ...@@ -563,7 +564,12 @@ class Controller(object):
finish = [] finish = []
for inst in instances: for inst in instances:
if inst.is_target: if inst.is_target:
finish.append(False) if inst.expected_train_steps > 0:
finish.append(False)
else:
finish.append(True)
print(inst.name+': train finished!')
inst.save()
def train_finish(): def train_finish():
for inst in instances: for inst in instances:
...@@ -641,9 +647,11 @@ class Controller(object): ...@@ -641,9 +647,11 @@ class Controller(object):
pred_prog = self._init_pred(instance, inference_model_dir) pred_prog = self._init_pred(instance, inference_model_dir)
inst = instance inst = instance
print(inst.name+": loading data...")
inst.reader['pred'].load_data() inst.reader['pred'].load_data()
fetch_names, fetch_vars = inst.pred_fetch_list fetch_names, fetch_vars = inst.pred_fetch_list
print('predicting...')
mapper = {k:v for k,v in inst.pred_input} mapper = {k:v for k,v in inst.pred_input}
buf = [] buf = []
for feed in inst.reader['pred'].iterator(): for feed in inst.reader['pred'].iterator():
...@@ -653,12 +661,13 @@ class Controller(object): ...@@ -653,12 +661,13 @@ class Controller(object):
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred') inst.postprocess(rt_outputs, phase='pred')
reader_outputs = inst.reader['pred'].get_epoch_outputs() if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs()
else:
reader_outputs = None
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
if __name__ == '__main__': if __name__ == '__main__':
assert len(sys.argv) == 2, "Usage: python mtl_controller.py <mtl_conf_path>" assert len(sys.argv) == 2, "Usage: python mtl_controller.py <mtl_conf_path>"
conf_path = sys.argv[1] conf_path = sys.argv[1]
......
...@@ -93,10 +93,6 @@ class Reader(reader): ...@@ -93,10 +93,6 @@ class Reader(reader):
for batch in self._data_generator(): for batch in self._data_generator():
yield list_to_dict(batch) yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property @property
def num_examples(self): def num_examples(self):
return self._reader.get_num_examples(phase=self._phase) return self._reader.get_num_examples(phase=self._phase)
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
class TaskParadigm(task_paradigm): class TaskParadigm(task_paradigm):
''' '''
...@@ -35,6 +37,8 @@ class TaskParadigm(task_paradigm): ...@@ -35,6 +37,8 @@ class TaskParadigm(task_paradigm):
self._dropout_prob = config['dropout_prob'] self._dropout_prob = config['dropout_prob']
else: else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0) self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._preds = []
@property @property
def inputs_attrs(self): def inputs_attrs(self):
...@@ -78,3 +82,20 @@ class TaskParadigm(task_paradigm): ...@@ -78,3 +82,20 @@ class TaskParadigm(task_paradigm):
else: else:
return {"logits":logits} return {"logits":logits}
def postprocess(self, rt_outputs):
if not self._is_training:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
class TaskParadigm(task_paradigm): class TaskParadigm(task_paradigm):
''' '''
...@@ -35,6 +37,9 @@ class TaskParadigm(task_paradigm): ...@@ -35,6 +37,9 @@ class TaskParadigm(task_paradigm):
else: else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0) self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._preds = []
@property @property
def inputs_attrs(self): def inputs_attrs(self):
...@@ -50,7 +55,7 @@ class TaskParadigm(task_paradigm): ...@@ -50,7 +55,7 @@ class TaskParadigm(task_paradigm):
if self._is_training: if self._is_training:
return {"loss": [[1], 'float32']} return {"loss": [[1], 'float32']}
else: else:
return {"logits": [[-1, 1], 'float32']} return {"logits": [[-1, 2], 'float32']}
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
if self._is_training: if self._is_training:
...@@ -81,3 +86,20 @@ class TaskParadigm(task_paradigm): ...@@ -81,3 +86,20 @@ class TaskParadigm(task_paradigm):
else: else:
return {'logits': logits} return {'logits': logits}
def postprocess(self, rt_outputs):
if not self._is_training:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册