trainer.py 16.7 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16
from __future__ import print_function
X
xixiaoyao 已提交
17 18 19
import os
import json
from paddle import fluid
X
xixiaoyao 已提交
20 21 22
import paddlepalm.utils.basic_helper as helper
from paddlepalm.utils import reader_helper
# from paddlepalm.default_settings import *
X
xixiaoyao 已提交
23

X
xixiaoyao 已提交
24
DEBUG=False
X
xixiaoyao 已提交
25

X
xixiaoyao 已提交
26

X
xixiaoyao 已提交
27 28 29 30 31 32
class Trainer(object):

    def __init__(self, name, reader, task_head, \
                 save_predict_model=False, pred_head=None, save_path=None, save_steps=-1, \
                 mix_ratio=1.0, reuse_head_with=None, \
                 silent=False):
X
xixiaoyao 已提交
33 34 35

        self._name = name
        self._verbose = not silent
X
xixiaoyao 已提交
36 37 38 39
        self._reader = reader
        self._pred_reader = None
        self._task_head = task_head
        self._pred_head = pred_head
X
xixiaoyao 已提交
40

X
xixiaoyao 已提交
41
        if save_predict_model:
X
xixiaoyao 已提交
42
            self._save_predict_model = True
X
xixiaoyao 已提交
43 44
            assert save_path is not None, "save_path is required when save_predict_model is set."
            assert save_steps == -1 or save_steps > 0, "save_steps should be -1 (only save the last step of this task) or larger than 0"
X
xixiaoyao 已提交
45 46 47 48
            assert pred_head is not None, "pred_head is required to save predict model."
            self._pred_reader = reader.clone(phase='pred')
            if save_path is not None and not os.path.exists(save_path):
                os.makedirs(save_path)
X
xixiaoyao 已提交
49
        else:
X
xixiaoyao 已提交
50 51 52
            assert save_path is None, "You should set save_predict_model as True, or the save_path is invalid."
            assert save_steps == -1 or save_steps == 0, "You should set save_predict_model as True, or the save_steps is invalid."
            assert pred_head is None, "You should set save_predict_model as True, or the pred_head is invalid." 
X
xixiaoyao 已提交
53

X
xixiaoyao 已提交
54
        self._save_steps = save_steps
X
xixiaoyao 已提交
55

X
xixiaoyao 已提交
56
        self._task_reuse_scope = name if reuse_head_with is None else reuse_head_with
X
xixiaoyao 已提交
57 58 59 60

        self._feeded_var_names = None
        self._target_vars = None

X
xixiaoyao 已提交
61 62
        self._num_examples = 0

X
xixiaoyao 已提交
63 64 65 66 67 68 69 70 71 72
        # training process management
        self._mix_ratio = mix_ratio
        self._expected_train_steps = None
        self._expected_train_epochs = None
        self._steps_pur_epoch = None
        self._cur_train_epoch = 0
        self._cur_train_step = 0
        self._train_finish = False

        # 存放不同运行阶段(train,eval,pred)的数据集reader,key为phase,value为Reader实例
X
xixiaoyao 已提交
73 74
        # self._reader = {'train': reader, 'eval': None, 'pred': self._pred_reader}
        # self._input_layer = None
X
xixiaoyao 已提交
75
        self._inputname_to_varname = {}
X
xixiaoyao 已提交
76
        # self._task_layer = {'train': task_head, 'eval': None, 'pred': pred_head}
X
xixiaoyao 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

        self._exe = fluid.Executor(fluid.CPUPlace())

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

        self._lock = False
X
xixiaoyao 已提交
90 91
        self._build_forward = False

X
xixiaoyao 已提交
92

X
xixiaoyao 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    def build_forward(self, backbone, pred_backbone=None):

        # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
        self._build_forward = True
        if self._save_predict_model:
            assert pred_backbone is not None, ""
        
        # create reader, task
        # then check i/o across reader, backbone and task_layer
        task_attrs = []
        pred_task_attrs = []

        task_attr_from_reader = helper.encode_inputs(self._task_head.inputs_attrs['reader'], self.name)
        # task_attr_from_reader = self._task_head.inputs_attrs['reader']

        # _check_io(backbone.inputs_attr, inst._reader['train'].outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train')
        # _check_io(inst.taskblock['train'].inputs_attrs['reader'], inst._reader['train'].outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train')
        # _check_io(inst._taskblock['train'].inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone')

        if self._save_predict_model:
            pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
            # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']

            # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
            # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
            # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')

        # merge reader input attrs from backbone and task_instances
        input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader)
        pred_input_names, pred_shape_and_dtypes, _ = reader_helper.merge_input_attrs(backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
        # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
        self._shape_and_dtypes = shape_and_dtypes
        self._name_to_position = name_to_position

        if DEBUG:
            print('----- for debug -----')
            print('joint input names:')
            print(joint_input_names)
            print('joint input shape and dtypes:')
            print(joint_shape_and_dtypes)


        input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]
        pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]

        train_prog = fluid.Program()
        train_init_prog = fluid.Program()
        self._prog = train_prog
        self._train_prog = train_prog
        self._train_init_prog = train_init_prog
        with fluid.program_guard(train_prog, train_init_prog):
            net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)

            # build backbone and task layers
            # bb_output_vars = self._backbone.build(net_inputs, scope_name='__paddlepalm_')
            bb_output_vars = backbone.build(net_inputs)
            assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())
        
        pred_prog = fluid.Program()
        pred_init_prog = fluid.Program()
        with fluid.program_guard(pred_prog, pred_init_prog):
            pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
            # pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
            pred_bb_output_vars = pred_backbone.build(pred_net_inputs)

        # fluid.framework.switch_main_program(train_prog)
        # fluid.framework.switch_startup_program(train_init_prog)

        task_output_vars = {}
        task_inputs = {'backbone': bb_output_vars}
        task_inputs_from_reader = helper.decode_inputs(net_inputs, self.name)
        task_inputs['reader'] = task_inputs_from_reader

        scope = self.name+'.'
        with fluid.program_guard(train_prog, train_init_prog):
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(task_inputs, phase='train', scope=scope)
        output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
        old = len(task_output_vars) # for debug
        task_output_vars.update(output_vars)
        assert len(task_output_vars) - old == len(output_vars) # for debug

        # prepare predict vars for saving inference model
        if self._save_predict_model:
            with fluid.program_guard(pred_prog, pred_init_prog):
                cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
                # self.pred_input = cur_inputs
                self._pred_input_name_list, self._pred_input_varname_list = \
                    zip(*[[k, v.name] for k,v in cur_inputs.items()])

                pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
                scope = self.name + '.'
                with fluid.unique_name.guard(scope):
                    self._build_head(pred_task_inputs, phase='pred', scope=scope)


        bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
        task_fetches = {k: v.name for k,v in task_output_vars.items()}
        # fetches = task_fetches
        # fetches['__task_id'] = net_inputs['__task_id'].name

        # compute loss
        # task_id_var = net_inputs['__task_id']
        # task_id_vec = layers.one_hot(task_id_var, num_instances)
        # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
        # loss = layers.reduce_sum(task_id_vec * losses)
        with fluid.program_guard(train_prog, train_init_prog):
            loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
        return loss_var

    def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999):
        # build optimizer
        optimizer._set_prog(self._train_prog)
        with fluid.program_guard(self._train_prog, self._train_init_prog):
            param_grads = optimizer.build()

            if weight_decay is not None:

                param_list = dict()

                for param in self._prog.global_block().all_parameters():
                    param_list[param.name] = param * 1.0
                    param_list[param.name].stop_gradient = True

                def exclude_from_weight_decay(name):
                    if name.find("layer_norm") > -1:
                        return True
                    bias_suffix = ["_bias", "_b", ".b_0"]
                    for suffix in bias_suffix:
                        if name.endswith(suffix):
                            return True
                    return False

                for param, grad in param_grads:
                    if exclude_from_weight_decay(param.name):
                        continue
                    with param.block.program._optimized_guard(
                        [param, grad]), fluid.framework.name_scope("weight_decay"):
                        updated_param = param - param_list[
                            param.name] * weight_decay * optimizer.get_cur_learning_rate()
                        fluid.layers.assign(output=param, input=updated_param)


            # loss.persistable = True
            if use_ema:
                ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
                ema.update()

    def load_data(self, input_file, file_format, batch_size, num_epochs=None, shuffle_train=True):
        # load data
        print("preparing data...", end='')
        self._reader._load_data(input_file=input_file, batch_size=batch_size, \
                                num_epochs=num_epochs, file_format=file_format, \
                                shuffle_train=shuffle_train)
        self._num_examples = self._reader.num_examples
        print('ok!')

        # merge dataset iterators and create net input vars
        iterator = self._reader._iterator()
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, self._shape_and_dtypes, self._name_to_position)
        return iterator_fn

    def random_init_params(self):
        
        helper.build_executor()

    def _build_head(self, net_inputs, phase, scope=""):
        if phase == 'train':
            output_vars = self._task_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
265
        if phase == 'pred':
X
xixiaoyao 已提交
266
            output_vars = self._pred_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
            if output_vars is not None:
                self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
            else:
                self._pred_fetch_name_list = []
                self._pred_fetch_var_list = []
        return output_vars

    def _postprocess(self, rt_outputs, phase):
        return self._task_layer[phase].postprocess(rt_outputs)

    def _epoch_postprocess(self, epoch_inputs, phase):
        return self._task_layer[phase].epoch_postprocess(epoch_inputs)
    
    def save(self, suffix=''):
        dirpath = self._save_infermodel_path + suffix
        self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]

        prog = fluid.default_main_program().clone()
        fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)

        conf = {}
        for k, strv in self._save_protocol.items(): 
            d = None
            v = locals()
            exec('d={}'.format(strv), globals(), v)
            conf[k] = v['d']
        with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
            writer.write(json.dumps(conf, indent=1))
        print(self._name + ': inference model saved at ' + dirpath)

    def _load(self, infer_model_path=None):
        if infer_model_path is None:
            infer_model_path = self._save_infermodel_path
        for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): 
            strv = self._save_protocol[k]
            exec('{}=v'.format(strv))
        pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
            fluid.io.load_inference_model(infer_model_path, self._exe)
        print(self._name+': inference model loaded from ' + infer_model_path)
        return pred_prog

    @property
    def name(self):
        return self._name

    @property
X
xixiaoyao 已提交
313 314
    def num_examples(self):
        return self._num_examples
X
xixiaoyao 已提交
315

X
xixiaoyao 已提交
316 317 318
    # @property
    # def _pred_input(self):
    #     return zip(*[self._pred_input_name_list, self._pred_input_varname_list])
X
xixiaoyao 已提交
319

X
xixiaoyao 已提交
320 321 322 323 324
    # @_pred_input.setter
    # def _pred_input(self, val):
    #     assert isinstance(val, dict)
    #     self._pred_input_name_list, self._pred_input_varname_list = \
    #         zip(*[[k, v.name] for k,v in val.items()])
X
xixiaoyao 已提交
325

X
xixiaoyao 已提交
326 327 328
    # @property
    # def _pred_fetch_list(self):
    #     return [self._pred_fetch_name_list, self._pred_fetch_var_list]
X
xixiaoyao 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355

    @property
    def mix_ratio(self):
        if self._mix_ratio is not None:
            return self._mix_ratio
        else:
            raise ValueError("{}: mix_ratio is None".format(self._name))

    @mix_ratio.setter
    def mix_ratio(self, value):
        self._mix_ratio = float(value)
        if self._verbose:
            print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))

    @property
    def save_infermodel_every_n_steps(self):
        return self._save_infermodel_every_n_steps

    @save_infermodel_every_n_steps.setter
    def save_infermodel_every_n_steps(self, val):
        self._save_infermodel_every_n_steps = val

    @property
    def expected_train_steps(self):
        return self._expected_train_steps

    @expected_train_steps.setter
X
xixiaoyao 已提交
356
    def expected_train_steps(self, value):
X
xixiaoyao 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
        self._expected_train_steps = value
        self._expected_train_epochs = value / float(self._steps_pur_epoch)

    @property
    def expected_train_epochs(self):
        return self._expected_train_epochs

    @property
    def cur_train_epoch(self):
        return self._cur_train_epoch

    @property
    def cur_train_step(self):
        return self._cur_train_step

X
xixiaoyao 已提交
372 373 374 375 376 377 378 379
    # @cur_train_step.setter
    # def _cur_train_step(self, value):
    #     self._cur_train_step = value
    #     if self._cur_train_step > self._steps_pur_epoch:
    #         self._cur_train_epoch += 1
    #         self._cur_train_step = 1
    #     if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
    #         self._train_finish = True
X
xixiaoyao 已提交
380 381 382 383 384 385

    @property
    def steps_pur_epoch(self):
        return self._steps_pur_epoch

    @steps_pur_epoch.setter
X
xixiaoyao 已提交
386
    def steps_pur_epoch(self, value):
X
xixiaoyao 已提交
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
        self._steps_pur_epoch = value

    @property
    def train_finish(self):
        return self._train_finish

    def tasklayer_reuse_with(self, task):
        assert isinstance(task, Task)
        if self._lock:
            raise Exception('you can only set tasklayer reuses BEFORE Controller created.')
        self._task_reuse_scope = task.name
    
    def _set_lock(self):
        self._lock = True