trainer.py 30.8 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: utf-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16
from __future__ import print_function
X
xixiaoyao 已提交
17 18 19
import os
import json
from paddle import fluid
X
xixiaoyao 已提交
20
import time
W
wangxiao1021 已提交
21
import sys
X
xixiaoyao 已提交
22
import numpy as np
X
xixiaoyao 已提交
23
import paddlepalm.utils.basic_helper as helper
X
xixiaoyao 已提交
24
from paddlepalm.utils import reader_helper, saver
W
wangxiao1021 已提交
25
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
X
xixiaoyao 已提交
26
# from paddlepalm.default_settings import *
X
xixiaoyao 已提交
27

X
xixiaoyao 已提交
28
DEBUG=False
X
xixiaoyao 已提交
29

X
xixiaoyao 已提交
30

X
xixiaoyao 已提交
31
class Trainer(object):
W
wangxiao1021 已提交
32 33 34
    """
    The core unit to start a training/predicting session for single task. A trainer is to build computation graph, manage training and evaluation process, achieve model/checkpoint saving and pretrain_model/checkpoint loading.
    """
X
xixiaoyao 已提交
35

W
wangxiao1021 已提交
36 37 38 39 40 41 42 43 44
    def __init__(self, name, mix_ratio=1.0, reuse_head_with=None):
        """Create a new trainer.

        Args:
            name: string. The name of the trainer(training task).
            mix_ratio: sampling weight of this trainer in multi-task learning mode. Default is 1.0.
            reuse_head_with: reuse parameters of task head with another trainer. Default is None, not reuse with others.

        """
X
xixiaoyao 已提交
45 46

        self._name = name
X
xixiaoyao 已提交
47
        self._pred_reader = None
W
wangxiao1021 已提交
48 49
        self._task_head = None
        self._pred_head = None
W
wangxiao1021 已提交
50
      
W
wangxiao1021 已提交
51
        self._train_reader = None
X
xixiaoyao 已提交
52
        self._dist_train_init = False
W
wangxiao1021 已提交
53 54 55 56 57 58
        self._predict_reader = None
        self._train_iterator = None
        self._predict_iterator = None

        self._train_init = False
        self._predict_init = False
W
wangxiao1021 已提交
59 60
        self._train_init_prog = None
        self._pred_init_prog = None
W
wangxiao1021 已提交
61 62

        self._check_save = lambda: False
X
xixiaoyao 已提交
63

X
xixiaoyao 已提交
64
        self._task_reuse_scope = name if reuse_head_with is None else reuse_head_with
X
xixiaoyao 已提交
65 66 67

        self._feeded_var_names = None
        self._target_vars = None
X
xixiaoyao 已提交
68
        self._predict_vars = None
X
xixiaoyao 已提交
69

X
xixiaoyao 已提交
70 71
        self._num_examples = 0

W
wangxiao1021 已提交
72 73 74 75
        self._multi_task = False
        self._as_auxilary = False
        self._task_id = None

X
xixiaoyao 已提交
76 77 78 79 80
        # training process management
        self._mix_ratio = mix_ratio
        self._expected_train_steps = None
        self._expected_train_epochs = None
        self._steps_pur_epoch = None
W
wangxiao1021 已提交
81
        self._pred_steps_pur_epoch = None
X
xixiaoyao 已提交
82 83 84 85 86 87 88 89 90 91
        self._cur_train_epoch = 0
        self._cur_train_step = 0
        self._train_finish = False

        self._inputname_to_varname = {}
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

X
xixiaoyao 已提交
92
        # exe is built when random_init_params called.
X
xixiaoyao 已提交
93
        self._exe = None
X
xixiaoyao 已提交
94 95 96 97 98 99 100

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

        self._lock = False
W
wangxiao1021 已提交
101
        self._lock_prog = False
X
xixiaoyao 已提交
102 103
        self._build_forward = False

W
wangxiao1021 已提交
104 105 106
    def build_forward(self, backbone, task_head):
        """
        Build forward computation graph for training, which usually built from input layer to loss node.
X
xixiaoyao 已提交
107

W
wangxiao1021 已提交
108 109 110
        Args:
            backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding.
            head: a Head object with phase == 'train', which is used to build task specific output layers.
X
xixiaoyao 已提交
111
        
W
wangxiao1021 已提交
112 113 114
        Return:
            loss_var: a Variable object. The computational graph variable(node) of loss.
        """
X
xixiaoyao 已提交
115 116


W
wangxiao1021 已提交
117 118
        self._task_head = task_head
        self._backbone = backbone
X
xixiaoyao 已提交
119 120 121 122 123 124 125 126 127 128 129

        self._build_forward = True
        
        # create reader, task
        # then check i/o across reader, backbone and task_layer
        task_attrs = []
        pred_task_attrs = []

        task_attr_from_reader = helper.encode_inputs(self._task_head.inputs_attrs['reader'], self.name)

        # merge reader input attrs from backbone and task_instances
W
wangxiao1021 已提交
130
        input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False)
X
xixiaoyao 已提交
131 132 133
        # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
        self._shape_and_dtypes = shape_and_dtypes
        self._name_to_position = name_to_position
W
wangxiao1021 已提交
134
        self._input_names = input_names
X
xixiaoyao 已提交
135 136 137 138 139 140 141 142 143 144

        if DEBUG:
            print('----- for debug -----')
            print('joint input names:')
            print(joint_input_names)
            print('joint input shape and dtypes:')
            print(joint_shape_and_dtypes)

        input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]

W
wangxiao1021 已提交
145 146 147
        train_prog = fluid.Program()
        train_init_prog = fluid.Program()

W
wangxiao1021 已提交
148 149 150 151
        if not self._lock_prog:
            self._train_prog = train_prog
            self._train_init_prog = train_init_prog

W
wangxiao1021 已提交
152
        if not self._lock_prog:
W
wangxiao1021 已提交
153
            with fluid.program_guard(train_prog, train_init_prog):
W
wangxiao1021 已提交
154
                net_inputs = reader_helper.create_net_inputs(input_attrs, is_async=False)
W
wangxiao1021 已提交
155 156
                bb_output_vars = backbone.build(net_inputs)
        else:
W
wangxiao1021 已提交
157
            net_inputs = reader_helper.create_net_inputs(input_attrs, is_async=False)
X
xixiaoyao 已提交
158
            bb_output_vars = backbone.build(net_inputs)
W
wangxiao1021 已提交
159 160 161
        self._net_inputs = net_inputs
        assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())

X
xixiaoyao 已提交
162 163 164 165 166 167
        task_output_vars = {}
        task_inputs = {'backbone': bb_output_vars}
        task_inputs_from_reader = helper.decode_inputs(net_inputs, self.name)
        task_inputs['reader'] = task_inputs_from_reader

        scope = self.name+'.'
W
wangxiao1021 已提交
168
        if not self._lock_prog:
W
wangxiao1021 已提交
169 170 171 172
            with fluid.program_guard(train_prog, train_init_prog):
                with fluid.unique_name.guard(scope):
                    output_vars = self._build_head(task_inputs, phase='train', scope=scope)
        else:
X
xixiaoyao 已提交
173 174
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(task_inputs, phase='train', scope=scope)
W
wangxiao1021 已提交
175

X
xixiaoyao 已提交
176 177 178 179 180 181 182
        output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
        old = len(task_output_vars) # for debug
        task_output_vars.update(output_vars)
        assert len(task_output_vars) - old == len(output_vars) # for debug

        bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
        task_fetches = {k: v.name for k,v in task_output_vars.items()}
X
xixiaoyao 已提交
183
        self._fetches = task_fetches
X
xixiaoyao 已提交
184
        self._fetch_names, self._fetch_list = zip(*self._fetches.items())
W
wangxiao1021 已提交
185
        if not self._lock_prog:
W
wangxiao1021 已提交
186 187 188
            with fluid.program_guard(train_prog, train_init_prog):
                loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
        else:
X
xixiaoyao 已提交
189
            loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
X
xixiaoyao 已提交
190

W
wangxiao1021 已提交
191 192 193 194 195
        self._loss_var = loss_var

        if not self._multi_task:
            self._init_exe_prog(for_train=True)

X
xixiaoyao 已提交
196 197
        return loss_var

W
wangxiao1021 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
    def build_predict_forward(self, pred_backbone, pred_head):
        """
        Build computation graph for evaluation and prediction.

        Arguments:
            - pred_backbone: a Backbone object with phase == 'predict'. For evaluating model during training, the predict backbone should keep the same with train backbone.
            - pred_head: a Head object with phase == 'predict'. For evaluating model during training, the predict head should keep the same with train head.
        
        Return:
            - output_vars: dict type. Each value is a computational graph variable(node) argumented by pred_head outputs_attr.
        """
        self._pred_head = pred_head
        self._pred_backbone = pred_backbone
        pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)

        pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False)
        pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
        self._pred_shape_and_dtypes = pred_shape_and_dtypes
        self._pred_name_to_position = pred_name_to_position
X
xixiaoyao 已提交
217
        self._pred_input_names = pred_input_names
W
wangxiao1021 已提交
218

X
xixiaoyao 已提交
219 220 221 222 223 224 225 226 227 228 229
        if not self._lock_prog:
            pred_prog = fluid.Program()
            self._pred_prog = pred_prog
            pred_init_prog = fluid.Program()
            self._pred_init_prog = pred_init_prog

            with fluid.program_guard(pred_prog, pred_init_prog):
                pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
                pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
                self._pred_net_inputs = pred_net_inputs
        else:
W
wangxiao1021 已提交
230 231 232 233 234
            pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
            pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
            self._pred_net_inputs = pred_net_inputs

        # prepare predict vars for saving inference model
X
xixiaoyao 已提交
235 236 237 238 239 240 241 242 243 244 245
        if not self._lock_prog:
            with fluid.program_guard(pred_prog, pred_init_prog):
                cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
                self._pred_input_name_list, self._pred_input_varname_list = \
                    zip(*[[k, v.name] for k,v in cur_inputs.items()])

                pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
                scope = self.name + '.'
                with fluid.unique_name.guard(scope):
                    output_vars = self._build_head(pred_task_inputs, phase='predict', scope=scope)
        else:
W
wangxiao1021 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
            cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
            self._pred_input_name_list, self._pred_input_varname_list = \
                zip(*[[k, v.name] for k,v in cur_inputs.items()])

            pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
            scope = self.name + '.'
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(pred_task_inputs, phase='predict', scope=scope)

        if output_vars is not None:
            self._pred_fetch_name_list, self._pred_fetch_list = zip(*output_vars.items())
        else:
            self._pred_fetch_name_list = []
            self._pred_fetch_var_list = []

X
xixiaoyao 已提交
261 262 263 264 265
        # if not self._multi_task:
        self._init_exe_prog(for_train=False)
        self._exe.run(self._pred_init_prog)

        self._predict_vars = output_vars
W
wangxiao1021 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279
            
        return output_vars

    def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=None):
        """
        Build backward computation graph and training strategy.

        Arguments:
            - optimizer: 
            - weight_decay: optional, default is None (disable weight decay).
            - use_ema: optional, default is False. The flag to control whether to apply Exponential Moving Average strategy on parameter updates.
            - ema_decay: optional, default is None. Only works with use_ema == True. Control decay rate of EMA strategy.

        """
X
xixiaoyao 已提交
280
        # build optimizer
W
wangxiao1021 已提交
281 282
        assert self._loss_var is not None and self._train_init_prog is not None, "train graph not foung! You should build_forward first."
        optimizer._set_prog(self._train_prog, self._train_init_prog)
X
xixiaoyao 已提交
283
        with fluid.program_guard(self._train_prog, self._train_init_prog):
W
wangxiao1021 已提交
284
            param_grads = optimizer._build()
X
xixiaoyao 已提交
285 286 287 288 289

            if weight_decay is not None:

                param_list = dict()

W
wangxiao1021 已提交
290
                for param in self._train_prog.global_block().all_parameters():
X
xixiaoyao 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
                    param_list[param.name] = param * 1.0
                    param_list[param.name].stop_gradient = True

                def exclude_from_weight_decay(name):
                    if name.find("layer_norm") > -1:
                        return True
                    bias_suffix = ["_bias", "_b", ".b_0"]
                    for suffix in bias_suffix:
                        if name.endswith(suffix):
                            return True
                    return False

                for param, grad in param_grads:
                    if exclude_from_weight_decay(param.name):
                        continue
                    with param.block.program._optimized_guard(
                        [param, grad]), fluid.framework.name_scope("weight_decay"):
                        updated_param = param - param_list[
                            param.name] * weight_decay * optimizer.get_cur_learning_rate()
                        fluid.layers.assign(output=param, input=updated_param)

            if use_ema:
                ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
                ema.update()

W
wangxiao1021 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
        self._exe.run(self._train_init_prog)

    def set_as_aux(self):
        """Set the task in this trainer as auxilary task. \nCAUSIOUS: This API only works on multi-task learning mode. Each task is set as target task by default. """
        self._as_auxilary = True

    def fit_reader(self, reader, phase='train'):
        """
        Bind a reader and loaded train/predict data to trainer. 
        
        Args:
            reader: a Reader object. The running phase of the reader should be consistent with `phase` argument of this method.
            phase: running phase. Currently support: train, predict.

        """

        self._check_phase(phase)
W
wangxiao1021 已提交
333 334 335 336
        if phase=='train':
            assert self._shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
        else:
            assert self._pred_shape_and_dtypes is not None, "You need to build_forward     or build_predict_head first to prepare input features."
W
wangxiao1021 已提交
337 338 339 340 341 342

        batch_size = reader._batch_size

        self._num_epochs = reader.num_epochs
        if phase == 'train':
            self._train_reader = reader
W
wangxiao1021 已提交
343
            self._steps_pur_epoch = reader.num_examples // batch_size
W
wangxiao1021 已提交
344 345 346 347 348 349 350 351 352 353 354 355
            shape_and_dtypes = self._shape_and_dtypes
            name_to_position = self._name_to_position
            if self._task_id is not None:
                self._net_inputs['__task_id'] = self._task_id
            net_inputs = self._net_inputs
            self._train_batch_size = batch_size
            self._num_examples = reader.num_examples
            reader_helper.check_io(self._backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(train)')
            reader_helper.check_io(self._task_head.inputs_attrs['backbone'], self._backbone.outputs_attr, in_name='task_head(backbone, train)', out_name='backbone')
        elif phase == 'predict':
            self._predict_reader = reader
W
wangxiao1021 已提交
356
            self._pred_steps_pur_epoch = reader.num_examples // batch_size 
W
wangxiao1021 已提交
357 358 359 360 361 362 363 364 365 366 367
            shape_and_dtypes = self._pred_shape_and_dtypes
            name_to_position = self._pred_name_to_position
            net_inputs = self._pred_net_inputs
            self._predict_batch_size = batch_size
            self._pred_num_examples = reader.num_examples
            reader_helper.check_io(self._pred_backbone.inputs_attr, reader.outputs_attr, in_name='backbone', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['reader'], reader.outputs_attr, in_name='task_head(reader)', out_name='reader(predict)')
            reader_helper.check_io(self._pred_head.inputs_attrs['backbone'], self._pred_backbone.outputs_attr, in_name='task_head(backbone, predict)', out_name='backbone')
        else:
            raise NotImplementedError()
            
X
xixiaoyao 已提交
368 369 370
        print('ok!')

        # merge dataset iterators and create net input vars
W
wangxiao1021 已提交
371 372 373 374 375
        iterator = reader._iterator()
        prefix = self.name

        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
X
xixiaoyao 已提交
376 377 378
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
W
wangxiao1021 已提交
379 380 381
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict')
        self._raw_iterator_fn = iterator_fn
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
X
xixiaoyao 已提交
382
        if gpu_dev_count > 1:
W
wangxiao1021 已提交
383
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn, phase=phase)
X
xixiaoyao 已提交
384
        else:
W
wangxiao1021 已提交
385
            distribute_feeder_fn = iterator_fn()
X
xixiaoyao 已提交
386

W
wangxiao1021 已提交
387 388 389 390 391 392
        if phase == 'train':
            self._train_iterator = distribute_feeder_fn
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_iterator = distribute_feeder_fn
            self._pred_feed_batch_process_fn = feed_batch_process_fn
X
xixiaoyao 已提交
393
        return distribute_feeder_fn
W
wangxiao1021 已提交
394

W
wangxiao1021 已提交
395 396 397 398 399 400 401 402 403
    def load_ckpt(self, model_path):
        """
        load training checkpoint for further training or predicting.

        Args:
            model_path: the path of saved checkpoint/parameters.
        """
        assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."

X
xixiaoyao 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417
        # if self._train_init_prog is not None:
        #     saver.init_pretraining_params(
        #         self._exe,
        #         model_path,
        #         convert=False,
        #         main_program=self._train_init_prog,
        #         strict=True)
        # elif self._pred_init_prog is not None:
        #     saver.init_pretraining_params(
        #         self._exe,
        #         model_path,
        #         convert=False,
        #         main_program=self._pred_init_prog,
        #         strict=True)
W
wangxiao1021 已提交
418
        if self._train_init_prog is not None:
X
xixiaoyao 已提交
419 420
            print('loading checkpoint into train program')
            saver.init_checkpoint(
W
wangxiao1021 已提交
421 422
                self._exe,
                model_path,
X
xixiaoyao 已提交
423
                main_program=self._train_init_prog)
W
wangxiao1021 已提交
424
        elif self._pred_init_prog is not None:
X
xixiaoyao 已提交
425
            saver.init_checkpoint(
W
wangxiao1021 已提交
426 427
                self._exe,
                model_path,
X
xixiaoyao 已提交
428
                main_program=self._pred_init_prog)
W
wangxiao1021 已提交
429 430
        else:
            raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
W
wangxiao1021 已提交
431

W
wangxiao1021 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
    def load_predict_model(self, model_path, convert=False):
        """
        load pretrain models(backbone) for training.

        Args:
            model_path: the path of saved pretrained parameters.
        """

        assert self._pred_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters."

        saver.init_pretraining_params(
            self._exe,
            model_path,
            convert=convert,
            main_program=self._pred_prog)
W
wangxiao1021 已提交
447 448 449 450 451 452 453 454 455

    def load_pretrain(self, model_path, convert=False):
        """
        load pretrain models(backbone) for training.

        Args:
            model_path: the path of saved pretrained parameters.
        """
        assert self._train_init_prog is not None, "training graph not found. You should at least build_forward to load its pretrained parameters."
X
xixiaoyao 已提交
456 457 458 459

        saver.init_pretraining_params(
            self._exe,
            model_path,
W
wangxiao1021 已提交
460
            convert=convert,
X
xixiaoyao 已提交
461
            main_program=self._train_init_prog)
X
xixiaoyao 已提交
462

W
wangxiao1021 已提交
463
    def set_saver(self, save_path, save_steps, save_type='ckpt'):
W
wangxiao1021 已提交
464 465 466 467 468 469 470
        """
        create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.

        Args:
            save_path: a string. the path to save checkpoints or predict models.
            save_steps: an integer. the frequency to save models.
            save_type: a string. The type of saved model. Currently support checkpoint(ckpt) and predict model(predict), default is ckpt. If both two types are needed to save, you can set as "ckpt,predict".
X
xixiaoyao 已提交
471

W
wangxiao1021 已提交
472
        """
W
wangxiao1021 已提交
473
        
X
xixiaoyao 已提交
474

X
xixiaoyao 已提交
475 476
        save_type = save_type.split(',')
        if 'predict' in save_type:
W
wangxiao1021 已提交
477
            assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
X
xixiaoyao 已提交
478
            assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
W
wangxiao1021 已提交
479
            self._save_predict = True
X
xixiaoyao 已提交
480 481 482
            if not os.path.exists(save_path):
                os.makedirs(save_path)
        else:
W
wangxiao1021 已提交
483
            self._save_predict = False
X
xixiaoyao 已提交
484 485 486

        if 'ckpt' in save_type:
            if save_path is not None and save_steps is not None:
W
wangxiao1021 已提交
487
                self._save_ckpt = True
X
xixiaoyao 已提交
488 489 490 491
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
            else:
                "WARNING: save_path or save_steps is not set, model will not be saved during training."
W
wangxiao1021 已提交
492
                self._save_ckpt = False
X
xixiaoyao 已提交
493
        else:
W
wangxiao1021 已提交
494 495 496 497
            self._save_ckpt = False

        def temp_func():
            if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
W
wangxiao1021 已提交
498

W
wangxiao1021 已提交
499
                if self._save_predict:
W
wangxiao1021 已提交
500 501
                    self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
                    print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
W
wangxiao1021 已提交
502
                    sys.stdout.flush()
W
wangxiao1021 已提交
503
                if self._save_ckpt:
W
wangxiao1021 已提交
504 505
                    fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
                    print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
W
wangxiao1021 已提交
506
                    sys.stdout.flush()
W
wangxiao1021 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
                return True
            else:
                return False

        self._check_save = temp_func
            
    def train(self, print_steps=5):
        """
        start training.

        Args:
            print_steps: int. Logging frequency of training message, e.g., current step, loss and speed.
        """
        
        iterator = self._train_iterator
        self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
X
xixiaoyao 已提交
523 524 525 526 527 528

        time_begin = time.time()
        for feed in iterator:
            rt_outputs = self.train_one_step(feed)

            task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
W
wangxiao1021 已提交
529
            self._task_head.batch_postprocess(task_rt_outputs)
X
xixiaoyao 已提交
530 531 532 533 534 535 536 537 538 539


            if print_steps > 0 and self._cur_train_step % print_steps == 0:
                loss = rt_outputs[self.name+'.loss']
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print("step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
W
wangxiao1021 已提交
540
                       (self._cur_train_step-1) % self._steps_pur_epoch + 1 , self._steps_pur_epoch, self._cur_train_epoch,
X
xixiaoyao 已提交
541
                       loss, print_steps / time_cost))
W
wangxiao1021 已提交
542
                sys.stdout.flush()
W
wangxiao1021 已提交
543
                time_begin = time.time() 
X
xixiaoyao 已提交
544

W
wangxiao1021 已提交
545 546 547 548 549 550 551 552 553 554 555 556 557 558
            if self._num_epochs is None and not self._multi_task and self._cur_train_step == self._steps_pur_epoch:
                break
        
    def predict(self, output_dir=None, print_steps=1000):
        """
        start predicting.

        Args:
            output_dir: str. The path to save prediction results, default is None. If set as None, the results would output to screen directly. 
            print_steps: int. Logging frequency of predicting message, e.g., current progress and speed.
        """
        iterator = self._predict_iterator
        self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()

X
xixiaoyao 已提交
559

W
wangxiao1021 已提交
560 561
        if output_dir is not None and not os.path.exists(output_dir):
            os.makedirs(output_dir)
X
xixiaoyao 已提交
562

W
wangxiao1021 已提交
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
        time_begin = time.time()
        
        cur_predict_step = 0
        for feed in iterator:
            rt_outputs = self.predict_one_batch(feed)
            self._pred_head.batch_postprocess(rt_outputs)

            cur_predict_step += 1

            if print_steps > 0 and cur_predict_step % print_steps == 0:
                time_end = time.time()
                time_cost = time_end - time_begin

                print("batch {}/{}, speed: {:.2f} steps/s".format(
                       cur_predict_step, self._pred_steps_pur_epoch,
                       print_steps / time_cost))
W
wangxiao1021 已提交
579
                sys.stdout.flush()
W
wangxiao1021 已提交
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
                time_begin = time.time()

        if self._pred_head.epoch_inputs_attrs:
            reader_outputs = self._predict_reader.get_epoch_outputs()
        else:
            reader_outputs = None

        results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
        return results

    def _check_phase(self, phase):
        assert phase in ['train', 'predict'], "Supported phase: train, predict,"

    def _set_multitask(self):
        self._multi_task = True

X
xixiaoyao 已提交
596 597 598
    def _set_nomultitask(self):
        self._multi_task = False

W
wangxiao1021 已提交
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639
    def _set_task_id(self, task_id):
        self._task_id = task_id

    def _init_exe_prog(self, for_train=True):
        if not self._train_init and not self._predict_init:
            on_gpu = gpu_dev_count > 0
            self._exe = helper.build_executor(on_gpu)

        if for_train:
            assert self._train_prog is not None, "train graph not found! You should build_forward first before you random init parameters."
            self._train_init = True
        else:
            assert self._pred_prog is not None, "predict graph not found! You should build_predict_head first before you random init parameters."
            self._predict_init = True

    # def random_init_params(self):
    #     """
    #     randomly initialize model parameters.
    #     """
    #     
    #     if not self._train_init:
    #         self._init_exe_prog()
    #     
    #     print('random init params...')
    #     self._exe.run(self._train_init_prog)

    def get_one_batch(self, phase='train'):
        self._check_phase(phase)
        if phase == 'train':
            return next(self._train_reader)
        elif phase == 'predict':
            return next(self._predict_reader)
        else:
            raise NotImplementedError()

    def _set_exe(self, exe):
        self._exe = exe

    def _set_dist_train(self, prog):
        self._distribute_train_prog = prog

X
xixiaoyao 已提交
640 641 642
    def _set_dist_pred(self, prog):
        self._distribute_pred_prog = prog

W
wangxiao1021 已提交
643 644 645
    def _set_fetch_list(self, fetch_list):
        self._fetch_list = fetch_list

X
xixiaoyao 已提交
646
    def train_one_step(self, batch):
W
wangxiao1021 已提交
647

X
xixiaoyao 已提交
648 649 650 651
        if not self._dist_train_init:
            self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
            self._dist_train_init = True

W
wangxiao1021 已提交
652 653 654 655
        exe = self._exe
        distribute_train_prog = self._distribute_train_prog
        fetch_list = self._fetch_list

X
xixiaoyao 已提交
656 657
        if gpu_dev_count > 1:
            feed, mask = batch
W
wangxiao1021 已提交
658 659
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._train_batch_size)
W
wangxiao1021 已提交
660 661 662
            if num_fakes:
                rt_outputs = [i[:-num_fakes] for i in rt_outputs]
        
X
xixiaoyao 已提交
663 664
        else:
            feed = self._feed_batch_process_fn(batch)
W
wangxiao1021 已提交
665
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=fetch_list)
X
xixiaoyao 已提交
666 667

        rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
W
wangxiao1021 已提交
668
        self._cur_train_step += 1
W
wangxiao1021 已提交
669
        self._check_save()
W
wangxiao1021 已提交
670
        self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
X
xixiaoyao 已提交
671
        return rt_outputs
W
wangxiao1021 已提交
672 673 674 675

    def predict_one_batch(self, batch):
        if gpu_dev_count > 1:
            feed, mask = batch
X
xixiaoyao 已提交
676
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list, use_prune=True)
W
wangxiao1021 已提交
677
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._predict_batch_size)
W
wangxiao1021 已提交
678 679
            if num_fakes:
                rt_outputs = [i[:-num_fakes] for i in rt_outputs]
W
wangxiao1021 已提交
680 681
        else:
            feed = self._pred_feed_batch_process_fn(batch)
X
xixiaoyao 已提交
682
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list, use_prune=True)
W
wangxiao1021 已提交
683 684 685 686 687 688 689

        rt_outputs = {k:v for k,v in zip(self._pred_fetch_name_list, rt_outputs)}
        return rt_outputs

    @property
    def name(self):
        return self._name
W
wangxiao1021 已提交
690
    
W
wangxiao1021 已提交
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
    @property
    def num_examples(self):
        return self._num_examples

    @property
    def mix_ratio(self):
        return self._mix_ratio

    @mix_ratio.setter
    def mix_ratio(self, value):
        self._mix_ratio = value

    @property
    def num_epochs(self):
        return self._num_epochs

    @property
    def cur_train_step(self):
        return self._cur_train_step

    @property
    def cur_train_epoch(self):
        return self._cur_train_epoch

    @property
    def steps_pur_epoch(self):
        return self._steps_pur_epoch
X
xixiaoyao 已提交
718

X
xixiaoyao 已提交
719
    def _build_head(self, net_inputs, phase, scope=""):
W
wangxiao1021 已提交
720
        self._check_phase(phase)
X
xixiaoyao 已提交
721 722
        if phase == 'train':
            output_vars = self._task_head.build(net_inputs, scope_name=scope)
W
wangxiao1021 已提交
723
        if phase == 'predict':
X
xixiaoyao 已提交
724
            output_vars = self._pred_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
725 726
        return output_vars
    
W
wangxiao1021 已提交
727
    def _save(self, save_path, suffix=None):
X
xixiaoyao 已提交
728 729 730 731 732
        # dirpath = save_path.rstrip('/').rstrip('\\') + suffix
        if suffix is not None:
            dirpath = os.path.join(save_path, suffix)
        else:
            dirpath = save_path
X
xixiaoyao 已提交
733 734
        self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]

W
wangxiao1021 已提交
735
        prog = self._pred_prog.clone()
X
xixiaoyao 已提交
736 737 738 739 740 741 742 743 744 745
        fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)

        conf = {}
        for k, strv in self._save_protocol.items(): 
            d = None
            v = locals()
            exec('d={}'.format(strv), globals(), v)
            conf[k] = v['d']
        with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
            writer.write(json.dumps(conf, indent=1))
X
xixiaoyao 已提交
746
        print(self._name + ': predict model saved at ' + dirpath)
W
wangxiao1021 已提交
747
        sys.stdout.flush()
X
xixiaoyao 已提交
748

W
wangxiao1021 已提交
749
    
X
xixiaoyao 已提交
750 751 752 753 754 755 756 757 758
    def _load(self, infer_model_path=None):
        if infer_model_path is None:
            infer_model_path = self._save_infermodel_path
        for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): 
            strv = self._save_protocol[k]
            exec('{}=v'.format(strv))
        pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
            fluid.io.load_inference_model(infer_model_path, self._exe)
        print(self._name+': inference model loaded from ' + infer_model_path)
W
wangxiao1021 已提交
759
        sys.stdout.flush()
X
xixiaoyao 已提交
760 761
        return pred_prog