# -*- coding: UTF-8 -*- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import os import sys import importlib import multiprocessing from paddle import fluid from paddle.fluid import layers import yaml import json import logging import time import numpy as np from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs from paddlepalm.distribute import data_feeder from default_settings import * from task_instance import TaskInstance, check_instances DEBUG=False VERBOSE=0 def _get_basename(f): return os.path.splitext(f)[0] def _get_suffix(f): return os.path.splitext(f)[-1] def _parse_yaml(f, asdict=True, support_cmd_line=False): assert os.path.exists(f), "file {} not found.".format(f) if support_cmd_line: args = PDConfig(yaml_file=f, fuse_args=True) args.build() return args.asdict() if asdict else args else: if asdict: with open(f, "r") as fin: yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) return yaml_config else: raise NotImplementedError() def _parse_json(f, asdict=True, support_cmd_line=False): assert os.path.exists(f), "file {} not found.".format(f) if support_cmd_line: args = PDConfig(json_file=f, fuse_args=support_cmd_line) args.build() return args.asdict() if asdict else args else: if asdict: with open(f, "r") as fin: config = json.load(fin) return config else: raise NotImplementedError() def _parse_list(string, astype=str): assert isinstance(string, str), "{} is not a string.".format(string) if ',' not in string: return [astype(string)] string = string.replace(',', ' ') return [astype(i) for i in string.split()] def _try_float(s): try: float(s) return(float(s)) except: return s def _check_conf(conf, checklist=None): assert isinstance(conf, dict), "{} is not a dict.".format(conf) ret = {} for k,v in conf.items(): if isinstance(v, str): v = _try_float(v) ret[k] = v if checklist is not None: for k, t in checklist: assert k in ret, "required argument {} is NOT exist in config file.".format(k) assert isintance(ret[k], t), "value type of argument {} should be {}".format(k, t) return ret # TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None def _check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): for name, attr in in_attr.items(): assert name in out_attr, in_name+': '+name+' not found in '+out_name if attr != out_attr[name]: if strict: raise ValueError(name+': shape or dtype not consistent!') else: logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) def _merge_conf(conf1, conf2, conf1_first=True, strict=False): assert isinstance(conf1, dict), "{} is not a dict.".format(conf1) assert isinstance(conf2, dict), "{} is not a dict.".format(conf2) base_conf = conf2 if conf1_first else conf1 base_conf = base_conf.copy() new_conf = conf1 if conf1_first else conf2 for k, v in new_conf.items(): if k in base_conf: if base_conf[k] != v: raise Warning("value of argument {} has been updated to {}.".format(k, v)) else: if strict: continue base_conf[k] = v return base_conf def _encode_inputs(inputs, scope_name, sep='/', cand_set=None): outputs = {} for k, v in inputs.items(): if cand_set is not None: if k in cand_set: outputs[k] = v if scope_name+sep+k in cand_set: outputs[scope_name+sep+k] = v else: outputs[scope_name+sep+k] = v return outputs def _decode_inputs(inputs, scope_name, sep='/', keep_unk_keys=True): outputs = {} for name, value in inputs.items(): # var for backbone are also available to tasks if keep_unk_keys and sep not in name: outputs[name] = value # var for this inst if name.startswith(scope_name+'/'): outputs[name[len(scope_name+'/'):]] = value return outputs def _init_env(use_gpu): if use_gpu: place = fluid.CUDAPlace(0) dev_count = fluid.core.get_cuda_device_count() else: place = fluid.CPUPlace() dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) return fluid.Executor(place), dev_count def _fit_attr(conf, fit_attr, strict=False): for i, attr in fit_attr.items(): if i not in conf: if strict: raise Exception('Argument {} is required to create a controller.'.format(i)) else: continue conf[i] = attr(conf[i]) return conf def create_feed_batch_process_fn(net_inputs): def feed_batch_process_fn(data): temp = {} for q, var in net_inputs.items(): if isinstance(var, str) or isinstance(var, unicode): temp[var] = data[q] else: temp[var.name] = data[q] return temp return feed_batch_process_fn class Controller(object): def __init__(self, config, task_dir='.', for_train=True): """ Args: config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径; """ self._for_train = for_train assert isinstance(config, str) or isinstance(config, dict), "a config dict or config file path is required to create a Controller." if isinstance(config, str): mtl_conf = _parse_yaml(config, support_cmd_line=True) else: mtl_conf = config mtl_conf = _check_conf(mtl_conf) mtl_conf = _fit_attr(mtl_conf, REQUIRED_ARGS, strict=True) mtl_conf = _fit_attr(mtl_conf, OPTIONAL_ARGS, strict=False) exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) self.exe = exe self.dev_count = dev_count print_dict(mtl_conf, title='global configuration') # parse task instances and target tags instnames = _parse_list(mtl_conf['task_instance']) assert len(instnames) == len(set(instnames)), "repeated task_instance is NOT supported." num_instances = len(instnames) self.num_instances = num_instances instname_to_conf = {} instname_to_id = {} for id, instname in enumerate(instnames): instpath = os.path.join(task_dir, instname+'.yaml') conf = _parse_yaml(instpath, support_cmd_line=False) # conf = _check_conf(conf, TASK_INSTANCE_REQUIRED_ARGS) conf = _check_conf(conf) temp_conf = _merge_conf(mtl_conf, conf, strict=True) print_dict(temp_conf, title='{} configuration'.format(instname)) conf = _merge_conf(mtl_conf, conf) instname_to_conf[instname] = conf instname_to_id[instname] = id # prepare backbone if 'backbone_config_path' in mtl_conf: bb_conf = _parse_json(mtl_conf['backbone_config_path']) bb_conf = _merge_conf(mtl_conf, bb_conf) else: bb_conf = mtl_conf print_dict(bb_conf, title = 'backbone configuration'.format(instname)) bb_name = mtl_conf['backbone'] bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name) Backbone = getattr(bb_mod, 'Model') # create task instances instances = [] for name in instnames: instances.append(TaskInstance(name, instname_to_id[name], instname_to_conf[name])) check_instances(instances) # parse target_tag if 'target_tag' in mtl_conf: target_tag = str(mtl_conf['target_tag']) tags = _parse_list(target_tag, astype=int) assert len(tags) == len(instnames), "number of target_tag is NOT consistent with that in task_instance." for tag, inst in zip(tags, instances): inst.is_target = tag else: tags = [i.is_target for i in instances] num_targets = sum(tags) num_auxes = num_instances - num_targets # parse mix ratios if 'mix_ratio' in mtl_conf: mix_ratio = str(mtl_conf['mix_ratio']) mrs = _parse_list(mix_ratio, astype=float) assert len(mrs) == num_instances, "number of mix_ratios is NOT consistent with num_instances." else: mrs = [1.0] * num_instances for mr, inst in zip(mrs, instances): inst.mix_ratio = mr # parse task layer reuse tags instname_to_reusehost = {i:i for i in instnames} if 'task_reuse_tag' in mtl_conf: tags = _parse_list(mtl_conf['task_reuse_tag'], astype=int) assert len(tags) == num_targets, 'number of reuse_tags is NOT consistent with number of instances.' else: tags = [] mapper = {} for inst in instances: history = set() history.add(inst.name) cur_inst = inst while True: if cur_inst.task_reuse_scope in history: mapper[inst.name] = len(tags) break elif cur_inst.task_reuse_scope in mapper: mapper[inst.name] = mapper[cur_inst.task_reuse_scope] break else: cur_inst = name_to_instance[cur_inst.task_reuse_scope] history.add(cur_inst.name) tags.append(mapper[inst.name]) for i in range(1, num_instances): for j in range(i): if tags[i] == tags[j]: assert instances[i].Paradigm == \ instances[j].Paradigm, \ "paradigm of reuse tasks should be consistent" instances[i].task_reuse_scope = instances[j].name break self.instances = instances self.mrs = mrs self.Backbone = Backbone self.bb_conf = bb_conf self.bb_name = bb_name self.has_init_train = False self.has_init_pred = False if self._for_train: print("initialing for training...") self._init_train() self.has_init_train = True def _init_train(self): instances = self.instances Backbone = self.Backbone bb_conf = self.bb_conf bb_name = self.bb_name dev_count = self.dev_count num_instances = len(instances) mrs = self.mrs # set first_target/main task instance main_inst = None for inst in instances: if inst.is_target: main_inst = inst inst.is_first_target = True break main_conf = main_inst.config if not os.path.exists(main_conf['save_path']): os.makedirs(main_conf['save_path']) os.makedirs(os.path.join(main_conf['save_path'], 'ckpt')) # prepare backbone train_backbone = Backbone(bb_conf, phase='train') pred_backbone = Backbone(bb_conf, phase='pred') # create reader, task # then check i/o across reader, backbone and task_layer task_attrs = [] pred_task_attrs = [] for inst in instances: train_reader = inst.Reader(inst.config, phase='train') inst.reader['train'] = train_reader train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) inst.task_layer['train'] = train_parad task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], inst.name) task_attrs.append(task_attr_from_reader) _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') if inst.is_target: if 'pred_file' not in inst.config: inst.config['pred_file'] = '' pred_reader = inst.Reader(inst.config, phase='pred') pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) inst.task_layer['pred'] = pred_parad task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) pred_task_attrs.append(task_attr_from_reader) _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') # merge reader input attrs from backbone and task_instances joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] if DEBUG: print('----- for debug -----') print('joint input names:') print(joint_input_names) print('joint input shape and dtypes:') print(joint_shape_and_dtypes) # load data for inst in instances: print(inst.name+": preparing data...", end='') inst.reader['train'].load_data() print('ok!') # merge dataset iterators and create net input vars iterators = [] prefixes = [] mrs = [] for inst in instances: iterators.append(inst.reader['train'].iterator()) prefixes.append(inst.name) mrs.append(inst.mix_ratio) joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') self._joint_iterator_fn = joint_iterator_fn input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) net_inputs = create_net_inputs(input_attrs, async=False) self._net_inputs = net_inputs # build backbone and task layers train_prog = fluid.default_main_program() train_init_prog = fluid.default_startup_program() bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) pred_prog = fluid.Program() pred_init_prog = fluid.Program() with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): pred_net_inputs = create_net_inputs(pred_input_attrs) pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') fluid.framework.switch_main_program(train_prog) fluid.framework.switch_startup_program(train_init_prog) task_output_vars = {} for inst in instances: task_inputs = {'backbone': bb_output_vars} task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) task_inputs['reader'] = task_inputs_from_reader scope = inst.task_reuse_scope + '/' with fluid.unique_name.guard(scope): output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} old = len(task_output_vars) # for debug task_output_vars.update(output_vars) assert len(task_output_vars) - old == len(output_vars) # for debug # prepare predict vars for saving inference model if inst.is_target: with fluid.program_guard(pred_prog, pred_init_prog): cur_inputs = _decode_inputs(pred_net_inputs, inst.name) inst.pred_input = cur_inputs pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} scope = inst.task_reuse_scope + '/' with fluid.unique_name.guard(scope): inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) bb_fetches = {k: v.name for k,v in bb_output_vars.items()} task_fetches = {k: v.name for k,v in task_output_vars.items()} fetches = task_fetches fetches['__task_id'] = net_inputs['__task_id'].name # compute loss task_id_var = net_inputs['__task_id'] task_id_vec = layers.one_hot(task_id_var, num_instances) losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) loss = layers.reduce_sum(task_id_vec * losses) main_reader = main_inst.reader['train'] num_examples = main_reader.num_examples for inst in instances: max_train_steps = int(main_conf['num_epochs']* inst.mix_ratio * (num_examples // main_conf['batch_size'] // dev_count)) if inst.is_target: print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) inst.steps_pur_epoch = inst.reader['train'].num_examples // main_conf['batch_size'] // dev_count inst.expected_train_steps = max_train_steps global_max_train_steps = int(main_conf['num_epochs'] * sum(mrs) * (num_examples // main_conf['batch_size'] // dev_count)) print('Estimated overall train steps {}.'.format(global_max_train_steps)) if 'warmup_proportion' in main_conf and main_conf['warmup_proportion'] > 0: warmup_steps = int(global_max_train_steps * main_conf['warmup_proportion']) print('Warmup steps: '+str(warmup_steps)) else: warmup_steps = 0 # build optimizer if 'optimizer' in main_conf: optim_mod = importlib.import_module(OPTIMIZER_DIR + '.' + main_conf['optimizer']) optimize = getattr(optim_mod, OPTIMIZE_METHOD) optimize(loss, main_conf, max_train_steps, warmup_steps, fluid.default_main_program()) loss.persistable = True if main_conf.get('use_ema', False): assert 'ema_decay' in main_conf, "ema_decay should be set when use_ema is enabled." ema = fluid.optimizer.ExponentialMovingAverage(main_conf['ema_decay']) ema.update() # prepare for train self.train_backbone = train_backbone self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) self.saver_program = fluid.default_main_program() self.main_inst = main_inst self.fetches = fetches self.has_init_train = True self.has_init_pred = True self.exe.run(fluid.default_startup_program()) print("\nRandomly initialize parameters...\n") def _init_pred(self, instance, infer_model_path): inst = instance if 'pred_output_path' not in inst.config: inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name) if not os.path.exists(inst.config['pred_output_path']): os.makedirs(inst.config['pred_output_path']) pred_backbone = self.Backbone(self.bb_conf, phase='pred') pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf) inst.task_layer['pred'] = pred_parad pred_joint_input_names, pred_joint_shape_and_dtypes, name_to_position = merge_input_attrs( pred_backbone.inputs_attr, inst.task_layer['pred'].inputs_attrs['reader'], insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_prog = inst.load(infer_model_path) pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() if inst.reader['pred'] is None: pred_reader = inst.Reader(inst.config, phase='pred') inst.reader['pred'] = pred_reader return pred_prog def load_pretrain(self, pretrain_path=None): # load pretrain model (or ckpt) if pretrain_path is None: assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set." pretrain_path = self.main_conf['pretrain_path'] init_pretraining_params( self.exe, pretrain_path, main_program=fluid.default_startup_program()) def train(self): if not self.has_init_train: self._init_train() self.has_init_train = True instances = self.instances num_instances = self.num_instances main_inst = self.main_inst main_conf = main_inst.config backbone = self.train_backbone train_program = self.train_program saver_program = self.saver_program fetches = self.fetches finish = [] for inst in instances: if inst.is_target: if inst.expected_train_steps > 0: finish.append(False) else: finish.append(True) print(inst.name+': train finished!') inst.save() def train_finish(): for inst in instances: if inst.is_target: if not inst.train_finish: return False return True # do training fetch_names, fetch_list = zip(*fetches.items()) main_step = 0 # only count for main task global_step = 0 # count for all tasks epoch = 0 time_begin = time.time() backbone_buffer = [] feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs) distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn) # palm.distribute.reader(self._joint_iterator_fn, self._net_inputs, prefetch_steps=2) while not train_finish(): feed, mask = next(distribute_feeder) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) while mask.pop() == False: rt_outputs.pop() rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id cur_task = instances[rt_task_id] backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) global_step += 1 cur_task.cur_train_step += 1 if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: cur_task.save(suffix='.step'+str(cur_task.cur_train_step)) if global_step % main_conf.get('print_every_n_steps', 5) == 0: loss = rt_outputs[cur_task.name+'/loss'] loss = np.mean(np.squeeze(loss)).tolist() time_end = time.time() time_cost = time_end - time_begin print("Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( global_step, cur_task.name, cur_task.cur_train_step, cur_task.steps_pur_epoch, cur_task.cur_train_epoch, loss, main_conf.get('print_every_n_steps', 5) / time_cost)) time_begin = time.time() if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: print(cur_task.name+': train finished!') cur_task.save() if 'save_every_n_steps' in main_conf and global_step % main_conf['save_every_n_steps'] == 0: save_path = os.path.join(main_conf['save_path'], 'ckpt', "step_" + str(global_step)) fluid.io.save_persistables(self.exe, save_path, saver_program) print('checkpoint has been saved at '+save_path) save_path = os.path.join(main_conf['save_path'], 'ckpt', "step_" + str(global_step)) fluid.io.save_persistables(self.exe, save_path, saver_program) print('checkpoint has been saved at '+save_path) print("ALL tasks train finished, exiting...") def pred(self, task_instance, inference_model_dir=None): if self._for_train: raise Exception('This controller is a trainer. Please build a new controller with for_train=False for predicting.') assert isinstance(task_instance, str) if isinstance(inference_model_dir, str): assert os.path.exists(inference_model_dir), inference_model_dir+" not found." # if not self.has_init_pred and inference_model_dir is None: # raise ValueError('infer_model_path is required for prediction.') if inference_model_dir is None: assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') instance = None for inst in self.instances: if inst.name == task_instance: instance = inst break if instance is None: raise ValueError(task_instance + ' is not a valid task_instance.') pred_prog = self._init_pred(instance, inference_model_dir) inst = instance print(inst.name+": loading data...") inst.reader['pred'].load_data() fetch_names, fetch_vars = inst.pred_fetch_list print('predicting...') feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1) buf = [] for feed, mask in distribute_feeder: print('before run') rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) print('after run') splited_rt_outputs = [] for item in rt_outputs: splited_rt_outputs.append(np.split(item, len(mask))) # assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)] print(mask) while mask.pop() == False: print(mask) for item in splited_rt_outputs: item.pop() rt_outputs = [] print('cancat') for item in splited_rt_outputs: rt_outputs.append(np.concatenate(item)) rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} inst.postprocess(rt_outputs, phase='pred') print('leave feeder') if inst.task_layer['pred'].epoch_inputs_attrs: reader_outputs = inst.reader['pred'].get_epoch_outputs() else: reader_outputs = None print('epoch postprocess') inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') if __name__ == '__main__': assert len(sys.argv) == 2, "Usage: python mtl_controller.py " conf_path = sys.argv[1] del sys.argv[1] controller = Controller(conf_path) if controller.main_conf['do_train']: controller.train() __all__ = ["Controller"]