diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py new file mode 100755 index 0000000000000000000000000000000000000000..30fac4dd7802d091e1ccb37f92f6477de28d2144 --- /dev/null +++ b/paddlepalm/mtl_controller.py @@ -0,0 +1,736 @@ +# -*- coding: UTF-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import importlib +import multiprocessing +from paddle import fluid +from paddle.fluid import layers +import yaml +import json +import logging +import time +import numpy as np + +from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint +from paddlepalm.utils.config_helper import PDConfig +from paddlepalm.utils.print_helper import print_dict +from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs +from paddlepalm.distribute import data_feeder + +from default_settings import * +from task_instance import TaskInstance, check_instances + + +DEBUG=False +VERBOSE=0 + +def _get_basename(f): + return os.path.splitext(f)[0] + + +def _get_suffix(f): + return os.path.splitext(f)[-1] + + +def _parse_yaml(f, asdict=True, support_cmd_line=False): + assert os.path.exists(f), "file {} not found.".format(f) + if support_cmd_line: + args = PDConfig(yaml_file=f, fuse_args=True) + args.build() + return args.asdict() if asdict else args + else: + if asdict: + with open(f, "r") as fin: + yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) + return yaml_config + else: + raise NotImplementedError() + + +def _parse_json(f, asdict=True, support_cmd_line=False): + assert os.path.exists(f), "file {} not found.".format(f) + if support_cmd_line: + args = PDConfig(json_file=f, fuse_args=support_cmd_line) + args.build() + return args.asdict() if asdict else args + else: + if asdict: + with open(f, "r") as fin: + config = json.load(fin) + return config + else: + raise NotImplementedError() + + +def _parse_list(string, astype=str): + assert isinstance(string, str), "{} is not a string.".format(string) + if ',' not in string: + return [astype(string)] + string = string.replace(',', ' ') + return [astype(i) for i in string.split()] + + +def _try_float(s): + try: + float(s) + return(float(s)) + except: + return s + + +def _check_conf(conf, checklist=None): + assert isinstance(conf, dict), "{} is not a dict.".format(conf) + ret = {} + for k,v in conf.items(): + if isinstance(v, str): + v = _try_float(v) + ret[k] = v + if checklist is not None: + for k, t in checklist: + assert k in ret, "required argument {} is NOT exist in config file.".format(k) + assert isintance(ret[k], t), "value type of argument {} should be {}".format(k, t) + return ret + + +# TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None +def _check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): + for name, attr in in_attr.items(): + assert name in out_attr, in_name+': '+name+' not found in '+out_name + if attr != out_attr[name]: + if strict: + raise ValueError(name+': shape or dtype not consistent!') + else: + logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) + + +def _merge_conf(conf1, conf2, conf1_first=True, strict=False): + assert isinstance(conf1, dict), "{} is not a dict.".format(conf1) + assert isinstance(conf2, dict), "{} is not a dict.".format(conf2) + base_conf = conf2 if conf1_first else conf1 + base_conf = base_conf.copy() + new_conf = conf1 if conf1_first else conf2 + + for k, v in new_conf.items(): + if k in base_conf: + if base_conf[k] != v: + raise Warning("value of argument {} has been updated to {}.".format(k, v)) + else: + if strict: + continue + + base_conf[k] = v + return base_conf + + +def _encode_inputs(inputs, scope_name, sep='/', cand_set=None): + outputs = {} + for k, v in inputs.items(): + if cand_set is not None: + if k in cand_set: + outputs[k] = v + if scope_name+sep+k in cand_set: + outputs[scope_name+sep+k] = v + else: + outputs[scope_name+sep+k] = v + return outputs + + +def _decode_inputs(inputs, scope_name, sep='/', keep_unk_keys=True): + outputs = {} + for name, value in inputs.items(): + # var for backbone are also available to tasks + if keep_unk_keys and sep not in name: + outputs[name] = value + # var for this inst + if name.startswith(scope_name+'/'): + outputs[name[len(scope_name+'/'):]] = value + return outputs + + +def _init_env(use_gpu): + if use_gpu: + place = fluid.CUDAPlace(0) + dev_count = fluid.core.get_cuda_device_count() + else: + place = fluid.CPUPlace() + dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + return fluid.Executor(place), dev_count + + +def _fit_attr(conf, fit_attr, strict=False): + for i, attr in fit_attr.items(): + if i not in conf: + if strict: + raise Exception('Argument {} is required to create a controller.'.format(i)) + else: + continue + conf[i] = attr(conf[i]) + return conf + + +def create_feed_batch_process_fn(net_inputs): + + def feed_batch_process_fn(data): + temp = {} + for q, var in net_inputs.items(): + if isinstance(var, str) or isinstance(var, unicode): + temp[var] = data[q] + else: + temp[var.name] = data[q] + return temp + + return feed_batch_process_fn + + +class Controller(object): + + def __init__(self, config, task_dir='.', for_train=True): + """ + Args: + config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径; + """ + + self._for_train = for_train + assert isinstance(config, str) or isinstance(config, dict), "a config dict or config file path is required to create a Controller." + + if isinstance(config, str): + mtl_conf = _parse_yaml(config, support_cmd_line=True) + else: + mtl_conf = config + + mtl_conf = _check_conf(mtl_conf) + mtl_conf = _fit_attr(mtl_conf, REQUIRED_ARGS, strict=True) + mtl_conf = _fit_attr(mtl_conf, OPTIONAL_ARGS, strict=False) + + exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) + self.exe = exe + self.dev_count = dev_count + + print_dict(mtl_conf, title='global configuration') + + # parse task instances and target tags + instnames = _parse_list(mtl_conf['task_instance']) + assert len(instnames) == len(set(instnames)), "repeated task_instance is NOT supported." + num_instances = len(instnames) + self.num_instances = num_instances + + instname_to_conf = {} + instname_to_id = {} + for id, instname in enumerate(instnames): + instpath = os.path.join(task_dir, instname+'.yaml') + conf = _parse_yaml(instpath, support_cmd_line=False) + # conf = _check_conf(conf, TASK_INSTANCE_REQUIRED_ARGS) + conf = _check_conf(conf) + temp_conf = _merge_conf(mtl_conf, conf, strict=True) + print_dict(temp_conf, title='{} configuration'.format(instname)) + conf = _merge_conf(mtl_conf, conf) + + instname_to_conf[instname] = conf + instname_to_id[instname] = id + + # prepare backbone + if 'backbone_config_path' in mtl_conf: + bb_conf = _parse_json(mtl_conf['backbone_config_path']) + bb_conf = _merge_conf(mtl_conf, bb_conf) + else: + bb_conf = mtl_conf + print_dict(bb_conf, title = 'backbone configuration'.format(instname)) + + bb_name = mtl_conf['backbone'] + bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name) + Backbone = getattr(bb_mod, 'Model') + + # create task instances + instances = [] + for name in instnames: + instances.append(TaskInstance(name, instname_to_id[name], instname_to_conf[name])) + + check_instances(instances) + + # parse target_tag + if 'target_tag' in mtl_conf: + target_tag = str(mtl_conf['target_tag']) + tags = _parse_list(target_tag, astype=int) + assert len(tags) == len(instnames), "number of target_tag is NOT consistent with that in task_instance." + for tag, inst in zip(tags, instances): + inst.is_target = tag + else: + tags = [i.is_target for i in instances] + num_targets = sum(tags) + num_auxes = num_instances - num_targets + + # parse mix ratios + if 'mix_ratio' in mtl_conf: + mix_ratio = str(mtl_conf['mix_ratio']) + mrs = _parse_list(mix_ratio, astype=float) + assert len(mrs) == num_instances, "number of mix_ratios is NOT consistent with num_instances." + else: + mrs = [1.0] * num_instances + + for mr, inst in zip(mrs, instances): + inst.mix_ratio = mr + + # parse task layer reuse tags + instname_to_reusehost = {i:i for i in instnames} + if 'task_reuse_tag' in mtl_conf: + tags = _parse_list(mtl_conf['task_reuse_tag'], astype=int) + assert len(tags) == num_targets, 'number of reuse_tags is NOT consistent with number of instances.' + else: + tags = [] + mapper = {} + for inst in instances: + history = set() + history.add(inst.name) + cur_inst = inst + while True: + if cur_inst.task_reuse_scope in history: + mapper[inst.name] = len(tags) + break + elif cur_inst.task_reuse_scope in mapper: + mapper[inst.name] = mapper[cur_inst.task_reuse_scope] + break + else: + cur_inst = name_to_instance[cur_inst.task_reuse_scope] + history.add(cur_inst.name) + + tags.append(mapper[inst.name]) + + for i in range(1, num_instances): + for j in range(i): + if tags[i] == tags[j]: + assert instances[i].Paradigm == \ + instances[j].Paradigm, \ + "paradigm of reuse tasks should be consistent" + instances[i].task_reuse_scope = instances[j].name + break + + self.instances = instances + self.mrs = mrs + self.Backbone = Backbone + self.bb_conf = bb_conf + self.bb_name = bb_name + + self.has_init_train = False + self.has_init_pred = False + + if self._for_train: + print("initialing for training...") + self._init_train() + self.has_init_train = True + + def _init_train(self): + + instances = self.instances + Backbone = self.Backbone + bb_conf = self.bb_conf + bb_name = self.bb_name + dev_count = self.dev_count + num_instances = len(instances) + mrs = self.mrs + + # set first_target/main task instance + main_inst = None + for inst in instances: + if inst.is_target: + main_inst = inst + inst.is_first_target = True + break + main_conf = main_inst.config + if not os.path.exists(main_conf['save_path']): + os.makedirs(main_conf['save_path']) + os.makedirs(os.path.join(main_conf['save_path'], 'ckpt')) + + # prepare backbone + train_backbone = Backbone(bb_conf, phase='train') + pred_backbone = Backbone(bb_conf, phase='pred') + + # create reader, task + # then check i/o across reader, backbone and task_layer + task_attrs = [] + pred_task_attrs = [] + for inst in instances: + train_reader = inst.Reader(inst.config, phase='train') + inst.reader['train'] = train_reader + train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) + inst.task_layer['train'] = train_parad + task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], inst.name) + task_attrs.append(task_attr_from_reader) + + _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') + _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') + _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') + + if inst.is_target: + if 'pred_file' not in inst.config: + inst.config['pred_file'] = '' + pred_reader = inst.Reader(inst.config, phase='pred') + pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) + inst.task_layer['pred'] = pred_parad + task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) + pred_task_attrs.append(task_attr_from_reader) + _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') + _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') + _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') + + # merge reader input attrs from backbone and task_instances + joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) + pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] + + if DEBUG: + print('----- for debug -----') + print('joint input names:') + print(joint_input_names) + print('joint input shape and dtypes:') + print(joint_shape_and_dtypes) + + # load data + for inst in instances: + print(inst.name+": preparing data...", end='') + inst.reader['train'].load_data() + print('ok!') + + # merge dataset iterators and create net input vars + iterators = [] + prefixes = [] + mrs = [] + for inst in instances: + iterators.append(inst.reader['train'].iterator()) + prefixes.append(inst.name) + mrs.append(inst.mix_ratio) + + joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') + self._joint_iterator_fn = joint_iterator_fn + + input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] + pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] + # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) + net_inputs = create_net_inputs(input_attrs, async=False) + self._net_inputs = net_inputs + + # build backbone and task layers + train_prog = fluid.default_main_program() + train_init_prog = fluid.default_startup_program() + bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') + assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) + + pred_prog = fluid.Program() + pred_init_prog = fluid.Program() + + with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): + pred_net_inputs = create_net_inputs(pred_input_attrs) + pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') + + fluid.framework.switch_main_program(train_prog) + fluid.framework.switch_startup_program(train_init_prog) + + task_output_vars = {} + for inst in instances: + task_inputs = {'backbone': bb_output_vars} + task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) + task_inputs['reader'] = task_inputs_from_reader + + scope = inst.task_reuse_scope + '/' + with fluid.unique_name.guard(scope): + output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) + output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} + old = len(task_output_vars) # for debug + task_output_vars.update(output_vars) + assert len(task_output_vars) - old == len(output_vars) # for debug + + # prepare predict vars for saving inference model + if inst.is_target: + + with fluid.program_guard(pred_prog, pred_init_prog): + cur_inputs = _decode_inputs(pred_net_inputs, inst.name) + inst.pred_input = cur_inputs + pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} + scope = inst.task_reuse_scope + '/' + with fluid.unique_name.guard(scope): + inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) + + + bb_fetches = {k: v.name for k,v in bb_output_vars.items()} + task_fetches = {k: v.name for k,v in task_output_vars.items()} + fetches = task_fetches + fetches['__task_id'] = net_inputs['__task_id'].name + + # compute loss + task_id_var = net_inputs['__task_id'] + task_id_vec = fluid.one_hot(task_id_var, num_instances) + losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) + loss = layers.reduce_sum(task_id_vec * losses) + + main_reader = main_inst.reader['train'] + + num_examples = main_reader.num_examples + for inst in instances: + max_train_steps = int(main_conf['num_epochs']* inst.mix_ratio * (num_examples // main_conf['batch_size'] // dev_count)) + if inst.is_target: + print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) + inst.steps_pur_epoch = inst.reader['train'].num_examples // main_conf['batch_size'] // dev_count + inst.expected_train_steps = max_train_steps + + global_max_train_steps = int(main_conf['num_epochs'] * sum(mrs) * (num_examples // main_conf['batch_size'] // dev_count)) + print('Estimated overall train steps {}.'.format(global_max_train_steps)) + + if 'warmup_proportion' in main_conf and main_conf['warmup_proportion'] > 0: + warmup_steps = int(global_max_train_steps * main_conf['warmup_proportion']) + print('Warmup steps: '+str(warmup_steps)) + else: + warmup_steps = 0 + + # build optimizer + if 'optimizer' in main_conf: + optim_mod = importlib.import_module(OPTIMIZER_DIR + '.' + main_conf['optimizer']) + optimize = getattr(optim_mod, OPTIMIZE_METHOD) + optimize(loss, main_conf, max_train_steps, warmup_steps, fluid.default_main_program()) + + loss.persistable = True + if main_conf.get('use_ema', False): + assert 'ema_decay' in main_conf, "ema_decay should be set when use_ema is enabled." + ema = fluid.optimizer.ExponentialMovingAverage(main_conf['ema_decay']) + ema.update() + + # prepare for train + self.train_backbone = train_backbone + self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) + self.saver_program = fluid.default_main_program() + + self.main_inst = main_inst + self.fetches = fetches + self.has_init_train = True + self.has_init_pred = True + + self.exe.run(fluid.default_startup_program()) + print("\nRandomly initialize parameters...\n") + + def _init_pred(self, instance, infer_model_path): + inst = instance + if 'pred_output_path' not in inst.config: + inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name) + + if not os.path.exists(inst.config['pred_output_path']): + os.makedirs(inst.config['pred_output_path']) + + pred_backbone = self.Backbone(self.bb_conf, phase='pred') + pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf) + inst.task_layer['pred'] = pred_parad + pred_joint_input_names, pred_joint_shape_and_dtypes, name_to_position = merge_input_attrs( + pred_backbone.inputs_attr, inst.task_layer['pred'].inputs_attrs['reader'], + insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + + pred_prog = inst.load(infer_model_path) + pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() + if inst.reader['pred'] is None: + pred_reader = inst.Reader(inst.config, phase='pred') + inst.reader['pred'] = pred_reader + return pred_prog + + def load_pretrain(self, pretrain_path=None): + # load pretrain model (or ckpt) + if pretrain_path is None: + assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set." + pretrain_path = self.main_conf['pretrain_path'] + + init_pretraining_params( + self.exe, + pretrain_path, + main_program=fluid.default_startup_program()) + + + def train(self): + + if not self.has_init_train: + self._init_train() + self.has_init_train = True + + instances = self.instances + num_instances = self.num_instances + main_inst = self.main_inst + main_conf = main_inst.config + + backbone = self.train_backbone + train_program = self.train_program + saver_program = self.saver_program + fetches = self.fetches + + finish = [] + for inst in instances: + if inst.is_target: + if inst.expected_train_steps > 0: + finish.append(False) + else: + finish.append(True) + print(inst.name+': train finished!') + inst.save() + + def train_finish(): + for inst in instances: + if inst.is_target: + if not inst.train_finish: + return False + return True + + # do training + fetch_names, fetch_list = zip(*fetches.items()) + + main_step = 0 # only count for main task + global_step = 0 # count for all tasks + epoch = 0 + time_begin = time.time() + backbone_buffer = [] + + feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs) + distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn) + + # palm.distribute.reader(self._joint_iterator_fn, self._net_inputs, prefetch_steps=2) + + while not train_finish(): + feed, mask = next(distribute_feeder) + rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) + while mask.pop() == False: + rt_outputs.pop() + + rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} + rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() + rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id + cur_task = instances[rt_task_id] + + backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} + backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) + + task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} + instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) + + global_step += 1 + cur_task.cur_train_step += 1 + + cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch + if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: + cur_task.save(suffix='.step'+str(cur_task_global_step)) + + if global_step % main_conf.get('print_every_n_steps', 5) == 0: + loss = rt_outputs[cur_task.name+'/loss'] + loss = np.mean(np.squeeze(loss)).tolist() + + time_end = time.time() + time_cost = time_end - time_begin + + print("Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( + global_step, cur_task.name, cur_task.cur_train_step, cur_task.steps_pur_epoch, cur_task.cur_train_epoch, + loss, main_conf.get('print_every_n_steps', 5) / time_cost)) + time_begin = time.time() + + if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: + print(cur_task.name+': train finished!') + cur_task.save() + + if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0: + save_path = os.path.join(main_conf['save_path'], 'ckpt', + "step_" + str(global_step)) + fluid.io.save_persistables(self.exe, save_path, saver_program) + print('checkpoint has been saved at '+save_path) + + save_path = os.path.join(main_conf['save_path'], 'ckpt', + "step_" + str(global_step)) + fluid.io.save_persistables(self.exe, save_path, saver_program) + print('checkpoint has been saved at '+save_path) + + print("ALL tasks train finished, exiting...") + + def pred(self, task_instance, inference_model_dir=None): + if self._for_train: + raise Exception('This controller is a trainer. Please build a new controller with for_train=False for predicting.') + + assert isinstance(task_instance, str) + if isinstance(inference_model_dir, str): + assert os.path.exists(inference_model_dir), inference_model_dir+" not found." + # if not self.has_init_pred and inference_model_dir is None: + # raise ValueError('infer_model_path is required for prediction.') + if inference_model_dir is None: + assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." + inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') + + instance = None + for inst in self.instances: + if inst.name == task_instance: + instance = inst + break + + if instance is None: + raise ValueError(task_instance + ' is not a valid task_instance.') + + pred_prog = self._init_pred(instance, inference_model_dir) + + inst = instance + print(inst.name+": loading data...") + inst.reader['pred'].load_data() + fetch_names, fetch_vars = inst.pred_fetch_list + + print('predicting...') + feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) + distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1) + + buf = [] + for feed, mask in distribute_feeder: + print('before run') + rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) + print('after run') + splited_rt_outputs = [] + for item in rt_outputs: + splited_rt_outputs.append(np.split(item, len(mask))) + + # assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)] + print(mask) + + while mask.pop() == False: + print(mask) + for item in splited_rt_outputs: + item.pop() + rt_outputs = [] + print('cancat') + for item in splited_rt_outputs: + rt_outputs.append(np.concatenate(item)) + + rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} + inst.postprocess(rt_outputs, phase='pred') + print('leave feeder') + if inst.task_layer['pred'].epoch_inputs_attrs: + reader_outputs = inst.reader['pred'].get_epoch_outputs() + else: + reader_outputs = None + print('epoch postprocess') + inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') + + +if __name__ == '__main__': + assert len(sys.argv) == 2, "Usage: python mtl_controller.py " + conf_path = sys.argv[1] + del sys.argv[1] + controller = Controller(conf_path) + if controller.main_conf['do_train']: + controller.train() + + + +__all__ = ["Controller"] + + +