From f29142d1f7aa5303a6e9667360b9314b4a9d1f0b Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Thu, 16 Jan 2020 15:24:37 +0800 Subject: [PATCH] add multihead trainer --- "demo/demo2/\\" | 276 +++++++++ demo/demo2/log.txt | 914 ++++++++++++++++++++++++++++++ demo/demo2/run.py | 54 +- paddlepalm/multihead_trainer.py | 85 ++- paddlepalm/trainer.py | 33 +- paddlepalm/utils/reader_helper.py | 31 +- 6 files changed, 1299 insertions(+), 94 deletions(-) create mode 100644 "demo/demo2/\\" create mode 100644 demo/demo2/log.txt diff --git "a/demo/demo2/\\" "b/demo/demo2/\\" new file mode 100644 index 0000000..0217a5c --- /dev/null +++ "b/demo/demo2/\\" @@ -0,0 +1,276 @@ + +from paddle import fluid +from paddle.fluid import layers +from paddlepalm.distribute import gpu_dev_count, cpu_dev_count +from paddlepalm import Trainer +from paddlepalm.utils import reader_helper +import time + +dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count +VERBOSE=False + + +class MultiHeadTrainer(Trainer): + + def __init__(self, trainers, reuse_flags=None): + if reuse_flags is not None: + assert len(reuse_flags) == len(trainers) + + self._trainers = trainers + + self._train_init = False + self._predict_init = False + self._feeded_var_names = None + self._cur_train_step = 0 + self._target_vars = None + + self._inputname_to_varname = {} + self._pred_input_name_list = [] + self._pred_input_varname_list = [] + self._pred_fetch_name_list = [] + self._pred_fetch_var_list = [] + + self._exe = None + + self._save_protocol = { + 'input_names': 'self._pred_input_name_list', + 'input_varnames': 'self._pred_input_varname_list', + 'fetch_list': 'self._pred_fetch_name_list'} + + self._check_save = lambda: False + for t in self._trainers: + t._set_multitask() + + def build_forward(self, backbone, heads): + + if isinstance(heads, list): + head_dict = {k.name: v for k,v in zip(self._trainers, heads)} + elif isinstance(heads, dict): + head_dict = heads + else: + raise ValueError() + + num_heads = len(self._trainers) + assert len(head_dict) == num_heads + + for t in self._trainers: + assert t.name in head_dict, "expected: {}, exists: {}".format(t.name, head_dict.keys()) + + train_prog = fluid.Program() + train_init_prog = fluid.Program() + self._train_prog = train_prog + self._train_init_prog = train_init_prog + + def get_loss(i): + head = head_dict[self._trainers[i].name] + # loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog) + loss_var = self._trainers[i].build_forward(backbone, head) + return loss_var + + # task_fns = {} + # for i in range(num_heads): + + # def task_loss(): + # task_id = i + # return lambda: get_loss(task_id) + + # task_fns[i] = task_loss() + + + # task_fns = {i: lambda: get_loss(i) for i in range(num_heads)} + task_fns = {i: lambda i=i: get_loss(i) for i in range(num_heads)} + + with fluid.program_guard(train_prog, train_init_prog): + task_id_var = fluid.data(name="__task_id",shape=[1],dtype='int64') + task_id_var += 0 + # task_id_var = fluid.layers.fill_constant(shape=[1],dtype='int64', value=1) + # print(task_id_var.name) + + loss_var = layers.switch_case( + branch_index=task_id_var, + branch_fns=task_fns + ) + self._task_id_var = task_id_var + self._loss_var = loss_var + self._fetch_list = [loss_var.name] + for b in train_prog.blocks: + for var in b.vars: + pass + # if 'task_id' in var: + # print(var) + # exit() + # print(var) + return loss_var + + def fit_readers(self, reader_dict): + raise NotImplementedError() + + def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'): + + if isinstance(readers, list): + reader_dict = {k.name: v for k,v in zip(self._trainers, readers)} + elif isinstance(readers, dict): + reader_dict = readers + else: + raise ValueError() + + num_heads = len(self._trainers) + assert len(reader_dict) == num_heads + + trainer_dict = {t.name: t for t in self._trainers} + assert sampling_reference in trainer_dict + + trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference]) + base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch + + input_names = [] + name_to_pos = [] + joint_shape_and_dtypes = [] + iterators = [] + prefixes = [] + mrs = [] + net_inputs = [] + global_steps = 0 + for t in self._trainers: + assert t.name in reader_dict + assert reader_dict[t.name].num_epochs is None, "{}: num_epochs is not None. \ + To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(t.name) + # print(num_epochs, t.mix_ratio, base_steps_pur_epoch) + max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch) + if not t._as_auxilary: + print('{}: expected train steps {}.'.format(t.name, max_train_steps)) + global_steps += max_train_steps + if t.name != sampling_reference: + t.fit_reader(reader_dict[t.name]) + net_inputs.append(t._net_inputs) + prefixes.append(t.name) + mrs.append(t.mix_ratio) + iterators.append(t._raw_iterator_fn()) + input_names.append(t._input_names) + name_to_pos.append(t._name_to_position) + joint_shape_and_dtypes.append(t._shape_and_dtypes) + + print('Estimated overall train steps {}.'.format(global_steps)) + self._overall_train_steps = global_steps + + iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ + mrs, input_names, name_to_pos, dev_count=dev_count) + feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) + + if gpu_dev_count > 1: + distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn) + else: + distribute_feeder_fn = iterator_fn + + if phase == 'train': + self._train_reader = distribute_feeder_fn() + self._feed_batch_process_fn = feed_batch_process_fn + elif phase == 'predict': + self._predict_reader = distribute_feeder_fn() + self._pred_feed_batch_process_fn = feed_batch_process_fn + + def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5): + iterator = self._train_reader + self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) + + save_type = save_type.split(',') + if 'predict' in save_type: + assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model." + assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.' + save_predict = True + if not os.path.exists(save_path): + os.makedirs(save_path) + else: + save_predict = False + + if 'ckpt' in save_type: + if save_path is not None and save_steps is not None: + save_ckpt = True + if not os.path.exists(save_path): + os.makedirs(save_path) + else: + "WARNING: save_path or save_steps is not set, model will not be saved during training." + save_ckpt = False + else: + save_ckpt = False + + time_begin = time.time() + for feed in iterator: + # batch, task_id = feed + rt_outputs, task_id = self.train_one_step(feed) + + task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')} + self._task_head.batch_postprocess(task_rt_outputs) + + if print_steps > 0 and self._cur_train_step % print_steps == 0: + loss = rt_outputs[self._trainers[task_id].name+'.loss'] + loss = np.mean(np.squeeze(loss)).tolist() + + time_end = time.time() + time_cost = time_end - time_begin + + print("global step: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( + (self._cur_train_step, self._trainers[task_id]._cur_train_step-1) % self._trainers[task_id]._steps_pur_epoch + 1, self._trainers[task_id]._steps_pur_epoch, self._trainers[task_id]._cur_train_epoch, + loss, print_steps / time_cost)) + time_begin = time.time() + + self._check_save() + + # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: + # print(cur_task.name+': train finished!') + # cur_task.save() + + # if (save_predict or save_ckpt) and self._cur_train_step % save_steps == 0: + # if save_predict: + # self.save(save_path, suffix='pred.step'+str(self._cur_train_step)) + # if save_ckpt: + # fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) + # print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) + + if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch: + break + + + def train_one_step(self, batch): + + if dev_count > 1: + assert isinstance(batch, list) + # for f in batch: + # f['branch'] = np.array([task_id], dtype='int64') + task_id = batch[0]['__task_id'][0] + else: + assert isinstance(batch, dict) + task_id = batch['__task_id'][0] + # batch['branch'] = np.array([task_id], dtype='int64') + + # feed = self._trainers[task_id].get_one_batch() + print(batch) + print(self._distribute_train_prog) + rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog, self._fetch_list) + + self._cur_train_steps += 1 + return rt_outputs, task_id + + # if dev_count > 1: + # # feed, mask, task_id = batch + # for f in feed: + # f['branch'] = np.array([task_id], dtype='int64') + # rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._trainers[task_id]._fetch_list) + # num_fakes = decode_fake(len(rt_outputs[0]), mask, self._trainers[task_id]._batch_size) + # for _ in range(num_fakes): + # for item in rt_outputs: + # item.pop() + # else: + # feed, task_id = batch + # feed['branch'] = np.array([task_id], dtype='int64') + # rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._trainers[task_id]._fetch_list) + + def predict_one_batch(self, batch): + raise NotImplementedError() + + def predict(self, output_dir=None, print_steps=1000): + raise NotImplementedError() + + @property + def overall_train_steps(self): + return self._overall_train_steps diff --git a/demo/demo2/log.txt b/demo/demo2/log.txt new file mode 100644 index 0000000..041ba3f --- /dev/null +++ b/demo/demo2/log.txt @@ -0,0 +1,914 @@ +{'token_ids': [[-1, -1], 'int64'], 'label_ids': [[-1], 'int64']} +{'token_ids': [[-1, -1], 'int64']} + +{'token_ids': [[-1, -1], 'int64'], 'label_ids': [[-1], 'int64'], u'input_mask': [[-1, -1, 1], 'float32'], u'position_ids': [[-1, -1], 'int64'], u'task_ids': [[-1, -1], 'int64'], u'segment_ids': [[-1, -1], 'int64']} +{'token_ids': [[-1, -1], 'int64']} +preparing data... +0 +61 +done! +name: "tmp_0" +type { + type: LOD_TENSOR + lod_tensor { + tensor { + data_type: INT64 + dims: 1 + } + lod_level: 0 + } +} + +name: "reduce_sum_0.tmp_0" +type { + type: LOD_TENSOR + lod_tensor { + tensor { + data_type: FP32 + dims: 1 + } + } +} +persistable: false + +name: "reduce_sum_1.tmp_0" +type { + type: LOD_TENSOR + lod_tensor { + tensor { + data_type: FP32 + dims: 1 + } + } +} +persistable: false + +random init params... +Loading pretraining parameters from pretrain/ernie/params... +Warning: cls.cls_out_w not found in pretrain/ernie/params. +Warning: cls.cls_out_b not found in pretrain/ernie/params. +Warning: senti_cls.cls_out_w not found in pretrain/ernie/params. +Warning: senti_cls.cls_out_b not found in pretrain/ernie/params. + +ok! +cls: expected train steps 30. +senti_cls: expected train steps 30. +ok! +Estimated overall train steps 60. +{'__task_id': array([0]), u'token_ids': array([[ 101, 2073, 2515, 5843, 4518, 1998, 8460, 2272, 2013, + 22254, 12848, 3593, 8787, 6177, 2028, 2012, 1037, 2051, + 1012, 100, 26286, 2081, 1996, 1036, 1036, 5843, 1005, + 1005, 2029, 2052, 2031, 2042, 1037, 1036, 1036, 2674, + 5843, 1005, 1005, 1010, 1036, 1036, 5217, 5843, 1005, + 1005, 1010, 1036, 1036, 13493, 5843, 1005, 1005, 4385, + 1012, 102, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0], + [ 101, 100, 12904, 2683, 1010, 100, 2001, 10836, 2011, + 2029, 2111, 1029, 100, 100, 2000, 5676, 1996, 3408, + 1997, 1996, 100, 1997, 100, 2419, 100, 100, 1010, + 2004, 100, 1997, 100, 1010, 2000, 3627, 2004, 4621, + 11590, 2104, 1996, 2516, 1997, 100, 1997, 100, 1012, + 100, 2516, 2001, 4379, 2000, 2010, 3920, 2365, 2021, + 2043, 100, 1005, 1055, 8215, 14153, 2351, 1996, 2516, + 1997, 100, 1997, 100, 1998, 100, 1997, 100, 2150, + 4372, 21077, 1999, 2028, 102, 0, 0, 0, 0, + 0, 0], + [ 101, 100, 100, 2003, 1996, 2905, 1997, 2019, 3883, + 2040, 2003, 4050, 3459, 1999, 2054, 6907, 1997, 5691, + 1029, 2002, 23873, 10874, 2143, 1000, 100, 1000, 1006, + 2325, 1007, 1998, 1996, 100, 7815, 3850, 1000, 100, + 100, 100, 1000, 1010, 2008, 4836, 2006, 100, 100, + 1012, 100, 2003, 1996, 3920, 2905, 1997, 3883, 100, + 100, 1012, 100, 2003, 1996, 7799, 1997, 100, 100, + 100, 100, 2516, 1999, 2432, 1012, 100, 1996, 2168, + 2095, 102, 0, 0, 0, 0, 0, 0, 0, + 0, 0], + [ 101, 100, 2095, 2106, 100, 4553, 2008, 1996, 10138, + 1999, 100, 2001, 10560, 1029, 28845, 1012, 100, 1010, + 2085, 2894, 1999, 100, 1010, 2001, 16839, 9080, 12863, + 2005, 2010, 10759, 1010, 1998, 2626, 2000, 1037, 2767, + 1010, 1000, 100, 8364, 1996, 2617, 1997, 2026, 6712, + 1012, 1000, 100, 1999, 100, 10937, 2002, 4342, 1010, + 2096, 8932, 2013, 100, 2000, 100, 1010, 2008, 1996, + 10138, 2018, 2042, 10560, 1010, 2002, 5228, 2010, 21782, + 1999, 1996, 5530, 1997, 2010, 2797, 3485, 1024, 1000, + 100, 102]]), u'input_mask': array([[[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.]], + + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.]], + + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.]], + + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]]], dtype=float32), u'position_ids': array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 0, 0, 0, + 0, 0, 0], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 0, 0, 0, 0, 0, 0, + 0, 0, 0], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82]]), u'task_ids': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'cls.label_ids': array([0, 0, 0, 3]), u'segment_ids': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])} + + +hahahahahahhahah +{u'token_ids': array([[ 101, 2073, 2515, 5843, 4518, 1998, 8460, 2272, 2013, + 22254, 12848, 3593, 8787, 6177, 2028, 2012, 1037, 2051, + 1012, 100, 26286, 2081, 1996, 1036, 1036, 5843, 1005, + 1005, 2029, 2052, 2031, 2042, 1037, 1036, 1036, 2674, + 5843, 1005, 1005, 1010, 1036, 1036, 5217, 5843, 1005, + 1005, 1010, 1036, 1036, 13493, 5843, 1005, 1005, 4385, + 1012, 102, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0], + [ 101, 100, 12904, 2683, 1010, 100, 2001, 10836, 2011, + 2029, 2111, 1029, 100, 100, 2000, 5676, 1996, 3408, + 1997, 1996, 100, 1997, 100, 2419, 100, 100, 1010, + 2004, 100, 1997, 100, 1010, 2000, 3627, 2004, 4621, + 11590, 2104, 1996, 2516, 1997, 100, 1997, 100, 1012, + 100, 2516, 2001, 4379, 2000, 2010, 3920, 2365, 2021, + 2043, 100, 1005, 1055, 8215, 14153, 2351, 1996, 2516, + 1997, 100, 1997, 100, 1998, 100, 1997, 100, 2150, + 4372, 21077, 1999, 2028, 102, 0, 0, 0, 0, + 0, 0], + [ 101, 100, 100, 2003, 1996, 2905, 1997, 2019, 3883, + 2040, 2003, 4050, 3459, 1999, 2054, 6907, 1997, 5691, + 1029, 2002, 23873, 10874, 2143, 1000, 100, 1000, 1006, + 2325, 1007, 1998, 1996, 100, 7815, 3850, 1000, 100, + 100, 100, 1000, 1010, 2008, 4836, 2006, 100, 100, + 1012, 100, 2003, 1996, 3920, 2905, 1997, 3883, 100, + 100, 1012, 100, 2003, 1996, 7799, 1997, 100, 100, + 100, 100, 2516, 1999, 2432, 1012, 100, 1996, 2168, + 2095, 102, 0, 0, 0, 0, 0, 0, 0, + 0, 0], + [ 101, 100, 2095, 2106, 100, 4553, 2008, 1996, 10138, + 1999, 100, 2001, 10560, 1029, 28845, 1012, 100, 1010, + 2085, 2894, 1999, 100, 1010, 2001, 16839, 9080, 12863, + 2005, 2010, 10759, 1010, 1998, 2626, 2000, 1037, 2767, + 1010, 1000, 100, 8364, 1996, 2617, 1997, 2026, 6712, + 1012, 1000, 100, 1999, 100, 10937, 2002, 4342, 1010, + 2096, 8932, 2013, 100, 2000, 100, 1010, 2008, 1996, + 10138, 2018, 2042, 10560, 1010, 2002, 5228, 2010, 21782, + 1999, 1996, 5530, 1997, 2010, 2797, 3485, 1024, 1000, + 100, 102]]), u'input_mask': array([[[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.]], + + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.]], + + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.], + [0.]], + + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]]], dtype=float32), u'position_ids': array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 0, 0, 0, + 0, 0, 0], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 0, 0, 0, 0, 0, 0, + 0, 0, 0], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82]]), u'task_ids': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), u'cls.label_ids': array([0, 0, 0, 3]), u'segment_ids': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])} diff --git a/demo/demo2/run.py b/demo/demo2/run.py index 6be3ea8..826ae33 100644 --- a/demo/demo2/run.py +++ b/demo/demo2/run.py @@ -30,6 +30,7 @@ if __name__ == '__main__': print(predict_cls_reader.outputs_attr) # 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段 cls_reader.register_with(ernie) + cls_reader2.register_with(ernie) print(cls_reader.outputs_attr) print(predict_cls_reader.outputs_attr) @@ -57,14 +58,10 @@ if __name__ == '__main__': loss_var = mh_trainer.build_forward(ernie, [cls_head, cls_head2]) - # controller.build_forward() - # Error! a head/backbone can be only build once! Try NOT to call build_forward method for any Trainer! - - # n_steps = cls_reader.num_examples * num_epochs // batch_size - # warmup_steps = int(0.1 * n_steps) - # print(warmup_steps) - # sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) - sched = None + n_steps = cls_reader.num_examples * num_epochs // batch_size + warmup_steps = int(0.1 * n_steps) + print(warmup_steps) + sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) adam = palm.optimizer.Adam(loss_var, lr, sched) @@ -78,44 +75,3 @@ if __name__ == '__main__': mh_trainer.train(print_steps=1) # trainer.save() - # print('prepare to predict...') - # pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred') - # cls_pred_head = palm.head.Classify(4, 1024, phase='pred') - # trainer.build_predict_forward(pred_ernie, cls_pred_head) - - # predict_cls_reader.load_data(predict_file, 8) - # print(predict_cls_reader.num_examples) - # predict_cls_reader.register_with(pred_ernie) - # trainer.fit_reader(predict_cls_reader, phase='predict') - # print('predicting..') - # trainer.predict(print_steps=20) - - - - - - - - - # controller = palm.Controller([mrqa, match4mrqa, mlm4mrqa]) - - # loss = controller.build_forward(bb, mask_task=[]) - - # n_steps = controller.estimate_train_steps(basetask=mrqa, num_epochs=2, batch_size=8, dev_count=4) - # adam = palm.optimizer.Adam(loss) - # sched = palm.schedualer.LinearWarmup(learning_rate, max_train_steps=n_steps, warmup_steps=0.1*n_steps) - # - # controller.build_backward(optimizer=adam, schedualer=sched, weight_decay=0.001, use_ema=True, ema_decay=0.999) - - # controller.random_init_params() - # controller.load_pretrain('../../pretrain_model/ernie/params') - # controller.train() - - - - - - # controller = palm.Controller(config='config.yaml', task_dir='tasks', for_train=False) - # controller.pred('mrqa', inference_model_dir='output_model/secondrun/mrqa/infer_model') - - diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index 8b86e53..6708b71 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -3,7 +3,9 @@ from paddle import fluid from paddle.fluid import layers from paddlepalm.distribute import gpu_dev_count, cpu_dev_count from paddlepalm import Trainer -from paddlepalm.utils.reader_helper import create_multihead_iterator_fn, create_multihead_feed_batch_process_fn +from paddlepalm.utils import reader_helper +import numpy as np +import time dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count VERBOSE=False @@ -17,6 +19,9 @@ class MultiHeadTrainer(Trainer): self._trainers = trainers + name_maxlen = max([len(i.name) for i in self._trainers]) + self._name_pads = {i.name: name_maxlen-len(i.name) for i in self._trainers} + self._train_init = False self._predict_init = False self._feeded_var_names = None @@ -80,18 +85,30 @@ class MultiHeadTrainer(Trainer): task_fns = {i: lambda i=i: get_loss(i) for i in range(num_heads)} with fluid.program_guard(train_prog, train_init_prog): - head_id_var = fluid.data(name="branch",shape=[1],dtype='int64') + task_id_var = fluid.data(name="__task_id",shape=[1],dtype='int64') + # task_id_var = fluid.layers.fill_constant(shape=[1],dtype='int64', value=1) + # print(task_id_var.name) + loss_var = layers.switch_case( - branch_index=head_id_var, + branch_index=task_id_var, branch_fns=task_fns ) - self._head_id_var = head_id_var + self._task_id_var = task_id_var + self._loss_var = loss_var + self._fetch_list = [loss_var.name] + # for b in train_prog.blocks: + # for var in b.vars: + # pass + # if 'task_id' in var: + # print(var) + # exit() + # print(var) return loss_var def fit_readers(self, reader_dict): raise NotImplementedError() - def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs): + def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'): if isinstance(readers, list): reader_dict = {k.name: v for k,v in zip(self._trainers, readers)} @@ -106,10 +123,13 @@ class MultiHeadTrainer(Trainer): trainer_dict = {t.name: t for t in self._trainers} assert sampling_reference in trainer_dict - trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference]) + trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference], task_id=self._task_id_var) base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch + self._finish_steps = {} + self._finish = {} input_names = [] + name_to_pos = [] joint_shape_and_dtypes = [] iterators = [] prefixes = [] @@ -124,22 +144,29 @@ class MultiHeadTrainer(Trainer): max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch) if not t._as_auxilary: print('{}: expected train steps {}.'.format(t.name, max_train_steps)) + self._finish_steps[t.name] = max_train_steps + self._finish[t.name] = False + else: + self._finish_steps[t.name] = 9999999999 + self._finish[t.name] = True + global_steps += max_train_steps if t.name != sampling_reference: - t.fit_reader(reader_dict[t.name]) + t.fit_reader(reader_dict[t.name], task_id=self._task_id_var) net_inputs.append(t._net_inputs) prefixes.append(t.name) mrs.append(t.mix_ratio) iterators.append(t._raw_iterator_fn()) input_names.append(t._input_names) + name_to_pos.append(t._name_to_position) joint_shape_and_dtypes.append(t._shape_and_dtypes) print('Estimated overall train steps {}.'.format(global_steps)) self._overall_train_steps = global_steps - iterator_fn = create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ - mrs, input_names, dev_count=dev_count) - feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs) + iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \ + mrs, input_names, name_to_pos, dev_count=dev_count) + feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) if gpu_dev_count > 1: distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn) @@ -152,6 +179,15 @@ class MultiHeadTrainer(Trainer): elif phase == 'predict': self._predict_reader = distribute_feeder_fn() self._pred_feed_batch_process_fn = feed_batch_process_fn + + def check_finish(self, task_name, silent=False): + trainers = {t.name:t for t in self._trainers} + if trainers[task_name]._cur_train_step == self._finish_steps[task_name]: + if not silent: + print(task_name+' train finish!') + self._finish[task_name]=True + flags = list(set(self._finish.values())) + return len(flags) == 1 and flags[0] == True def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5): iterator = self._train_reader @@ -180,12 +216,11 @@ class MultiHeadTrainer(Trainer): time_begin = time.time() for feed in iterator: - print(feed) # batch, task_id = feed - rt_outputs, task_id = self.train_one_step(feed, task_id) + rt_outputs, task_id = self.train_one_step(feed) task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')} - self._task_head.batch_postprocess(task_rt_outputs) + self._trainers[task_id]._task_head.batch_postprocess(task_rt_outputs) if print_steps > 0 and self._cur_train_step % print_steps == 0: loss = rt_outputs[self._trainers[task_id].name+'.loss'] @@ -194,12 +229,17 @@ class MultiHeadTrainer(Trainer): time_end = time.time() time_cost = time_end - time_begin - print("global step: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( - (self._cur_train_step, self._trainers[task_id]._cur_train_step-1) % self._trainers[task_id]._steps_pur_epoch + 1, self._trainers[task_id]._steps_pur_epoch, self._trainers[task_id]._cur_train_epoch, + print("global step: {}, {}: step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( + self._cur_train_step, ' '*self._name_pads[self._trainers[task_id].name]+self._trainers[task_id].name, \ + (self._trainers[task_id]._cur_train_step-1) % self._trainers[task_id]._steps_pur_epoch + 1, \ + self._trainers[task_id]._steps_pur_epoch, self._trainers[task_id]._cur_train_epoch, \ loss, print_steps / time_cost)) time_begin = time.time() self._check_save() + finish = self.check_finish(self._trainers[task_id].name) + if finish: + break # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: # print(cur_task.name+': train finished!') @@ -212,26 +252,19 @@ class MultiHeadTrainer(Trainer): # fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) # print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) - if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch: - break - def train_one_step(self, batch): if dev_count > 1: assert isinstance(batch, list) - # for f in batch: - # f['branch'] = np.array([task_id], dtype='int64') - task_id = batch[0]['__task_id'] + task_id = batch[0]['__task_id'][0] else: assert isinstance(batch, dict) - task_id = batch['__task_id'] - # batch['branch'] = np.array([task_id], dtype='int64') + task_id = batch['__task_id'][0] - # feed = self._trainers[task_id].get_one_batch() - rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog) + rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog, self._fetch_list) - self._cur_train_steps += 1 + self._cur_train_step += 1 return rt_outputs, task_id # if dev_count > 1: diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 2022c5c..29a5917 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -41,7 +41,7 @@ class Trainer(object): self._train_init = False self._predict_init = False - nelf._check_save = lambda: False + self._check_save = lambda: False # if save_predict_model: # self._save_predict_model = True @@ -62,6 +62,7 @@ class Trainer(object): self._num_examples = 0 self._multi_task = False + self._as_auxilary = False # training process management self._mix_ratio = mix_ratio @@ -93,6 +94,7 @@ class Trainer(object): def build_predict_forward(self, pred_backbone, pred_head, pred_prog=None, pred_init_prog=None): self._pred_head = pred_head + self._pred_backbone = pred_backbone # self._pred_reader = self._reader.clone(phase='pred') pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name) # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader'] @@ -145,6 +147,7 @@ class Trainer(object): def build_forward(self, backbone, task_head): # assert not self._multi_task, "you cannot build_forward in trainer when a train is wrapper by MultiHeadTrainer." self._task_head = task_head + self._backbone = backbone # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode" self._build_forward = True @@ -239,7 +242,10 @@ class Trainer(object): # for var in block.vars: # print("[debug] : %d, %s" % (_id, var)) self._loss_var = loss_var + print(loss_var) return loss_var + + def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=None): # assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer." # build optimizer assert self._train_init_prog is not None, "train graph not foung! You should build_forward first." @@ -289,7 +295,7 @@ class Trainer(object): def set_as_aux(self): self._as_auxilary = True - def fit_reader(self, reader, phase='train'): + def fit_reader(self, reader, phase='train', task_id=None): # assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer." # load data @@ -304,9 +310,14 @@ class Trainer(object): self._steps_pur_epoch = reader.num_examples // batch_size shape_and_dtypes = self._shape_and_dtypes name_to_position = self._name_to_position + if task_id is not None: + self._net_inputs['__task_id'] = task_id net_inputs = self._net_inputs self._train_batch_size = batch_size self._num_examples = reader.num_examples + reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)') + reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)') + reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone') elif phase == 'predict': tail = self._num_examples % batch_size > 0 self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0 @@ -315,6 +326,9 @@ class Trainer(object): net_inputs = self._pred_net_inputs self._predict_batch_size = batch_size self._pred_num_examples = reader.num_examples + reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)') + reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)') + reader_helper.check_io(inst._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone') else: raise NotImplementedError() @@ -450,8 +464,6 @@ class Trainer(object): iterator = self._train_reader self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) - - # if save_path is not None or save_steps is not None: # assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built." # if self._save_predict_model: @@ -501,11 +513,8 @@ class Trainer(object): # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: # print(cur_task.name+': train finished!') # cur_task.save() - - self._check_save() - - if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch: + if self._num_epochs is None and not self._multi_task and self._cur_train_step == self._steps_pur_epoch: break # save_path = os.path.join(main_conf['save_path'], 'ckpt', # "step_" + str(global_step)) @@ -560,24 +569,26 @@ class Trainer(object): results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir) return results - def train_one_step(self, batch, executor=None, distribute_train_prog=None): + def train_one_step(self, batch, executor=None, distribute_train_prog=None, fetch_list=None): exe = self._exe if executor is None else executor distribute_train_prog = self._distribute_train_prog if distribute_train_prog is None else distribute_train_prog + fetch_list = self._fetch_list if fetch_list is None else fetch_list if gpu_dev_count > 1: feed, mask = batch - rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list) + rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size) for _ in range(num_fakes): for item in rt_outputs: item.pop() else: feed = self._feed_batch_process_fn(batch) - rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list) + rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list) rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} self._cur_train_step += 1 self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch + self._check_save() return rt_outputs @property diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index 0f8f046..9e93dc5 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -56,10 +56,22 @@ def create_feed_batch_process_fn(net_inputs): # return feed_batch_process_fn +def check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): + for name, attr in in_attr.items(): + assert name in out_attr, in_name+': '+name+' not found in '+out_name + if attr != out_attr[name]: + if strict: + raise ValueError(name+': shape or dtype not consistent!') + else: + logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) + + def _check_and_adapt_shape_dtype(rt_val, attr, message=""): if not isinstance(rt_val, np.ndarray): + if rt_val is None: + raise Exception(message+": get None value. ") rt_val = np.array(rt_val) - assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)." + assert rt_val.dtype != np.dtype('O'), message+"yielded data is not a valid tensor (number of elements on some dimension may not consistent): {}".format(rt_val) if rt_val.dtype == np.dtype('float64'): rt_val = rt_val.astype('float32') @@ -147,14 +159,12 @@ def create_iterator_fn(iterator, iterator_prefix, shape_and_dtypes, outname_to_p return iterator_fn -def create_multihead_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, names, dev_count=1, keep_one_task=True): +def create_multihead_iterator_fn(iterators, iterator_prefixes, joint_shape_and_dtypes, mrs, names, outname_to_pos, dev_count=1, keep_one_task=True): task_ids = range(len(iterators)) weights = [mr / float(sum(mrs)) for mr in mrs] if not keep_one_task: dev_count = 1 - pos_to_outname = {j:i for i,j in outname_to_pos.items()} - def iterator(): while True: id = np.random.choice(task_ids, p=weights) @@ -171,10 +181,12 @@ def create_multihead_iterator_fn(iterators, iterator_prefixes, joint_shape_and_d task_outname = prefix + '.' + outname if outname in names[id]: + idx = outname_to_pos[id][outname] val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=outname+': ') results[outname] = val if task_outname in names[id]: + idx = outname_to_pos[id][task_outname] val = _check_and_adapt_shape_dtype(val, joint_shape_and_dtypes[id][idx], message=task_outname+': ') results[task_outname] = val @@ -297,7 +309,7 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc names = [] start = 0 if insert_taskid: - ret.append(([1], 'int64')) + ret.append(([1, 1], 'int64')) names.append('__task_id') start += 1 @@ -318,11 +330,14 @@ def merge_input_attrs(backbone_attr, task_attrs, insert_taskid=True, insert_batc names += sorted(backbone_attr.keys()) ret.extend([backbone_attr[k] for k in names[start:]]) + name_to_position = {} # pos=0 is for task_id, thus we start from 1 + for pos, k in enumerate(names): + name_to_position[k] = pos for task_attr in task_attrs: task_names = sorted(task_attr.keys()) names.extend(task_names) ret.extend([task_attr[k] for k in task_names]) - return names, ret - - + for pos, k in enumerate(task_names, start=len(name_to_position)): + name_to_position[k] = pos + return names, ret, name_to_position -- GitLab