multihead_trainer.py 12.5 KB
Newer Older
W
wangxiao1021 已提交
1 2 3

from paddle import fluid
from paddle.fluid import layers
X
Xiaoyao Xi 已提交
4
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count, data_feeder, decode_fake
W
wangxiao1021 已提交
5 6 7 8
from paddlepalm import Trainer
from paddlepalm.utils import reader_helper
import numpy as np
import time
W
wangxiao1021 已提交
9
import sys
W
wangxiao1021 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
VERBOSE=False


class MultiHeadTrainer(Trainer):
    """
    The core unit to start a multi-task training/predicting session. A MultiHeadTrainer is built based on several Trainers. Beyond the inheritance of Trainer, it additionally achieves model backbone reuse across tasks, trainer sampling for multi-task learning, and multi-head inference for effective evaluation and prediction. 
    """
    
    def __init__(self, trainers):
        """Create a new multi_head_trainer.

        Args:
            trainers: a list of Trainer objects.

        """
        # if reuse_flags is not None:
        #     assert len(reuse_flags) == len(trainers)
        Trainer.__init__(self, '')

        self._trainers = trainers

        name_maxlen = max([len(i.name) for i in self._trainers])
        self._name_pads = {i.name: name_maxlen-len(i.name) for i in self._trainers}

        self._train_init = False
        self._predict_init = False
        self._feeded_var_names = None
        self._cur_train_step = 0
        self._target_vars = None

        self._inputname_to_varname = {}
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

        self._exe = None

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

W
wangxiao1021 已提交
55
        self._check_save = lambda: False
W
wangxiao1021 已提交
56 57 58
        for t in self._trainers:
            t._set_multitask()

W
wangxiao1021 已提交
59 60
    # def build_forward(self, backbone, heads):
    def build_forward(self):
W
wangxiao1021 已提交
61 62 63 64 65 66
        """
        Build forward computation graph for training, which usually built from input layer to loss node.

        Return:
            - loss_var: a Variable object. The computational graph variable(node) of loss.
        """
W
wangxiao1021 已提交
67 68 69 70 71 72 73
        head_dict = {}
        backbone = self._trainers[0]._backbone
        for i in self._trainers:
            assert i._task_head is not None and i._backbone is not None, "You should build forward for the {} task".format(i._name)
            assert i._backbone == backbone, "The backbone for each task must be the same"
            head_dict[i._name] = i._task_head
            
W
wangxiao1021 已提交
74 75 76 77 78 79 80
        train_prog = fluid.Program()
        train_init_prog = fluid.Program()
        self._train_prog = train_prog
        self._train_init_prog = train_init_prog

        def get_loss(i):
            head = head_dict[self._trainers[i].name]
W
wangxiao1021 已提交
81
            self._trainers[i]._lock_prog = True
W
wangxiao1021 已提交
82
            loss_var = self._trainers[i].build_forward(backbone, head)
W
wangxiao1021 已提交
83
            self._trainers[i]._lock_prog = False
W
wangxiao1021 已提交
84 85
            return loss_var
      
W
wangxiao1021 已提交
86
        task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))}
W
wangxiao1021 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

        with fluid.program_guard(train_prog, train_init_prog):
            task_id_var = fluid.data(name="__task_id",shape=[1],dtype='int64')

            loss_var = layers.switch_case(
                branch_index=task_id_var,
                branch_fns=task_fns
            )
        self._task_id_var = task_id_var
        self._loss_var = loss_var
        self._fetch_list = [loss_var.name]
        # for b in train_prog.blocks:
        #     for var in b.vars:
        #         pass
                # if 'task_id' in var:
                #     print(var)
                #     exit()
                # print(var)
        if not self._multi_task:
            self._init_exe_prog(for_train=True)
        return loss_var

    def fit_readers(self, reader_dict):
        raise NotImplementedError()

    def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'):
        """
X
Xiaoyao Xi 已提交
114 115 116
        Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only 
            works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from 
            their `mix_ratio`.
W
wangxiao1021 已提交
117 118 119

        Args:
            readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each 
X
Xiaoyao Xi 已提交
120 121
            sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling. 
            num_epochs: training epochs of the sampling_reference task (trainer). 
W
wangxiao1021 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
        """
        self._check_phase(phase)

        if isinstance(readers, list):
            reader_dict = {k.name: v for k,v in zip(self._trainers, readers)}
        elif isinstance(readers, dict):
            reader_dict = readers
        else:
            raise ValueError()
        
        num_heads = len(self._trainers)
        assert len(reader_dict) == num_heads, "received number of readers is not consistent with trainers."

        trainer_dict = {t.name: t for t in self._trainers}
        assert sampling_reference in trainer_dict

        trainer_dict[sampling_reference]._set_task_id(self._task_id_var)
        trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference])
        base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch

        self._finish_steps = {}
        self._finish = {}
        input_names = []
        name_to_pos = []
        joint_shape_and_dtypes = []
        iterators = []
        prefixes = []
        mrs = []
        net_inputs = []
        global_steps = 0
        for t in self._trainers:
            assert t.name in reader_dict
            assert reader_dict[t.name].num_epochs is None, "{}: num_epochs is not None. \
                To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(t.name)
            # print(num_epochs, t.mix_ratio, base_steps_pur_epoch)
            max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch)
            if not t._as_auxilary:
                print('{}: expected train steps {}.'.format(t.name, max_train_steps))
W
wangxiao1021 已提交
160
                sys.stdout.flush()
W
wangxiao1021 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
                self._finish_steps[t.name] = max_train_steps
                self._finish[t.name] = False
            else:
                self._finish_steps[t.name] = 9999999999
                self._finish[t.name] = True

            global_steps += max_train_steps
            if t.name != sampling_reference:
                t._set_task_id(self._task_id_var)
                t.fit_reader(reader_dict[t.name])
            net_inputs.append(t._net_inputs)
            prefixes.append(t.name)
            mrs.append(t.mix_ratio)
            iterators.append(t._raw_iterator_fn())
            input_names.append(t._input_names)
            name_to_pos.append(t._name_to_position)
            joint_shape_and_dtypes.append(t._shape_and_dtypes)

        print('Estimated overall train steps {}.'.format(global_steps))
W
wangxiao1021 已提交
180
        sys.stdout.flush()
W
wangxiao1021 已提交
181 182 183 184 185 186 187
        self._overall_train_steps = global_steps

        iterator_fn = reader_helper.create_multihead_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \
            mrs, input_names, name_to_pos, dev_count=dev_count)
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)

        if gpu_dev_count > 1:
W
wangxiao1021 已提交
188
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase, is_multi=True)
W
wangxiao1021 已提交
189
        else:
W
wangxiao1021 已提交
190
            distribute_feeder_fn = iterator_fn()
W
wangxiao1021 已提交
191 192

        if phase == 'train':
W
wangxiao1021 已提交
193
            self._train_reader = distribute_feeder_fn
W
wangxiao1021 已提交
194 195
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
W
wangxiao1021 已提交
196
            self._predict_reader = distribute_feeder_fn
W
wangxiao1021 已提交
197 198 199 200 201 202 203
            self._pred_feed_batch_process_fn = feed_batch_process_fn

    def _check_finish(self, task_name, silent=False):
        trainers = {t.name:t for t in self._trainers}
        if trainers[task_name]._cur_train_step == self._finish_steps[task_name]:
            if not silent:
                print(task_name+' train finish!')
W
wangxiao1021 已提交
204
                sys.stdout.flush()
W
wangxiao1021 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
            self._finish[task_name]=True
        flags = list(set(self._finish.values()))
        return len(flags) == 1 and flags[0] == True
        
    def train(self, print_steps=5):
        """
        start training.

        Args:
            print_steps: int. Logging frequency of training message, e.g., current step, loss and speed.
        """
        iterator = self._train_reader
        self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
        for t in self._trainers:
            t._set_exe(self._exe)
            t._set_dist_train(self._distribute_train_prog)
            t._set_fetch_list(self._fetch_list)

        time_begin = time.time()
        for feed in iterator:
            # batch, task_id = feed
            rt_outputs, task_id = self.train_one_step(feed)

            task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')}
            self._trainers[task_id]._task_head.batch_postprocess(task_rt_outputs)
            if print_steps > 0 and self._cur_train_step % print_steps == 0:
                loss = rt_outputs[self._trainers[task_id].name+'.loss']
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print("global step: {}, {}: step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
                       self._cur_train_step, ' '*self._name_pads[self._trainers[task_id].name]+self._trainers[task_id].name, \
                       (self._trainers[task_id]._cur_train_step-1) % self._trainers[task_id]._steps_pur_epoch + 1, \
                       self._trainers[task_id]._steps_pur_epoch, self._trainers[task_id]._cur_train_epoch, \
                       loss, print_steps / time_cost))
W
wangxiao1021 已提交
242
                sys.stdout.flush()
W
wangxiao1021 已提交
243 244
                time_begin = time.time()

W
wangxiao1021 已提交
245
            self._check_save()
W
wangxiao1021 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
            finish = self._check_finish(self._trainers[task_id].name)
            if finish:
                break

            # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
            #     print(cur_task.name+': train finished!')
            #     cur_task.save()

            # if (save_predict or save_ckpt) and self._cur_train_step % save_steps == 0:
            #     if save_predict:
            #         self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
            #     if save_ckpt:
            #         fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
            #         print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))


    def train_one_step(self, batch):

        if dev_count > 1:
W
wangxiao1021 已提交
265 266
            assert isinstance(batch, tuple)
            task_id = batch[0][0]['__task_id'][0]
W
wangxiao1021 已提交
267 268 269 270 271 272 273 274
        else:
            assert isinstance(batch, dict)
            task_id = batch['__task_id'][0]
            
        # rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog, self._fetch_list)
        rt_outputs = self._trainers[task_id].train_one_step(batch)

        self._cur_train_step += 1
W
wangxiao1021 已提交
275
        self._check_save()
W
wangxiao1021 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
        return rt_outputs, task_id
        
        # if dev_count > 1:
        #     # feed, mask, task_id = batch
        #     for f in feed:
        #         f['branch'] = np.array([task_id], dtype='int64')
        #     rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._trainers[task_id]._fetch_list)
        #     num_fakes = decode_fake(len(rt_outputs[0]), mask, self._trainers[task_id]._batch_size)
        #     for _ in range(num_fakes):
        #         for item in rt_outputs:
        #             item.pop()
        # else:
        #     feed, task_id = batch
        #     feed['branch'] = np.array([task_id], dtype='int64')
        #     rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._trainers[task_id]._fetch_list)

    def predict_one_batch(self, batch):
        raise NotImplementedError()

    def predict(self, output_dir=None, print_steps=1000):
        raise NotImplementedError()

    @property
    def overall_train_steps(self):
        return self._overall_train_steps