trainer.py 28.7 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: utf-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16
from __future__ import print_function
X
xixiaoyao 已提交
17 18 19
import os
import json
from paddle import fluid
X
xixiaoyao 已提交
20 21
import time
import numpy as np
X
xixiaoyao 已提交
22
import paddlepalm.utils.basic_helper as helper
X
xixiaoyao 已提交
23
from paddlepalm.utils import reader_helper, saver
X
xixiaoyao 已提交
24
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
X
xixiaoyao 已提交
25
# from paddlepalm.default_settings import *
X
xixiaoyao 已提交
26

X
xixiaoyao 已提交
27
DEBUG=False
X
xixiaoyao 已提交
28

X
xixiaoyao 已提交
29

X
xixiaoyao 已提交
30 31
class Trainer(object):

X
xixiaoyao 已提交
32
    def __init__(self, name, mix_ratio=1.0, reuse_head_with=None, \
X
xixiaoyao 已提交
33
                 silent=False):
X
xixiaoyao 已提交
34 35 36

        self._name = name
        self._verbose = not silent
X
xixiaoyao 已提交
37
        self._pred_reader = None
X
xixiaoyao 已提交
38
        self._task_head = None
X
xixiaoyao 已提交
39
        self._pred_head = None
X
xixiaoyao 已提交
40

X
xixiaoyao 已提交
41 42 43
        self._train_init = False
        self._predict_init = False

X
xixiaoyao 已提交
44 45
        self._check_save = lambda: False

X
xixiaoyao 已提交
46 47 48 49 50 51 52 53
        # if save_predict_model:
        #     self._save_predict_model = True
        #     assert pred_head is not None, "pred_head is required to save predict model."
        #     self._pred_reader = reader.clone(phase='pred')
        # else:
        #     assert pred_head is None, "You should set save_predict_model as True, or the pred_head is invalid." 
        #     self._save_predict_model = False
        #     self._pred_reader = None
X
xixiaoyao 已提交
54

X
xixiaoyao 已提交
55
        # self._save_steps = save_steps
X
xixiaoyao 已提交
56

X
xixiaoyao 已提交
57
        self._task_reuse_scope = name if reuse_head_with is None else reuse_head_with
X
xixiaoyao 已提交
58 59 60 61

        self._feeded_var_names = None
        self._target_vars = None

X
xixiaoyao 已提交
62 63
        self._num_examples = 0

X
xixiaoyao 已提交
64 65
        self._multi_task = False

X
xixiaoyao 已提交
66 67 68 69 70
        # training process management
        self._mix_ratio = mix_ratio
        self._expected_train_steps = None
        self._expected_train_epochs = None
        self._steps_pur_epoch = None
X
xixiaoyao 已提交
71
        self._pred_steps_pur_epoch = None
X
xixiaoyao 已提交
72 73 74 75 76 77 78 79 80 81
        self._cur_train_epoch = 0
        self._cur_train_step = 0
        self._train_finish = False

        self._inputname_to_varname = {}
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

X
xixiaoyao 已提交
82 83
        # exe is built when random_init_params is called.
        # self._exe = helper.build_executor(gpu_dev_count>0)
X
xixiaoyao 已提交
84
        self._exe = None
X
xixiaoyao 已提交
85 86 87 88 89 90 91

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

        self._lock = False
X
xixiaoyao 已提交
92 93
        self._build_forward = False

X
xixiaoyao 已提交
94
    def build_predict_forward(self, pred_backbone, pred_head, pred_prog=None, pred_init_prog=None):
X
xixiaoyao 已提交
95 96
        self._pred_head = pred_head
        # self._pred_reader = self._reader.clone(phase='pred')
X
xixiaoyao 已提交
97 98
        pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
        # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']
X
xixiaoyao 已提交
99

X
xixiaoyao 已提交
100 101 102
        # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
X
xixiaoyao 已提交
103
        pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
X
xixiaoyao 已提交
104
        pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
X
xixiaoyao 已提交
105 106
        self._pred_shape_and_dtypes = pred_shape_and_dtypes
        self._pred_name_to_position = pred_name_to_position
X
xixiaoyao 已提交
107 108 109
        
        if pred_prog is None:
            pred_prog = fluid.Program()
X
xixiaoyao 已提交
110
        self._pred_prog = pred_prog
X
xixiaoyao 已提交
111 112
        if pred_init_prog is None:
            pred_init_prog = fluid.Program()
X
xixiaoyao 已提交
113
        self._pred_init_prog = pred_init_prog
X
xixiaoyao 已提交
114 115 116 117
        with fluid.program_guard(pred_prog, pred_init_prog):
            pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
            # pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
            pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
X
xixiaoyao 已提交
118
            self._pred_net_inputs = pred_net_inputs
X
xixiaoyao 已提交
119 120 121 122 123 124 125 126 127 128 129

        # prepare predict vars for saving inference model
        with fluid.program_guard(pred_prog, pred_init_prog):
            cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
            # self.pred_input = cur_inputs
            self._pred_input_name_list, self._pred_input_varname_list = \
                zip(*[[k, v.name] for k,v in cur_inputs.items()])

            pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
            scope = self.name + '.'
            with fluid.unique_name.guard(scope):
X
xixiaoyao 已提交
130 131 132
                output_vars = self._build_head(pred_task_inputs, phase='pred', scope=scope)

        if output_vars is not None:
X
xixiaoyao 已提交
133
            self._pred_fetch_name_list, self._pred_fetch_list = zip(*output_vars.items())
X
xixiaoyao 已提交
134 135 136 137 138
        else:
            self._pred_fetch_name_list = []
            self._pred_fetch_var_list = []

        return output_vars
X
xixiaoyao 已提交
139

X
xixiaoyao 已提交
140 141
    def _set_multitask(self):
        self._multi_task = True
X
xixiaoyao 已提交
142

X
xixiaoyao 已提交
143 144
    def build_forward(self, backbone, task_head):
        # assert not self._multi_task, "you cannot build_forward in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
145
        self._task_head = task_head
X
xixiaoyao 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

        # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
        self._build_forward = True
        
        # create reader, task
        # then check i/o across reader, backbone and task_layer
        task_attrs = []
        pred_task_attrs = []

        task_attr_from_reader = helper.encode_inputs(self._task_head.inputs_attrs['reader'], self.name)
        # task_attr_from_reader = self._task_head.inputs_attrs['reader']

        # _check_io(backbone.inputs_attr, inst._reader['train'].outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train')
        # _check_io(inst.taskblock['train'].inputs_attrs['reader'], inst._reader['train'].outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train')
        # _check_io(inst._taskblock['train'].inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone')


        # merge reader input attrs from backbone and task_instances
X
xixiaoyao 已提交
164
        input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
X
xixiaoyao 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177
        # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
        self._shape_and_dtypes = shape_and_dtypes
        self._name_to_position = name_to_position

        if DEBUG:
            print('----- for debug -----')
            print('joint input names:')
            print(joint_input_names)
            print('joint input shape and dtypes:')
            print(joint_shape_and_dtypes)

        input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]

X
xixiaoyao 已提交
178 179 180
        train_prog = fluid.Program()
        train_init_prog = fluid.Program()

X
xixiaoyao 已提交
181 182
        self._train_prog = train_prog
        self._train_init_prog = train_init_prog
X
xixiaoyao 已提交
183 184 185 186 187
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
                bb_output_vars = backbone.build(net_inputs)
        else:
X
xixiaoyao 已提交
188 189
            net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
            bb_output_vars = backbone.build(net_inputs)
X
xixiaoyao 已提交
190 191 192
        self._net_inputs = net_inputs
        assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())

X
xixiaoyao 已提交
193
        # self._bb_output_vars.keys
X
xixiaoyao 已提交
194 195 196 197 198 199 200 201 202 203

        # fluid.framework.switch_main_program(train_prog)
        # fluid.framework.switch_startup_program(train_init_prog)

        task_output_vars = {}
        task_inputs = {'backbone': bb_output_vars}
        task_inputs_from_reader = helper.decode_inputs(net_inputs, self.name)
        task_inputs['reader'] = task_inputs_from_reader

        scope = self.name+'.'
X
xixiaoyao 已提交
204 205 206 207 208
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                with fluid.unique_name.guard(scope):
                    output_vars = self._build_head(task_inputs, phase='train', scope=scope)
        else:
X
xixiaoyao 已提交
209 210
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(task_inputs, phase='train', scope=scope)
X
xixiaoyao 已提交
211

X
xixiaoyao 已提交
212 213 214 215 216 217 218
        output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
        old = len(task_output_vars) # for debug
        task_output_vars.update(output_vars)
        assert len(task_output_vars) - old == len(output_vars) # for debug

        bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
        task_fetches = {k: v.name for k,v in task_output_vars.items()}
X
xixiaoyao 已提交
219
        self._fetches = task_fetches
X
xixiaoyao 已提交
220
        self._fetch_names, self._fetch_list = zip(*self._fetches.items())
X
xixiaoyao 已提交
221 222 223 224 225 226 227 228
        # fetches = task_fetches
        # fetches['__task_id'] = net_inputs['__task_id'].name

        # compute loss
        # task_id_var = net_inputs['__task_id']
        # task_id_vec = layers.one_hot(task_id_var, num_instances)
        # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
        # loss = layers.reduce_sum(task_id_vec * losses)
X
xixiaoyao 已提交
229 230 231 232
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
        else:
X
xixiaoyao 已提交
233
            loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
X
xixiaoyao 已提交
234

X
xixiaoyao 已提交
235 236 237
        # for _id, block in enumerate(self._train_prog.blocks):
        #   for var in block.vars:
        #     print("[debug] : %d, %s" % (_id, var))
X
xixiaoyao 已提交
238
        self._loss_var = loss_var
X
xixiaoyao 已提交
239 240 241
        return loss_var

    def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999):
X
xixiaoyao 已提交
242
        # assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
243
        # build optimizer
X
xixiaoyao 已提交
244 245
        assert self._train_init_prog is not None, "train graph not foung! You should build_forward first."
        optimizer._set_prog(self._train_prog, self._train_init_prog)
X
xixiaoyao 已提交
246 247 248 249 250 251 252
        with fluid.program_guard(self._train_prog, self._train_init_prog):
            param_grads = optimizer.build()

            if weight_decay is not None:

                param_list = dict()

X
xixiaoyao 已提交
253
                for param in self._train_prog.global_block().all_parameters():
X
xixiaoyao 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
                    param_list[param.name] = param * 1.0
                    param_list[param.name].stop_gradient = True

                def exclude_from_weight_decay(name):
                    if name.find("layer_norm") > -1:
                        return True
                    bias_suffix = ["_bias", "_b", ".b_0"]
                    for suffix in bias_suffix:
                        if name.endswith(suffix):
                            return True
                    return False

                for param, grad in param_grads:
                    if exclude_from_weight_decay(param.name):
                        continue
                    with param.block.program._optimized_guard(
                        [param, grad]), fluid.framework.name_scope("weight_decay"):
                        updated_param = param - param_list[
                            param.name] * weight_decay * optimizer.get_cur_learning_rate()
                        fluid.layers.assign(output=param, input=updated_param)


            # loss.persistable = True
            if use_ema:
                ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
                ema.update()

X
xixiaoyao 已提交
281 282 283 284 285 286 287
        # for bid, block in enumerate(self._train_prog.blocks):
        #     print('block id: '+str(bid))
        #     for var in block.vars:
        #         print("%d : %s" % (bid, var))
            
        # print(self._train_prog)

X
xixiaoyao 已提交
288
    def fit_reader(self, reader, phase='train'):
X
xixiaoyao 已提交
289
        # assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
290
        # load data
X
xixiaoyao 已提交
291 292 293

        assert self._shape_and_dtypes is not None or self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."

X
xixiaoyao 已提交
294 295 296
        # 这里不确定是否要向上取整,需确认
        # tail = self._num_examples % batch_size > 0
        # self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0
X
xixiaoyao 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        batch_size = reader._batch_size
        self._num_epochs = reader.num_epochs
        if phase == 'train':
            self._steps_pur_epoch = reader.num_examples // batch_size
            shape_and_dtypes = self._shape_and_dtypes
            name_to_position = self._name_to_position
            net_inputs = self._net_inputs
            self._train_batch_size = batch_size
            self._num_examples = reader.num_examples
        elif phase == 'predict':
            tail = self._num_examples % batch_size > 0
            self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
            shape_and_dtypes = self._pred_shape_and_dtypes
            name_to_position = self._pred_name_to_position
            net_inputs = self._pred_net_inputs
            self._predict_batch_size = batch_size
            self._pred_num_examples = reader.num_examples
        else:
            raise NotImplementedError()
            
X
xixiaoyao 已提交
317 318 319
        print('ok!')

        # merge dataset iterators and create net input vars
X
xixiaoyao 已提交
320
        iterator = reader._iterator()
X
xixiaoyao 已提交
321 322 323
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
X
xixiaoyao 已提交
324 325 326
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict')
        self._raw_iterator_fn = iterator_fn
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
X
xixiaoyao 已提交
327 328 329 330
        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn)
        else:
            distribute_feeder_fn = iterator_fn
X
xixiaoyao 已提交
331 332 333 334 335 336 337 338

        if phase == 'train':
            self._train_reader = distribute_feeder_fn()
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_reader = distribute_feeder_fn()
            self._pred_feed_batch_process_fn = feed_batch_process_fn
        # return distribute_feeder_fn()
X
xixiaoyao 已提交
339

X
xixiaoyao 已提交
340
    def _init_exe_prog(self, for_train=True):
X
xixiaoyao 已提交
341 342 343 344 345 346 347 348 349 350
        if not self._train_init and not self._predict_init:
            on_gpu = gpu_dev_count > 0
            self._exe = helper.build_executor(on_gpu)

        if for_train:
            assert self._train_prog is not None, "train graph not foung! You should build_forward first before you random init parameters."
            self._train_init = True
        else:
            assert self._pred_prog is not None, "predict graph not foung! You should build_predict_head first before you random init parameters."
            self._predict_init = True
X
xixiaoyao 已提交
351 352

    def random_init_params(self):
X
xixiaoyao 已提交
353
        
X
xixiaoyao 已提交
354 355 356
        if not self._train_init:
            self._init_exe_prog()
        
X
xixiaoyao 已提交
357 358
        print('random init params...')
        self._exe.run(self._train_init_prog)
X
xixiaoyao 已提交
359

X
xixiaoyao 已提交
360 361
    def load_ckpt(self, model_path, phase='train'):
        # load pretrain model (or ckpt)
X
xixiaoyao 已提交
362 363
        # assert self._exe is not None, "You need to random_init_params before load checkpoints."
        if phase == 'train' and not self._train_init:
X
xixiaoyao 已提交
364
            self._init_exe_prog(for_train=True)
X
xixiaoyao 已提交
365
        if phase == 'predict' and not self._predict_init:
X
xixiaoyao 已提交
366
            self._init_exe_prog(for_train=False)
X
xixiaoyao 已提交
367 368 369 370 371 372

        if phase == 'train':
            assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
X
xixiaoyao 已提交
373 374
                main_program=self._train_init_prog,
                strict=True)
X
xixiaoyao 已提交
375 376 377 378 379
        elif phase == 'predict':
            assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
X
xixiaoyao 已提交
380 381
                main_program=self._pred_init_prog,
                strict=True)
X
xixiaoyao 已提交
382 383 384 385 386 387 388
        else:
            raise NotImplementedError()
            

    def load_predict_model(self, model_path):
        raise NotImplementedError()

X
xixiaoyao 已提交
389
    def load_pretrain(self, model_path, convert=False):
X
xixiaoyao 已提交
390 391 392 393 394 395
        # load pretrain model (or ckpt)
        assert self._exe is not None, "You need to random_init_params before load pretrain models."

        saver.init_pretraining_params(
            self._exe,
            model_path,
X
xixiaoyao 已提交
396
            convert=convert,
X
xixiaoyao 已提交
397
            main_program=self._train_init_prog)
X
xixiaoyao 已提交
398

X
xixiaoyao 已提交
399
    def set_saver(self, save_path, save_steps, save_type='ckpt'):
X
xixiaoyao 已提交
400

X
xixiaoyao 已提交
401 402
        save_type = save_type.split(',')
        if 'predict' in save_type:
X
xixiaoyao 已提交
403
            assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
X
xixiaoyao 已提交
404
            assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
X
xixiaoyao 已提交
405
            self._save_predict = True
X
xixiaoyao 已提交
406 407 408
            if not os.path.exists(save_path):
                os.makedirs(save_path)
        else:
X
xixiaoyao 已提交
409
            self._save_predict = False
X
xixiaoyao 已提交
410 411 412

        if 'ckpt' in save_type:
            if save_path is not None and save_steps is not None:
X
xixiaoyao 已提交
413
                self._save_ckpt = True
X
xixiaoyao 已提交
414 415 416 417
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
            else:
                "WARNING: save_path or save_steps is not set, model will not be saved during training."
X
xixiaoyao 已提交
418
                self._save_ckpt = False
X
xixiaoyao 已提交
419
        else:
X
xixiaoyao 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
            self._save_ckpt = False

        def temp_func():
            if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
                if self._save_predict:
                    self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
                    print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
                if self._save_ckpt:
                    fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
                    print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
                return True
            else:
                return False

        self._check_save = temp_func
            
    def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5):
        """
        Argument:
            save_type: ckpt, predict, pretrain
        """
        iterator = self._train_reader
        self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)


X
xixiaoyao 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474

        # if save_path is not None or save_steps is not None:
        #     assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built."
        # if self._save_predict_model:
        #     if save_path is None and save_steps is None:
        #         print('Warning: model will not be saved for this run. If you want to save model, set save_path and save_steps.')
        #     else:
        #         assert save_path is not None, "argument save_path is required to save models."
        #         assert save_steps == -1 or save_steps > 0, "argument save_steps should be -1 (only save the last step of this task) or larger than 0"
        #         if save_path is not None and not os.path.exists(save_path):
        #             os.makedirs(save_path)
        # else:
        #     assert save_path is None, "You should set save_predict_model as True, or the argument save_path is invalid."
        #     assert save_steps is None, "You should set save_predict_model as True, or the argument save_steps is invalid."

        time_begin = time.time()
        for feed in iterator:
            rt_outputs = self.train_one_step(feed)
            # if gpu_dev_count > 1:
            #     feed, mask = feed
            # rt_outputs = self.exe.run(self._train_prog, feed=feed, fetch_list=self._fetch_list)
            # print(rt_outputs)
            # print(len(rt_outputs))
            # if gpu_dev_count > 1:
            #     while mask.pop() == False:
            #         rt_outputs.pop()

            # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}

            task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
X
xixiaoyao 已提交
475
            self._task_head.batch_postprocess(task_rt_outputs)
X
xixiaoyao 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494

            # if self._save_predict_model and self._cur_train_step % save_steps == 0:
            #     self.save(save_path, suffix='.step'+str(self._cur_train_steps))

            if print_steps > 0 and self._cur_train_step % print_steps == 0:
                loss = rt_outputs[self.name+'.loss']
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print("step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
                       (self._cur_train_step-1) % self._steps_pur_epoch + 1, self._steps_pur_epoch, self._cur_train_epoch,
                       loss, print_steps / time_cost))
                time_begin = time.time()

            # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
            #     print(cur_task.name+': train finished!')
            #     cur_task.save()
X
xixiaoyao 已提交
495 496
            
            self._check_save()
X
xixiaoyao 已提交
497 498


X
xixiaoyao 已提交
499 500
            if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch:
                break
X
xixiaoyao 已提交
501 502 503 504 505 506 507
        # save_path = os.path.join(main_conf['save_path'], 'ckpt',
        #                          "step_" + str(global_step))
        # fluid.io.save_persistables(self.exe, save_path, saver_program)
        # print('checkpoint has been saved at '+save_path)

        # print("ALL tasks train finished, exiting...")

X
xixiaoyao 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
    def get_one_batch(self, phase='train'):
        if phase == 'train':
            return next(self._train_reader)
        elif phase == 'predict':
            return next(self._predict_reader)
        else:
            raise NotImplementedError()
        
    def predict(self, output_dir=None, print_steps=1000):
        """
        Argument:
            save_type: ckpt, predict, pretrain
        """
        iterator = self._predict_reader
        self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()

        if output_dir is not None and not os.path.exists(output_dir):
            os.makedirs(output_dir)

        time_begin = time.time()
        cur_predict_step = 0
        for feed in iterator:
            rt_outputs = self.predict_one_batch(feed)
            # rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
            # print(rt_outputs)
            self._pred_head.batch_postprocess(rt_outputs)

            cur_predict_step += 1

            if print_steps > 0 and cur_predict_step % print_steps == 0:
                time_end = time.time()
                time_cost = time_end - time_begin

                print("batch {}/{}, speed: {:.2f} steps/s".format(
                       cur_predict_step, self._pred_steps_pur_epoch,
                       print_steps / time_cost))
                time_begin = time.time()

        if self._pred_head.epoch_inputs_attrs:
            reader_outputs = self._pred_reader.get_epoch_outputs()
        else:
            reader_outputs = None

        results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
        return results

X
xixiaoyao 已提交
554 555 556 557
    def train_one_step(self, batch, executor=None, distribute_train_prog=None):
        exe = self._exe if executor is None else executor
        distribute_train_prog = self._distribute_train_prog if distribute_train_prog is None else distribute_train_prog

X
xixiaoyao 已提交
558 559
        if gpu_dev_count > 1:
            feed, mask = batch
X
xixiaoyao 已提交
560
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
561 562 563 564
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size)
            for _ in range(num_fakes):
                for item in rt_outputs:
                    item.pop()
X
xixiaoyao 已提交
565 566
        else:
            feed = self._feed_batch_process_fn(batch)
X
xixiaoyao 已提交
567
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
568 569

        rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
X
xixiaoyao 已提交
570 571
        self._cur_train_step += 1
        self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
X
xixiaoyao 已提交
572
        return rt_outputs
X
xixiaoyao 已提交
573

X
xixiaoyao 已提交
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
    @property
    def num_epochs(self):
        return self._num_epochs

    @property
    def cur_train_steps(self):
        return self._cur_train_step

    @property
    def cur_train_epoch(self):
        return self._cur_train_epoch

    @property
    def steps_pur_epoch(self):
        return self._steps_pur_epoch

X
xixiaoyao 已提交
590 591 592
    def predict_one_batch(self, batch):
        if gpu_dev_count > 1:
            feed, mask = batch
X
xixiaoyao 已提交
593 594 595 596 597
            rt_outputs = self.exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size)
            for _ in range(num_fakes):
                for item in rt_outputs:
                    item.pop()
X
xixiaoyao 已提交
598
        else:
X
xixiaoyao 已提交
599 600
            feed = self._pred_feed_batch_process_fn(batch)
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
X
xixiaoyao 已提交
601

X
xixiaoyao 已提交
602 603
        rt_outputs = {k:v for k,v in zip(self._pred_fetch_name_list, rt_outputs)}
        return rt_outputs
X
xixiaoyao 已提交
604

X
xixiaoyao 已提交
605 606 607
    def _build_head(self, net_inputs, phase, scope=""):
        if phase == 'train':
            output_vars = self._task_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
608
        if phase == 'pred':
X
xixiaoyao 已提交
609
            output_vars = self._pred_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
610 611
        return output_vars
    
X
xixiaoyao 已提交
612 613 614 615 616 617
    def save(self, save_path, suffix=None):
        # dirpath = save_path.rstrip('/').rstrip('\\') + suffix
        if suffix is not None:
            dirpath = os.path.join(save_path, suffix)
        else:
            dirpath = save_path
X
xixiaoyao 已提交
618 619
        self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]

X
xixiaoyao 已提交
620
        prog = self._pred_prog.clone()
X
xixiaoyao 已提交
621 622 623 624 625 626 627 628 629 630
        fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)

        conf = {}
        for k, strv in self._save_protocol.items(): 
            d = None
            v = locals()
            exec('d={}'.format(strv), globals(), v)
            conf[k] = v['d']
        with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
            writer.write(json.dumps(conf, indent=1))
X
xixiaoyao 已提交
631
        print(self._name + ': predict model saved at ' + dirpath)
X
xixiaoyao 已提交
632

X
xixiaoyao 已提交
633
    
X
xixiaoyao 已提交
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
    def _load(self, infer_model_path=None):
        if infer_model_path is None:
            infer_model_path = self._save_infermodel_path
        for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): 
            strv = self._save_protocol[k]
            exec('{}=v'.format(strv))
        pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
            fluid.io.load_inference_model(infer_model_path, self._exe)
        print(self._name+': inference model loaded from ' + infer_model_path)
        return pred_prog

    @property
    def name(self):
        return self._name

    @property
X
xixiaoyao 已提交
650 651
    def num_examples(self):
        return self._num_examples
X
xixiaoyao 已提交
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668

    @property
    def mix_ratio(self):
        if self._mix_ratio is not None:
            return self._mix_ratio
        else:
            raise ValueError("{}: mix_ratio is None".format(self._name))

    @mix_ratio.setter
    def mix_ratio(self, value):
        self._mix_ratio = float(value)
        if self._verbose:
            print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))

    def _set_lock(self):
        self._lock = True