未验证 提交 1acb2cec 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #20 from xixiaoyao/master

fix bugs
......@@ -154,7 +154,11 @@ class task_paradigm(object):
raise NotImplementedError()
def build(self, inputs):
@property
def epoch_inputs_attrs(self):
return {}
def build(self, inputs, scope_name=""):
"""建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。
Args:
inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象
......@@ -168,6 +172,6 @@ class task_paradigm(object):
"""每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。"""
pass
def post_postprocess(self, global_buffer):
def epoch_postprocess(self, post_inputs):
pass
......@@ -182,21 +182,17 @@ def _fit_attr(conf, fit_attr, strict=False):
class Controller(object):
def __init__(self, config=None, task_dir='.', for_train=True):
def __init__(self, config, task_dir='.', for_train=True):
"""
Args:
config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径;
"""
self._for_train = for_train
# default mtl_conf
# if config is None and config_path is None:
# raise ValueError('For config and config_path, at least one of them should be set.')
assert isinstance(config, str) or isinstance(config, dict), "a config dict or config file path is required to create a Controller."
if isinstance(config, str):
mtl_conf = _parse_yaml(config, support_cmd_line=True)
# if config is not None:
# mtl_conf = _merge_conf(config, mtl_conf)
else:
mtl_conf = config
......@@ -518,6 +514,11 @@ class Controller(object):
def _init_pred(self, instance, infer_model_path):
inst = instance
if 'pred_output_path' not in inst.config:
inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name)
if not os.path.exists(inst.config['pred_output_path']):
os.makedirs(inst.config['pred_output_path'])
pred_backbone = self.Backbone(self.bb_conf, phase='pred')
pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf)
......@@ -563,7 +564,12 @@ class Controller(object):
finish = []
for inst in instances:
if inst.is_target:
finish.append(False)
if inst.expected_train_steps > 0:
finish.append(False)
else:
finish.append(True)
print(inst.name+': train finished!')
inst.save()
def train_finish():
for inst in instances:
......@@ -641,9 +647,11 @@ class Controller(object):
pred_prog = self._init_pred(instance, inference_model_dir)
inst = instance
print(inst.name+": loading data...")
inst.reader['pred'].load_data()
fetch_names, fetch_vars = inst.pred_fetch_list
print('predicting...')
mapper = {k:v for k,v in inst.pred_input}
buf = []
for feed in inst.reader['pred'].iterator():
......@@ -653,12 +661,13 @@ class Controller(object):
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred')
reader_outputs = inst.reader['pred'].get_epoch_outputs()
if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs()
else:
reader_outputs = None
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
if __name__ == '__main__':
assert len(sys.argv) == 2, "Usage: python mtl_controller.py <mtl_conf_path>"
conf_path = sys.argv[1]
......
......@@ -93,10 +93,6 @@ class Reader(reader):
for batch in self._data_generator():
yield list_to_dict(batch)
def get_epoch_outputs(self):
return {'examples': self._reader.get_examples(self._phase),
'features': self._reader.get_features(self._phase)}
@property
def num_examples(self):
return self._reader.get_num_examples(phase=self._phase)
......
......@@ -14,8 +14,10 @@
# limitations under the License.
import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
class TaskParadigm(task_paradigm):
'''
......@@ -35,6 +37,8 @@ class TaskParadigm(task_paradigm):
self._dropout_prob = config['dropout_prob']
else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._preds = []
@property
def inputs_attrs(self):
......@@ -78,3 +82,20 @@ class TaskParadigm(task_paradigm):
else:
return {"logits":logits}
def postprocess(self, rt_outputs):
if not self._is_training:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
......@@ -14,8 +14,10 @@
# limitations under the License.
import paddle.fluid as fluid
from paddlepalm.interface import task_paradigm
from paddle.fluid import layers
from paddlepalm.interface import task_paradigm
import numpy as np
import os
class TaskParadigm(task_paradigm):
'''
......@@ -35,6 +37,9 @@ class TaskParadigm(task_paradigm):
else:
self._dropout_prob = backbone_config.get('hidden_dropout_prob', 0.0)
self._pred_output_path = config.get('pred_output_path', None)
self._preds = []
@property
def inputs_attrs(self):
......@@ -50,7 +55,7 @@ class TaskParadigm(task_paradigm):
if self._is_training:
return {"loss": [[1], 'float32']}
else:
return {"logits": [[-1, 1], 'float32']}
return {"logits": [[-1, 2], 'float32']}
def build(self, inputs, scope_name=""):
if self._is_training:
......@@ -81,3 +86,20 @@ class TaskParadigm(task_paradigm):
else:
return {'logits': logits}
def postprocess(self, rt_outputs):
if not self._is_training:
logits = rt_outputs['logits']
preds = np.argmax(logits, -1)
self._preds.extend(preds.tolist())
def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds:
writer.write(str(p)+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册