trainer.py 36.5 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: utf-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16
from __future__ import print_function
X
xixiaoyao 已提交
17 18 19
import os
import json
from paddle import fluid
X
xixiaoyao 已提交
20 21
import time
import numpy as np
X
xixiaoyao 已提交
22
import paddlepalm.utils.basic_helper as helper
X
xixiaoyao 已提交
23
from paddlepalm.utils import reader_helper, saver
W
wangxiao1021 已提交
24
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
X
xixiaoyao 已提交
25
# from paddlepalm.default_settings import *
X
xixiaoyao 已提交
26

X
xixiaoyao 已提交
27
DEBUG=False
X
xixiaoyao 已提交
28

X
xixiaoyao 已提交
29

X
xixiaoyao 已提交
30
class Trainer(object):
W
wangxiao1021 已提交
31 32 33
    """
    The core unit to start a training/predicting session for single task. A trainer is to build computation graph, manage training and evaluation process, achieve model/checkpoint saving and pretrain_model/checkpoint loading.
    """
X
xixiaoyao 已提交
34

W
wangxiao1021 已提交
35 36 37 38 39 40 41 42 43
    def __init__(self, name, mix_ratio=1.0, reuse_head_with=None):
        """Create a new trainer.

        Args:
            name: string. The name of the trainer(training task).
            mix_ratio: sampling weight of this trainer in multi-task learning mode. Default is 1.0.
            reuse_head_with: reuse parameters of task head with another trainer. Default is None, not reuse with others.

        """
X
xixiaoyao 已提交
44 45

        self._name = name
X
xixiaoyao 已提交
46
        self._pred_reader = None
W
wangxiao1021 已提交
47 48
        self._task_head = None
        self._pred_head = None
W
wangxiao1021 已提交
49
      
W
wangxiao1021 已提交
50 51 52 53 54 55 56
        self._train_reader = None
        self._predict_reader = None
        self._train_iterator = None
        self._predict_iterator = None

        self._train_init = False
        self._predict_init = False
W
wangxiao1021 已提交
57 58
        self._train_init_prog = None
        self._pred_init_prog = None
W
wangxiao1021 已提交
59 60

        self._check_save = lambda: False
X
xixiaoyao 已提交
61

X
xixiaoyao 已提交
62 63 64
        # if save_predict_model:
        #     self._save_predict_model = True
        #     assert pred_head is not None, "pred_head is required to save predict model."
W
wangxiao1021 已提交
65
        #     self._pred_reader = reader.clone(phase='predict')
X
xixiaoyao 已提交
66 67 68 69
        # else:
        #     assert pred_head is None, "You should set save_predict_model as True, or the pred_head is invalid." 
        #     self._save_predict_model = False
        #     self._pred_reader = None
X
xixiaoyao 已提交
70

X
xixiaoyao 已提交
71
        # self._save_steps = save_steps
X
xixiaoyao 已提交
72

X
xixiaoyao 已提交
73
        self._task_reuse_scope = name if reuse_head_with is None else reuse_head_with
X
xixiaoyao 已提交
74 75 76 77

        self._feeded_var_names = None
        self._target_vars = None

X
xixiaoyao 已提交
78 79
        self._num_examples = 0

W
wangxiao1021 已提交
80 81 82 83
        self._multi_task = False
        self._as_auxilary = False
        self._task_id = None

X
xixiaoyao 已提交
84 85 86 87 88
        # training process management
        self._mix_ratio = mix_ratio
        self._expected_train_steps = None
        self._expected_train_epochs = None
        self._steps_pur_epoch = None
W
wangxiao1021 已提交
89
        self._pred_steps_pur_epoch = None
X
xixiaoyao 已提交
90 91 92 93 94 95 96 97 98 99
        self._cur_train_epoch = 0
        self._cur_train_step = 0
        self._train_finish = False

        self._inputname_to_varname = {}
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

X
xixiaoyao 已提交
100 101
        # exe is built when random_init_params is called.
        # self._exe = helper.build_executor(gpu_dev_count>0)
X
xixiaoyao 已提交
102
        self._exe = None
X
xixiaoyao 已提交
103 104 105 106 107 108 109

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

        self._lock = False
X
xixiaoyao 已提交
110 111
        self._build_forward = False

W
wangxiao1021 已提交
112 113 114
    def build_forward(self, backbone, task_head):
        """
        Build forward computation graph for training, which usually built from input layer to loss node.
X
xixiaoyao 已提交
115

W
wangxiao1021 已提交
116 117 118
        Args:
            backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding.
            head: a Head object with phase == 'train', which is used to build task specific output layers.
X
xixiaoyao 已提交
119
        
W
wangxiao1021 已提交
120 121 122
        Return:
            loss_var: a Variable object. The computational graph variable(node) of loss.
        """
X
xixiaoyao 已提交
123 124


W
wangxiao1021 已提交
125 126 127
        # assert not self._multi_task, "you cannot build_forward in trainer when a train is wrapper by MultiHeadTrainer."
        self._task_head = task_head
        self._backbone = backbone
X
xixiaoyao 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

        # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
        self._build_forward = True
        
        # create reader, task
        # then check i/o across reader, backbone and task_layer
        task_attrs = []
        pred_task_attrs = []

        task_attr_from_reader = helper.encode_inputs(self._task_head.inputs_attrs['reader'], self.name)
        # task_attr_from_reader = self._task_head.inputs_attrs['reader']

        # _check_io(backbone.inputs_attr, inst._reader['train'].outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train')
        # _check_io(inst.taskblock['train'].inputs_attrs['reader'], inst._reader['train'].outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train')
        # _check_io(inst._taskblock['train'].inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone')


        # merge reader input attrs from backbone and task_instances
W
wangxiao1021 已提交
146
        input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False)
X
xixiaoyao 已提交
147 148 149
        # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
        self._shape_and_dtypes = shape_and_dtypes
        self._name_to_position = name_to_position
W
wangxiao1021 已提交
150
        self._input_names = input_names
X
xixiaoyao 已提交
151 152 153 154 155 156 157 158 159 160

        if DEBUG:
            print('----- for debug -----')
            print('joint input names:')
            print(joint_input_names)
            print('joint input shape and dtypes:')
            print(joint_shape_and_dtypes)

        input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]

W
wangxiao1021 已提交
161 162 163
        train_prog = fluid.Program()
        train_init_prog = fluid.Program()

X
xixiaoyao 已提交
164 165
        self._train_prog = train_prog
        self._train_init_prog = train_init_prog
W
wangxiao1021 已提交
166 167 168 169 170
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
                bb_output_vars = backbone.build(net_inputs)
        else:
X
xixiaoyao 已提交
171 172
            net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
            bb_output_vars = backbone.build(net_inputs)
W
wangxiao1021 已提交
173 174 175 176
        self._net_inputs = net_inputs
        assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())

        # self._bb_output_vars.keys
X
xixiaoyao 已提交
177 178 179 180 181 182 183 184 185 186

        # fluid.framework.switch_main_program(train_prog)
        # fluid.framework.switch_startup_program(train_init_prog)

        task_output_vars = {}
        task_inputs = {'backbone': bb_output_vars}
        task_inputs_from_reader = helper.decode_inputs(net_inputs, self.name)
        task_inputs['reader'] = task_inputs_from_reader

        scope = self.name+'.'
W
wangxiao1021 已提交
187 188 189 190 191
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                with fluid.unique_name.guard(scope):
                    output_vars = self._build_head(task_inputs, phase='train', scope=scope)
        else:
X
xixiaoyao 已提交
192 193
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(task_inputs, phase='train', scope=scope)
W
wangxiao1021 已提交
194

X
xixiaoyao 已提交
195 196 197 198 199 200 201
        output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
        old = len(task_output_vars) # for debug
        task_output_vars.update(output_vars)
        assert len(task_output_vars) - old == len(output_vars) # for debug

        bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
        task_fetches = {k: v.name for k,v in task_output_vars.items()}
X
xixiaoyao 已提交
202
        self._fetches = task_fetches
X
xixiaoyao 已提交
203
        self._fetch_names, self._fetch_list = zip(*self._fetches.items())
X
xixiaoyao 已提交
204 205 206 207 208 209 210 211
        # fetches = task_fetches
        # fetches['__task_id'] = net_inputs['__task_id'].name

        # compute loss
        # task_id_var = net_inputs['__task_id']
        # task_id_vec = layers.one_hot(task_id_var, num_instances)
        # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
        # loss = layers.reduce_sum(task_id_vec * losses)
W
wangxiao1021 已提交
212 213 214 215
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
        else:
X
xixiaoyao 已提交
216
            loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
X
xixiaoyao 已提交
217

W
wangxiao1021 已提交
218 219 220 221 222 223 224 225
        # for _id, block in enumerate(self._train_prog.blocks):
        #   for var in block.vars:
        #     print("[debug] : %d, %s" % (_id, var))
        self._loss_var = loss_var

        if not self._multi_task:
            self._init_exe_prog(for_train=True)

X
xixiaoyao 已提交
226 227
        return loss_var

W
wangxiao1021 已提交
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    def build_predict_forward(self, pred_backbone, pred_head):
        """
        Build computation graph for evaluation and prediction.

        Arguments:
            - pred_backbone: a Backbone object with phase == 'predict'. For evaluating model during training, the predict backbone should keep the same with train backbone.
            - pred_head: a Head object with phase == 'predict'. For evaluating model during training, the predict head should keep the same with train head.
        
        Return:
            - output_vars: dict type. Each value is a computational graph variable(node) argumented by pred_head outputs_attr.
        """
        self._pred_head = pred_head
        self._pred_backbone = pred_backbone
        # self._pred_reader = self._reader.clone(phase='pred')
        pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
        # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']

        # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')

        # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
        pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False)
        pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
        self._pred_shape_and_dtypes = pred_shape_and_dtypes
        self._pred_name_to_position = pred_name_to_position

        pred_prog = fluid.Program()
        self._pred_prog = pred_prog
        pred_init_prog = fluid.Program()
        self._pred_init_prog = pred_init_prog
        with fluid.program_guard(pred_prog, pred_init_prog):
            pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
            # pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
            pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
            self._pred_net_inputs = pred_net_inputs

        # prepare predict vars for saving inference model
        with fluid.program_guard(pred_prog, pred_init_prog):
            cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
            # self.pred_input = cur_inputs
            self._pred_input_name_list, self._pred_input_varname_list = \
                zip(*[[k, v.name] for k,v in cur_inputs.items()])

            pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
            scope = self.name + '.'
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(pred_task_inputs, phase='predict', scope=scope)

        if output_vars is not None:
            self._pred_fetch_name_list, self._pred_fetch_list = zip(*output_vars.items())
        else:
            self._pred_fetch_name_list = []
            self._pred_fetch_var_list = []

        if not self._multi_task:
            self._init_exe_prog(for_train=False)
            self._exe.run(self._pred_init_prog)
            
        return output_vars

    def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=None):
        """
        Build backward computation graph and training strategy.

        Arguments:
            - optimizer: 
            - weight_decay: optional, default is None (disable weight decay).
            - use_ema: optional, default is False. The flag to control whether to apply Exponential Moving Average strategy on parameter updates.
            - ema_decay: optional, default is None. Only works with use_ema == True. Control decay rate of EMA strategy.

        """
        # assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
301
        # build optimizer
W
wangxiao1021 已提交
302 303
        assert self._loss_var is not None and self._train_init_prog is not None, "train graph not foung! You should build_forward first."
        optimizer._set_prog(self._train_prog, self._train_init_prog)
X
xixiaoyao 已提交
304
        with fluid.program_guard(self._train_prog, self._train_init_prog):
W
wangxiao1021 已提交
305
            param_grads = optimizer._build()
X
xixiaoyao 已提交
306 307 308 309 310

            if weight_decay is not None:

                param_list = dict()

W
wangxiao1021 已提交
311
                for param in self._train_prog.global_block().all_parameters():
X
xixiaoyao 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
                    param_list[param.name] = param * 1.0
                    param_list[param.name].stop_gradient = True

                def exclude_from_weight_decay(name):
                    if name.find("layer_norm") > -1:
                        return True
                    bias_suffix = ["_bias", "_b", ".b_0"]
                    for suffix in bias_suffix:
                        if name.endswith(suffix):
                            return True
                    return False

                for param, grad in param_grads:
                    if exclude_from_weight_decay(param.name):
                        continue
                    with param.block.program._optimized_guard(
                        [param, grad]), fluid.framework.name_scope("weight_decay"):
                        updated_param = param - param_list[
                            param.name] * weight_decay * optimizer.get_cur_learning_rate()
                        fluid.layers.assign(output=param, input=updated_param)


            # loss.persistable = True
            if use_ema:
                ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
                ema.update()

W
wangxiao1021 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
        # for bid, block in enumerate(self._train_prog.blocks):
        #     print('block id: '+str(bid))
        #     for var in block.vars:
        #         print("%d : %s" % (bid, var))
            
        # print(self._train_prog)
        self._exe.run(self._train_init_prog)

    def set_as_aux(self):
        """Set the task in this trainer as auxilary task. \nCAUSIOUS: This API only works on multi-task learning mode. Each task is set as target task by default. """
        self._as_auxilary = True

    def fit_reader(self, reader, phase='train'):
        """
        Bind a reader and loaded train/predict data to trainer. 
        
        Args:
            reader: a Reader object. The running phase of the reader should be consistent with `phase` argument of this method.
            phase: running phase. Currently support: train, predict.

        """
        # assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
361
        # load data
W
wangxiao1021 已提交
362 363

        self._check_phase(phase)
W
wangxiao1021 已提交
364 365 366 367
        if phase=='train':
            assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
        else:
            assert self._pred_shape_and_dtypes is not None, "You need to build_forward     or build_predict_head first to prepare input features."
W
wangxiao1021 已提交
368

X
xixiaoyao 已提交
369 370 371
        # 这里不确定是否要向上取整,需确认
        # tail = self._num_examples % batch_size > 0
        # self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0
W
wangxiao1021 已提交
372 373 374 375 376 377
        
        batch_size = reader._batch_size

        self._num_epochs = reader.num_epochs
        if phase == 'train':
            self._train_reader = reader
W
wangxiao1021 已提交
378
            self._steps_pur_epoch = reader.num_examples // batch_size
W
wangxiao1021 已提交
379 380 381 382 383 384 385 386 387 388 389 390
            shape_and_dtypes = self._shape_and_dtypes
            name_to_position = self._name_to_position
            if self._task_id is not None:
                self._net_inputs['__task_id'] = self._task_id
            net_inputs = self._net_inputs
            self._train_batch_size = batch_size
            self._num_examples = reader.num_examples
            reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone')
        elif phase == 'predict':
            self._predict_reader = reader
W
wangxiao1021 已提交
391 392 393
            # tail = self._num_examples % batch_size > 0
            # self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
            self._pred_steps_pur_epoch = reader.num_examples // batch_size 
W
wangxiao1021 已提交
394 395 396 397 398 399 400 401 402 403 404
            shape_and_dtypes = self._pred_shape_and_dtypes
            name_to_position = self._pred_name_to_position
            net_inputs = self._pred_net_inputs
            self._predict_batch_size = batch_size
            self._pred_num_examples = reader.num_examples
            reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone')
        else:
            raise NotImplementedError()
            
X
xixiaoyao 已提交
405 406 407
        print('ok!')

        # merge dataset iterators and create net input vars
W
wangxiao1021 已提交
408 409 410 411 412 413
        iterator = reader._iterator()
        prefix = self.name


        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
X
xixiaoyao 已提交
414 415 416
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
W
wangxiao1021 已提交
417 418 419
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict')
        self._raw_iterator_fn = iterator_fn
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
X
xixiaoyao 已提交
420
        if gpu_dev_count > 1:
W
wangxiao1021 已提交
421
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase)
X
xixiaoyao 已提交
422
        else:
W
wangxiao1021 已提交
423
            distribute_feeder_fn = iterator_fn()
X
xixiaoyao 已提交
424

W
wangxiao1021 已提交
425 426 427 428 429 430 431 432
        if phase == 'train':
            self._train_iterator = distribute_feeder_fn
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_iterator = distribute_feeder_fn
            self._pred_feed_batch_process_fn = feed_batch_process_fn
        # return distribute_feeder_fn()

W
wangxiao1021 已提交
433

W
wangxiao1021 已提交
434 435 436 437 438 439 440
    def load_ckpt(self, model_path):
        """
        load training checkpoint for further training or predicting.

        Args:
            model_path: the path of saved checkpoint/parameters.
        """
X
xixiaoyao 已提交
441
        # load pretrain model (or ckpt)
W
wangxiao1021 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
        # assert self._exe is not None, "You need to random_init_params before load checkpoints."
        # if phase == 'train' and not self._train_init:
        #     self._init_exe_prog(for_train=True)
        #     self._exe.run(self._train_init_prog)
        # if phase == 'predict' and not self._predict_init:
        #     self._init_exe_prog(for_train=False)
        #     self._exe.run(self._pred_init_prog)

        assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."

        # if phase == 'train':
        #     assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
        if self._train_init_prog is not None:
            saver.init_pretraining_params(
                self._exe,
                model_path,
                convert=False,
                main_program=self._train_init_prog,
                strict=True)
        # elif phase == 'predict':
        elif self._pred_init_prog is not None:
            # assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
                convert=False,
                main_program=self._pred_init_prog,
                strict=True)
        else:
            raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
W
wangxiao1021 已提交
472

W
wangxiao1021 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
    def load_predict_model(self, model_path, convert=False):
        """
        load pretrain models(backbone) for training.

        Args:
            model_path: the path of saved pretrained parameters.
        """

        assert self._pred_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters."

        saver.init_pretraining_params(
            self._exe,
            model_path,
            convert=convert,
            main_program=self._pred_prog)
        # raise NotImplementedError()
W
wangxiao1021 已提交
489 490 491 492 493 494 495 496 497 498 499

    def load_pretrain(self, model_path, convert=False):
        """
        load pretrain models(backbone) for training.

        Args:
            model_path: the path of saved pretrained parameters.
        """
        # load pretrain model (or ckpt)
        # assert self._exe is not None, "You need to random_init_params before load pretrain models."
        assert self._train_init_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters."
X
xixiaoyao 已提交
500 501 502 503

        saver.init_pretraining_params(
            self._exe,
            model_path,
W
wangxiao1021 已提交
504
            convert=convert,
X
xixiaoyao 已提交
505
            main_program=self._train_init_prog)
X
xixiaoyao 已提交
506

W
wangxiao1021 已提交
507
    def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False):
W
wangxiao1021 已提交
508 509 510 511 512 513 514
        """
        create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.

        Args:
            save_path: a string. the path to save checkpoints or predict models.
            save_steps: an integer. the frequency to save models.
            save_type: a string. The type of saved model. Currently support checkpoint(ckpt) and predict model(predict), default is ckpt. If both two types are needed to save, you can set as "ckpt,predict".
X
xixiaoyao 已提交
515

W
wangxiao1021 已提交
516
        """
W
wangxiao1021 已提交
517
        
X
xixiaoyao 已提交
518

X
xixiaoyao 已提交
519 520
        save_type = save_type.split(',')
        if 'predict' in save_type:
W
wangxiao1021 已提交
521
            assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
X
xixiaoyao 已提交
522
            assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
W
wangxiao1021 已提交
523
            self._save_predict = True
X
xixiaoyao 已提交
524 525 526
            if not os.path.exists(save_path):
                os.makedirs(save_path)
        else:
W
wangxiao1021 已提交
527
            self._save_predict = False
X
xixiaoyao 已提交
528 529 530

        if 'ckpt' in save_type:
            if save_path is not None and save_steps is not None:
W
wangxiao1021 已提交
531
                self._save_ckpt = True
X
xixiaoyao 已提交
532 533 534 535
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
            else:
                "WARNING: save_path or save_steps is not set, model will not be saved during training."
W
wangxiao1021 已提交
536
                self._save_ckpt = False
X
xixiaoyao 已提交
537
        else:
W
wangxiao1021 已提交
538 539 540 541
            self._save_ckpt = False

        def temp_func():
            if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
W
wangxiao1021 已提交
542

W
wangxiao1021 已提交
543
                if self._save_predict:
W
wangxiao1021 已提交
544 545 546 547 548 549
                    if is_multi:
                        self._save(save_path, suffix='-pred.step'+str(self._cur_train_step))
                        print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
                    else:
                        self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
                        print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
W
wangxiao1021 已提交
550
                if self._save_ckpt:
W
wangxiao1021 已提交
551 552 553 554 555 556
                    if is_multi:
                        fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
                        print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
                    else:
                        fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
                        print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
W
wangxiao1021 已提交
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
                return True
            else:
                return False

        self._check_save = temp_func
            
    def train(self, print_steps=5):
        """
        start training.

        Args:
            print_steps: int. Logging frequency of training message, e.g., current step, loss and speed.
        """
        
        iterator = self._train_iterator
        self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
X
xixiaoyao 已提交
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592

        # if save_path is not None or save_steps is not None:
        #     assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built."
        # if self._save_predict_model:
        #     if save_path is None and save_steps is None:
        #         print('Warning: model will not be saved for this run. If you want to save model, set save_path and save_steps.')
        #     else:
        #         assert save_path is not None, "argument save_path is required to save models."
        #         assert save_steps == -1 or save_steps > 0, "argument save_steps should be -1 (only save the last step of this task) or larger than 0"
        #         if save_path is not None and not os.path.exists(save_path):
        #             os.makedirs(save_path)
        # else:
        #     assert save_path is None, "You should set save_predict_model as True, or the argument save_path is invalid."
        #     assert save_steps is None, "You should set save_predict_model as True, or the argument save_steps is invalid."

        time_begin = time.time()
        for feed in iterator:
            rt_outputs = self.train_one_step(feed)
            # if gpu_dev_count > 1:
            #     feed, mask = feed
W
wangxiao1021 已提交
593
            # rt_outputs = self._exe.run(self._train_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
594 595 596 597 598 599 600 601 602
            # print(rt_outputs)
            # print(len(rt_outputs))
            # if gpu_dev_count > 1:
            #     while mask.pop() == False:
            #         rt_outputs.pop()

            # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}

            task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
W
wangxiao1021 已提交
603
            self._task_head.batch_postprocess(task_rt_outputs)
X
xixiaoyao 已提交
604 605 606 607 608 609 610 611 612 613


            if print_steps > 0 and self._cur_train_step % print_steps == 0:
                loss = rt_outputs[self.name+'.loss']
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print("step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
W
wangxiao1021 已提交
614
                       (self._cur_train_step-1) % self._steps_pur_epoch + 1 , self._steps_pur_epoch, self._cur_train_epoch,
X
xixiaoyao 已提交
615
                       loss, print_steps / time_cost))
W
wangxiao1021 已提交
616
                time_begin = time.time() 
W
wangxiao1021 已提交
617
                # self._check_save()
X
xixiaoyao 已提交
618 619 620 621
            # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
            #     print(cur_task.name+': train finished!')
            #     cur_task.save()

W
wangxiao1021 已提交
622 623
            if self._num_epochs is None and not self._multi_task and self._cur_train_step == self._steps_pur_epoch:
                break
X
xixiaoyao 已提交
624 625 626 627 628
        # save_path = os.path.join(main_conf['save_path'], 'ckpt',
        #                          "step_" + str(global_step))
        # fluid.io.save_persistables(self.exe, save_path, saver_program)

        # print("ALL tasks train finished, exiting...")
W
wangxiao1021 已提交
629 630 631 632 633 634 635 636 637 638 639 640 641 642
        
    def predict(self, output_dir=None, print_steps=1000):
        """
        start predicting.

        Args:
            output_dir: str. The path to save prediction results, default is None. If set as None, the results would output to screen directly. 
            print_steps: int. Logging frequency of predicting message, e.g., current progress and speed.
        """
        iterator = self._predict_iterator
        self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()

        if output_dir is not None and not os.path.exists(output_dir):
            os.makedirs(output_dir)
X
xixiaoyao 已提交
643

W
wangxiao1021 已提交
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
        time_begin = time.time()
        
        cur_predict_step = 0
        for feed in iterator:
            rt_outputs = self.predict_one_batch(feed)
            # rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
            self._pred_head.batch_postprocess(rt_outputs)

            cur_predict_step += 1

            if print_steps > 0 and cur_predict_step % print_steps == 0:
                time_end = time.time()
                time_cost = time_end - time_begin

                print("batch {}/{}, speed: {:.2f} steps/s".format(
                       cur_predict_step, self._pred_steps_pur_epoch,
                       print_steps / time_cost))
                time_begin = time.time()

        if self._pred_head.epoch_inputs_attrs:
            reader_outputs = self._predict_reader.get_epoch_outputs()
        else:
            reader_outputs = None

        results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
        return results

    def _check_phase(self, phase):
        assert phase in ['train', 'predict'], "Supported phase: train, predict,"

    def _set_multitask(self):
        self._multi_task = True

    def _set_task_id(self, task_id):
        self._task_id = task_id

    def _init_exe_prog(self, for_train=True):
        if not self._train_init and not self._predict_init:
            on_gpu = gpu_dev_count > 0
            self._exe = helper.build_executor(on_gpu)

        if for_train:
            assert self._train_prog is not None, "train graph not found! You should build_forward first before you random init parameters."
            self._train_init = True
        else:
            assert self._pred_prog is not None, "predict graph not found! You should build_predict_head first before you random init parameters."
            self._predict_init = True

    # def random_init_params(self):
    #     """
    #     randomly initialize model parameters.
    #     """
    #     
    #     if not self._train_init:
    #         self._init_exe_prog()
    #     
    #     print('random init params...')
    #     self._exe.run(self._train_init_prog)

    def get_one_batch(self, phase='train'):
        self._check_phase(phase)
        if phase == 'train':
            return next(self._train_reader)
        elif phase == 'predict':
            return next(self._predict_reader)
        else:
            raise NotImplementedError()

    def _set_exe(self, exe):
        self._exe = exe

    def _set_dist_train(self, prog):
        self._distribute_train_prog = prog

    def _set_fetch_list(self, fetch_list):
        self._fetch_list = fetch_list

    # def train_one_step(self, batch, executor=None, distribute_train_prog=None, fetch_list=None):
X
xixiaoyao 已提交
722
    def train_one_step(self, batch):
W
wangxiao1021 已提交
723 724 725 726 727 728 729 730
        # exe = self._exe if executor is None else executor
        # distribute_train_prog = self._distribute_train_prog if distribute_train_prog is None else distribute_train_prog
        # fetch_list = self._fetch_list if fetch_list is None else fetch_list

        exe = self._exe
        distribute_train_prog = self._distribute_train_prog
        fetch_list = self._fetch_list

X
xixiaoyao 已提交
731 732
        if gpu_dev_count > 1:
            feed, mask = batch
W
wangxiao1021 已提交
733 734
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size)
W
wangxiao1021 已提交
735 736 737
            if num_fakes:
                rt_outputs = [i[:-num_fakes] for i in rt_outputs]
        
X
xixiaoyao 已提交
738 739
        else:
            feed = self._feed_batch_process_fn(batch)
W
wangxiao1021 已提交
740
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
X
xixiaoyao 已提交
741 742

        rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
W
wangxiao1021 已提交
743
        self._cur_train_step += 1
W
wangxiao1021 已提交
744
        self._check_save()
W
wangxiao1021 已提交
745
        self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
X
xixiaoyao 已提交
746
        return rt_outputs
W
wangxiao1021 已提交
747 748 749 750 751

    def predict_one_batch(self, batch):
        if gpu_dev_count > 1:
            feed, mask = batch
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
W
wangxiao1021 已提交
752
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size)
W
wangxiao1021 已提交
753 754
            if num_fakes:
                rt_outputs = [i[:-num_fakes] for i in rt_outputs]
W
wangxiao1021 已提交
755 756 757 758 759 760 761 762 763 764 765 766
        else:
            feed = self._pred_feed_batch_process_fn(batch)
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)

        rt_outputs = {k:v for k,v in zip(self._pred_fetch_name_list, rt_outputs)}
        return rt_outputs



    @property
    def name(self):
        return self._name
W
wangxiao1021 已提交
767
    
W
wangxiao1021 已提交
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
    @property
    def num_examples(self):
        return self._num_examples

    @property
    def mix_ratio(self):
        return self._mix_ratio

    @mix_ratio.setter
    def mix_ratio(self, value):
        self._mix_ratio = value

    @property
    def num_epochs(self):
        return self._num_epochs

    @property
    def cur_train_step(self):
        return self._cur_train_step

    @property
    def cur_train_epoch(self):
        return self._cur_train_epoch

    @property
    def steps_pur_epoch(self):
        return self._steps_pur_epoch
X
xixiaoyao 已提交
795

X
xixiaoyao 已提交
796
    def _build_head(self, net_inputs, phase, scope=""):
W
wangxiao1021 已提交
797
        self._check_phase(phase)
X
xixiaoyao 已提交
798 799
        if phase == 'train':
            output_vars = self._task_head.build(net_inputs, scope_name=scope)
W
wangxiao1021 已提交
800
        if phase == 'predict':
X
xixiaoyao 已提交
801
            output_vars = self._pred_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
802 803
        return output_vars
    
W
wangxiao1021 已提交
804
    def _save(self, save_path, suffix=None):
X
xixiaoyao 已提交
805 806 807 808 809
        # dirpath = save_path.rstrip('/').rstrip('\\') + suffix
        if suffix is not None:
            dirpath = os.path.join(save_path, suffix)
        else:
            dirpath = save_path
X
xixiaoyao 已提交
810 811
        self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]

W
wangxiao1021 已提交
812
        prog = self._pred_prog.clone()
X
xixiaoyao 已提交
813 814 815 816 817 818 819 820 821 822
        fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)

        conf = {}
        for k, strv in self._save_protocol.items(): 
            d = None
            v = locals()
            exec('d={}'.format(strv), globals(), v)
            conf[k] = v['d']
        with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
            writer.write(json.dumps(conf, indent=1))
X
xixiaoyao 已提交
823
        print(self._name + ': predict model saved at ' + dirpath)
X
xixiaoyao 已提交
824

W
wangxiao1021 已提交
825
    
X
xixiaoyao 已提交
826 827 828 829 830 831 832 833 834 835 836
    def _load(self, infer_model_path=None):
        if infer_model_path is None:
            infer_model_path = self._save_infermodel_path
        for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): 
            strv = self._save_protocol[k]
            exec('{}=v'.format(strv))
        pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
            fluid.io.load_inference_model(infer_model_path, self._exe)
        print(self._name+': inference model loaded from ' + infer_model_path)
        return pred_prog