trainer.py 28.6 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: utf-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16
from __future__ import print_function
X
xixiaoyao 已提交
17 18 19
import os
import json
from paddle import fluid
X
xixiaoyao 已提交
20 21
import time
import numpy as np
X
xixiaoyao 已提交
22
import paddlepalm.utils.basic_helper as helper
X
xixiaoyao 已提交
23
from paddlepalm.utils import reader_helper, saver
X
xixiaoyao 已提交
24
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
X
xixiaoyao 已提交
25
# from paddlepalm.default_settings import *
X
xixiaoyao 已提交
26

X
xixiaoyao 已提交
27
DEBUG=False
X
xixiaoyao 已提交
28

X
xixiaoyao 已提交
29

X
xixiaoyao 已提交
30 31
class Trainer(object):

X
xixiaoyao 已提交
32
    def __init__(self, name, mix_ratio=1.0, reuse_head_with=None, \
X
xixiaoyao 已提交
33
                 silent=False):
X
xixiaoyao 已提交
34 35 36

        self._name = name
        self._verbose = not silent
X
xixiaoyao 已提交
37
        self._pred_reader = None
X
xixiaoyao 已提交
38
        self._task_head = None
X
xixiaoyao 已提交
39
        self._pred_head = None
X
xixiaoyao 已提交
40

X
xixiaoyao 已提交
41 42 43
        self._train_init = False
        self._predict_init = False

X
xixiaoyao 已提交
44
        nelf._check_save = lambda: False
X
xixiaoyao 已提交
45

X
xixiaoyao 已提交
46 47 48 49 50 51 52 53
        # if save_predict_model:
        #     self._save_predict_model = True
        #     assert pred_head is not None, "pred_head is required to save predict model."
        #     self._pred_reader = reader.clone(phase='pred')
        # else:
        #     assert pred_head is None, "You should set save_predict_model as True, or the pred_head is invalid." 
        #     self._save_predict_model = False
        #     self._pred_reader = None
X
xixiaoyao 已提交
54

X
xixiaoyao 已提交
55
        # self._save_steps = save_steps
X
xixiaoyao 已提交
56

X
xixiaoyao 已提交
57
        self._task_reuse_scope = name if reuse_head_with is None else reuse_head_with
X
xixiaoyao 已提交
58 59 60 61

        self._feeded_var_names = None
        self._target_vars = None

X
xixiaoyao 已提交
62 63
        self._num_examples = 0

X
xixiaoyao 已提交
64 65
        self._multi_task = False

X
xixiaoyao 已提交
66 67 68 69 70
        # training process management
        self._mix_ratio = mix_ratio
        self._expected_train_steps = None
        self._expected_train_epochs = None
        self._steps_pur_epoch = None
X
xixiaoyao 已提交
71
        self._pred_steps_pur_epoch = None
X
xixiaoyao 已提交
72 73 74 75 76 77 78 79 80 81
        self._cur_train_epoch = 0
        self._cur_train_step = 0
        self._train_finish = False

        self._inputname_to_varname = {}
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

X
xixiaoyao 已提交
82 83
        # exe is built when random_init_params is called.
        # self._exe = helper.build_executor(gpu_dev_count>0)
X
xixiaoyao 已提交
84
        self._exe = None
X
xixiaoyao 已提交
85 86 87 88 89 90 91

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

        self._lock = False
X
xixiaoyao 已提交
92 93
        self._build_forward = False

X
xixiaoyao 已提交
94
    def build_predict_forward(self, pred_backbone, pred_head, pred_prog=None, pred_init_prog=None):
X
xixiaoyao 已提交
95 96
        self._pred_head = pred_head
        # self._pred_reader = self._reader.clone(phase='pred')
X
xixiaoyao 已提交
97 98
        pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
        # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']
X
xixiaoyao 已提交
99

X
xixiaoyao 已提交
100 101
        # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')

X
xixiaoyao 已提交
102 103 104
        # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
X
xixiaoyao 已提交
105
        pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False)
X
xixiaoyao 已提交
106
        pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
X
xixiaoyao 已提交
107 108
        self._pred_shape_and_dtypes = pred_shape_and_dtypes
        self._pred_name_to_position = pred_name_to_position
X
xixiaoyao 已提交
109 110 111
        
        if pred_prog is None:
            pred_prog = fluid.Program()
X
xixiaoyao 已提交
112
        self._pred_prog = pred_prog
X
xixiaoyao 已提交
113 114
        if pred_init_prog is None:
            pred_init_prog = fluid.Program()
X
xixiaoyao 已提交
115
        self._pred_init_prog = pred_init_prog
X
xixiaoyao 已提交
116 117 118 119
        with fluid.program_guard(pred_prog, pred_init_prog):
            pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
            # pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
            pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
X
xixiaoyao 已提交
120
            self._pred_net_inputs = pred_net_inputs
X
xixiaoyao 已提交
121 122 123 124 125 126 127 128 129 130 131

        # prepare predict vars for saving inference model
        with fluid.program_guard(pred_prog, pred_init_prog):
            cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
            # self.pred_input = cur_inputs
            self._pred_input_name_list, self._pred_input_varname_list = \
                zip(*[[k, v.name] for k,v in cur_inputs.items()])

            pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
            scope = self.name + '.'
            with fluid.unique_name.guard(scope):
X
xixiaoyao 已提交
132 133 134
                output_vars = self._build_head(pred_task_inputs, phase='pred', scope=scope)

        if output_vars is not None:
X
xixiaoyao 已提交
135
            self._pred_fetch_name_list, self._pred_fetch_list = zip(*output_vars.items())
X
xixiaoyao 已提交
136 137 138 139 140
        else:
            self._pred_fetch_name_list = []
            self._pred_fetch_var_list = []

        return output_vars
X
xixiaoyao 已提交
141

X
xixiaoyao 已提交
142 143
    def _set_multitask(self):
        self._multi_task = True
X
xixiaoyao 已提交
144

X
xixiaoyao 已提交
145 146
    def build_forward(self, backbone, task_head):
        # assert not self._multi_task, "you cannot build_forward in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
147
        self._task_head = task_head
X
xixiaoyao 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165

        # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
        self._build_forward = True
        
        # create reader, task
        # then check i/o across reader, backbone and task_layer
        task_attrs = []
        pred_task_attrs = []

        task_attr_from_reader = helper.encode_inputs(self._task_head.inputs_attrs['reader'], self.name)
        # task_attr_from_reader = self._task_head.inputs_attrs['reader']

        # _check_io(backbone.inputs_attr, inst._reader['train'].outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train')
        # _check_io(inst.taskblock['train'].inputs_attrs['reader'], inst._reader['train'].outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train')
        # _check_io(inst._taskblock['train'].inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone')


        # merge reader input attrs from backbone and task_instances
X
xixiaoyao 已提交
166
        input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False)
X
xixiaoyao 已提交
167 168 169
        # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
        self._shape_and_dtypes = shape_and_dtypes
        self._name_to_position = name_to_position
X
xixiaoyao 已提交
170
        self._input_names = input_names
X
xixiaoyao 已提交
171 172 173 174 175 176 177 178 179 180

        if DEBUG:
            print('----- for debug -----')
            print('joint input names:')
            print(joint_input_names)
            print('joint input shape and dtypes:')
            print(joint_shape_and_dtypes)

        input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]

X
xixiaoyao 已提交
181 182 183
        train_prog = fluid.Program()
        train_init_prog = fluid.Program()

X
xixiaoyao 已提交
184 185
        self._train_prog = train_prog
        self._train_init_prog = train_init_prog
X
xixiaoyao 已提交
186 187 188 189 190
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
                bb_output_vars = backbone.build(net_inputs)
        else:
X
xixiaoyao 已提交
191 192
            net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
            bb_output_vars = backbone.build(net_inputs)
X
xixiaoyao 已提交
193 194 195
        self._net_inputs = net_inputs
        assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())

X
xixiaoyao 已提交
196
        # self._bb_output_vars.keys
X
xixiaoyao 已提交
197 198 199 200 201 202 203 204 205 206

        # fluid.framework.switch_main_program(train_prog)
        # fluid.framework.switch_startup_program(train_init_prog)

        task_output_vars = {}
        task_inputs = {'backbone': bb_output_vars}
        task_inputs_from_reader = helper.decode_inputs(net_inputs, self.name)
        task_inputs['reader'] = task_inputs_from_reader

        scope = self.name+'.'
X
xixiaoyao 已提交
207 208 209 210 211
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                with fluid.unique_name.guard(scope):
                    output_vars = self._build_head(task_inputs, phase='train', scope=scope)
        else:
X
xixiaoyao 已提交
212 213
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(task_inputs, phase='train', scope=scope)
X
xixiaoyao 已提交
214

X
xixiaoyao 已提交
215 216 217 218 219 220 221
        output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
        old = len(task_output_vars) # for debug
        task_output_vars.update(output_vars)
        assert len(task_output_vars) - old == len(output_vars) # for debug

        bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
        task_fetches = {k: v.name for k,v in task_output_vars.items()}
X
xixiaoyao 已提交
222
        self._fetches = task_fetches
X
xixiaoyao 已提交
223
        self._fetch_names, self._fetch_list = zip(*self._fetches.items())
X
xixiaoyao 已提交
224 225 226 227 228 229 230 231
        # fetches = task_fetches
        # fetches['__task_id'] = net_inputs['__task_id'].name

        # compute loss
        # task_id_var = net_inputs['__task_id']
        # task_id_vec = layers.one_hot(task_id_var, num_instances)
        # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
        # loss = layers.reduce_sum(task_id_vec * losses)
X
xixiaoyao 已提交
232 233 234 235
        if not self._multi_task:
            with fluid.program_guard(train_prog, train_init_prog):
                loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
        else:
X
xixiaoyao 已提交
236
            loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
X
xixiaoyao 已提交
237

X
xixiaoyao 已提交
238 239 240
        # for _id, block in enumerate(self._train_prog.blocks):
        #   for var in block.vars:
        #     print("[debug] : %d, %s" % (_id, var))
X
xixiaoyao 已提交
241
        self._loss_var = loss_var
X
xixiaoyao 已提交
242
        return loss_var
X
xixiaoyao 已提交
243
        # assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
244
        # build optimizer
X
xixiaoyao 已提交
245 246
        assert self._train_init_prog is not None, "train graph not foung! You should build_forward first."
        optimizer._set_prog(self._train_prog, self._train_init_prog)
X
xixiaoyao 已提交
247 248 249 250 251 252 253
        with fluid.program_guard(self._train_prog, self._train_init_prog):
            param_grads = optimizer.build()

            if weight_decay is not None:

                param_list = dict()

X
xixiaoyao 已提交
254
                for param in self._train_prog.global_block().all_parameters():
X
xixiaoyao 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
                    param_list[param.name] = param * 1.0
                    param_list[param.name].stop_gradient = True

                def exclude_from_weight_decay(name):
                    if name.find("layer_norm") > -1:
                        return True
                    bias_suffix = ["_bias", "_b", ".b_0"]
                    for suffix in bias_suffix:
                        if name.endswith(suffix):
                            return True
                    return False

                for param, grad in param_grads:
                    if exclude_from_weight_decay(param.name):
                        continue
                    with param.block.program._optimized_guard(
                        [param, grad]), fluid.framework.name_scope("weight_decay"):
                        updated_param = param - param_list[
                            param.name] * weight_decay * optimizer.get_cur_learning_rate()
                        fluid.layers.assign(output=param, input=updated_param)


            # loss.persistable = True
            if use_ema:
                ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
                ema.update()

X
xixiaoyao 已提交
282 283 284 285 286 287 288
        # for bid, block in enumerate(self._train_prog.blocks):
        #     print('block id: '+str(bid))
        #     for var in block.vars:
        #         print("%d : %s" % (bid, var))
            
        # print(self._train_prog)

X
xixiaoyao 已提交
289 290 291
    def set_as_aux(self):
        self._as_auxilary = True

X
xixiaoyao 已提交
292
    def fit_reader(self, reader, phase='train'):
X
xixiaoyao 已提交
293
        # assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer."
X
xixiaoyao 已提交
294
        # load data
X
xixiaoyao 已提交
295 296 297

        assert self._shape_and_dtypes is not None or self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."

X
xixiaoyao 已提交
298 299 300
        # 这里不确定是否要向上取整,需确认
        # tail = self._num_examples % batch_size > 0
        # self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0
X
xixiaoyao 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
        batch_size = reader._batch_size
        self._num_epochs = reader.num_epochs
        if phase == 'train':
            self._steps_pur_epoch = reader.num_examples // batch_size
            shape_and_dtypes = self._shape_and_dtypes
            name_to_position = self._name_to_position
            net_inputs = self._net_inputs
            self._train_batch_size = batch_size
            self._num_examples = reader.num_examples
        elif phase == 'predict':
            tail = self._num_examples % batch_size > 0
            self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0
            shape_and_dtypes = self._pred_shape_and_dtypes
            name_to_position = self._pred_name_to_position
            net_inputs = self._pred_net_inputs
            self._predict_batch_size = batch_size
            self._pred_num_examples = reader.num_examples
        else:
            raise NotImplementedError()
            
X
xixiaoyao 已提交
321 322
        print('ok!')

X
xixiaoyao 已提交
323 324 325 326 327
        # merge dataset iterators and create net input vars
        iterator = reader._iterator()
        prefix = self.name


X
xixiaoyao 已提交
328
        # merge dataset iterators and create net input vars
X
xixiaoyao 已提交
329
        iterator = reader._iterator()
X
xixiaoyao 已提交
330 331 332
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
X
xixiaoyao 已提交
333 334 335
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict')
        self._raw_iterator_fn = iterator_fn
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs)
X
xixiaoyao 已提交
336 337 338 339
        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn)
        else:
            distribute_feeder_fn = iterator_fn
X
xixiaoyao 已提交
340 341 342 343 344 345 346 347

        if phase == 'train':
            self._train_reader = distribute_feeder_fn()
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_reader = distribute_feeder_fn()
            self._pred_feed_batch_process_fn = feed_batch_process_fn
        # return distribute_feeder_fn()
X
xixiaoyao 已提交
348

X
xixiaoyao 已提交
349
    def _init_exe_prog(self, for_train=True):
X
xixiaoyao 已提交
350 351 352 353 354 355 356 357 358 359
        if not self._train_init and not self._predict_init:
            on_gpu = gpu_dev_count > 0
            self._exe = helper.build_executor(on_gpu)

        if for_train:
            assert self._train_prog is not None, "train graph not foung! You should build_forward first before you random init parameters."
            self._train_init = True
        else:
            assert self._pred_prog is not None, "predict graph not foung! You should build_predict_head first before you random init parameters."
            self._predict_init = True
X
xixiaoyao 已提交
360 361

    def random_init_params(self):
X
xixiaoyao 已提交
362
        
X
xixiaoyao 已提交
363 364 365
        if not self._train_init:
            self._init_exe_prog()
        
X
xixiaoyao 已提交
366 367
        print('random init params...')
        self._exe.run(self._train_init_prog)
X
xixiaoyao 已提交
368

X
xixiaoyao 已提交
369 370
    def load_ckpt(self, model_path, phase='train'):
        # load pretrain model (or ckpt)
X
xixiaoyao 已提交
371 372
        # assert self._exe is not None, "You need to random_init_params before load checkpoints."
        if phase == 'train' and not self._train_init:
X
xixiaoyao 已提交
373
            self._init_exe_prog(for_train=True)
X
xixiaoyao 已提交
374
        if phase == 'predict' and not self._predict_init:
X
xixiaoyao 已提交
375
            self._init_exe_prog(for_train=False)
X
xixiaoyao 已提交
376 377 378 379 380 381

        if phase == 'train':
            assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
X
xixiaoyao 已提交
382 383
                main_program=self._train_init_prog,
                strict=True)
X
xixiaoyao 已提交
384 385 386 387 388
        elif phase == 'predict':
            assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
X
xixiaoyao 已提交
389 390
                main_program=self._pred_init_prog,
                strict=True)
X
xixiaoyao 已提交
391 392 393 394 395 396 397
        else:
            raise NotImplementedError()
            

    def load_predict_model(self, model_path):
        raise NotImplementedError()

X
xixiaoyao 已提交
398
    def load_pretrain(self, model_path, convert=False):
X
xixiaoyao 已提交
399 400 401 402 403 404
        # load pretrain model (or ckpt)
        assert self._exe is not None, "You need to random_init_params before load pretrain models."

        saver.init_pretraining_params(
            self._exe,
            model_path,
X
xixiaoyao 已提交
405
            convert=convert,
X
xixiaoyao 已提交
406
            main_program=self._train_init_prog)
X
xixiaoyao 已提交
407

X
xixiaoyao 已提交
408
    def set_saver(self, save_path, save_steps, save_type='ckpt'):
X
xixiaoyao 已提交
409

X
xixiaoyao 已提交
410 411
        save_type = save_type.split(',')
        if 'predict' in save_type:
X
xixiaoyao 已提交
412
            assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
X
xixiaoyao 已提交
413
            assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
X
xixiaoyao 已提交
414
            self._save_predict = True
X
xixiaoyao 已提交
415 416 417
            if not os.path.exists(save_path):
                os.makedirs(save_path)
        else:
X
xixiaoyao 已提交
418
            self._save_predict = False
X
xixiaoyao 已提交
419 420 421

        if 'ckpt' in save_type:
            if save_path is not None and save_steps is not None:
X
xixiaoyao 已提交
422
                self._save_ckpt = True
X
xixiaoyao 已提交
423 424 425 426
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
            else:
                "WARNING: save_path or save_steps is not set, model will not be saved during training."
X
xixiaoyao 已提交
427
                self._save_ckpt = False
X
xixiaoyao 已提交
428
        else:
X
xixiaoyao 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
            self._save_ckpt = False

        def temp_func():
            if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
                if self._save_predict:
                    self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
                    print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
                if self._save_ckpt:
                    fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
                    print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
                return True
            else:
                return False

        self._check_save = temp_func
            
    def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5):
        """
        Argument:
            save_type: ckpt, predict, pretrain
        """
        iterator = self._train_reader
        self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)


X
xixiaoyao 已提交
454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483

        # if save_path is not None or save_steps is not None:
        #     assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built."
        # if self._save_predict_model:
        #     if save_path is None and save_steps is None:
        #         print('Warning: model will not be saved for this run. If you want to save model, set save_path and save_steps.')
        #     else:
        #         assert save_path is not None, "argument save_path is required to save models."
        #         assert save_steps == -1 or save_steps > 0, "argument save_steps should be -1 (only save the last step of this task) or larger than 0"
        #         if save_path is not None and not os.path.exists(save_path):
        #             os.makedirs(save_path)
        # else:
        #     assert save_path is None, "You should set save_predict_model as True, or the argument save_path is invalid."
        #     assert save_steps is None, "You should set save_predict_model as True, or the argument save_steps is invalid."

        time_begin = time.time()
        for feed in iterator:
            rt_outputs = self.train_one_step(feed)
            # if gpu_dev_count > 1:
            #     feed, mask = feed
            # rt_outputs = self.exe.run(self._train_prog, feed=feed, fetch_list=self._fetch_list)
            # print(rt_outputs)
            # print(len(rt_outputs))
            # if gpu_dev_count > 1:
            #     while mask.pop() == False:
            #         rt_outputs.pop()

            # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}

            task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
X
xixiaoyao 已提交
484
            self._task_head.batch_postprocess(task_rt_outputs)
X
xixiaoyao 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503

            # if self._save_predict_model and self._cur_train_step % save_steps == 0:
            #     self.save(save_path, suffix='.step'+str(self._cur_train_steps))

            if print_steps > 0 and self._cur_train_step % print_steps == 0:
                loss = rt_outputs[self.name+'.loss']
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print("step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
                       (self._cur_train_step-1) % self._steps_pur_epoch + 1, self._steps_pur_epoch, self._cur_train_epoch,
                       loss, print_steps / time_cost))
                time_begin = time.time()

            # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
            #     print(cur_task.name+': train finished!')
            #     cur_task.save()
X
xixiaoyao 已提交
504 505
            
            self._check_save()
X
xixiaoyao 已提交
506 507


X
xixiaoyao 已提交
508 509
            if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch:
                break
X
xixiaoyao 已提交
510 511 512 513 514 515 516
        # save_path = os.path.join(main_conf['save_path'], 'ckpt',
        #                          "step_" + str(global_step))
        # fluid.io.save_persistables(self.exe, save_path, saver_program)
        # print('checkpoint has been saved at '+save_path)

        # print("ALL tasks train finished, exiting...")

X
xixiaoyao 已提交
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
    def get_one_batch(self, phase='train'):
        if phase == 'train':
            return next(self._train_reader)
        elif phase == 'predict':
            return next(self._predict_reader)
        else:
            raise NotImplementedError()
        
    def predict(self, output_dir=None, print_steps=1000):
        """
        Argument:
            save_type: ckpt, predict, pretrain
        """
        iterator = self._predict_reader
        self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()

        if output_dir is not None and not os.path.exists(output_dir):
            os.makedirs(output_dir)

        time_begin = time.time()
        cur_predict_step = 0
        for feed in iterator:
            rt_outputs = self.predict_one_batch(feed)
            # rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
            # print(rt_outputs)
            self._pred_head.batch_postprocess(rt_outputs)

            cur_predict_step += 1

            if print_steps > 0 and cur_predict_step % print_steps == 0:
                time_end = time.time()
                time_cost = time_end - time_begin

                print("batch {}/{}, speed: {:.2f} steps/s".format(
                       cur_predict_step, self._pred_steps_pur_epoch,
                       print_steps / time_cost))
                time_begin = time.time()

        if self._pred_head.epoch_inputs_attrs:
            reader_outputs = self._pred_reader.get_epoch_outputs()
        else:
            reader_outputs = None

        results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
        return results

X
xixiaoyao 已提交
563 564 565 566
    def train_one_step(self, batch, executor=None, distribute_train_prog=None):
        exe = self._exe if executor is None else executor
        distribute_train_prog = self._distribute_train_prog if distribute_train_prog is None else distribute_train_prog

X
xixiaoyao 已提交
567 568
        if gpu_dev_count > 1:
            feed, mask = batch
X
xixiaoyao 已提交
569
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
570 571 572 573
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size)
            for _ in range(num_fakes):
                for item in rt_outputs:
                    item.pop()
X
xixiaoyao 已提交
574 575
        else:
            feed = self._feed_batch_process_fn(batch)
X
xixiaoyao 已提交
576
            rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
577 578

        rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
X
xixiaoyao 已提交
579 580
        self._cur_train_step += 1
        self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
X
xixiaoyao 已提交
581
        return rt_outputs
X
xixiaoyao 已提交
582

X
xixiaoyao 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
    @property
    def num_epochs(self):
        return self._num_epochs

    @property
    def cur_train_steps(self):
        return self._cur_train_step

    @property
    def cur_train_epoch(self):
        return self._cur_train_epoch

    @property
    def steps_pur_epoch(self):
        return self._steps_pur_epoch

X
xixiaoyao 已提交
599 600 601
    def predict_one_batch(self, batch):
        if gpu_dev_count > 1:
            feed, mask = batch
X
xixiaoyao 已提交
602 603 604 605 606
            rt_outputs = self.exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
            num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size)
            for _ in range(num_fakes):
                for item in rt_outputs:
                    item.pop()
X
xixiaoyao 已提交
607
        else:
X
xixiaoyao 已提交
608 609
            feed = self._pred_feed_batch_process_fn(batch)
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list)
X
xixiaoyao 已提交
610

X
xixiaoyao 已提交
611 612
        rt_outputs = {k:v for k,v in zip(self._pred_fetch_name_list, rt_outputs)}
        return rt_outputs
X
xixiaoyao 已提交
613

X
xixiaoyao 已提交
614 615 616
    def _build_head(self, net_inputs, phase, scope=""):
        if phase == 'train':
            output_vars = self._task_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
617
        if phase == 'pred':
X
xixiaoyao 已提交
618
            output_vars = self._pred_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
619 620
        return output_vars
    
X
xixiaoyao 已提交
621 622 623 624 625 626
    def save(self, save_path, suffix=None):
        # dirpath = save_path.rstrip('/').rstrip('\\') + suffix
        if suffix is not None:
            dirpath = os.path.join(save_path, suffix)
        else:
            dirpath = save_path
X
xixiaoyao 已提交
627 628
        self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]

X
xixiaoyao 已提交
629
        prog = self._pred_prog.clone()
X
xixiaoyao 已提交
630 631 632 633 634 635 636 637 638 639
        fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)

        conf = {}
        for k, strv in self._save_protocol.items(): 
            d = None
            v = locals()
            exec('d={}'.format(strv), globals(), v)
            conf[k] = v['d']
        with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
            writer.write(json.dumps(conf, indent=1))
X
xixiaoyao 已提交
640
        print(self._name + ': predict model saved at ' + dirpath)
X
xixiaoyao 已提交
641

X
xixiaoyao 已提交
642
    
X
xixiaoyao 已提交
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
    def _load(self, infer_model_path=None):
        if infer_model_path is None:
            infer_model_path = self._save_infermodel_path
        for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): 
            strv = self._save_protocol[k]
            exec('{}=v'.format(strv))
        pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
            fluid.io.load_inference_model(infer_model_path, self._exe)
        print(self._name+': inference model loaded from ' + infer_model_path)
        return pred_prog

    @property
    def name(self):
        return self._name

    @property
X
xixiaoyao 已提交
659 660
    def num_examples(self):
        return self._num_examples
X
xixiaoyao 已提交
661 662 663 664 665 666 667 668 669 670 671 672

    @property
    def mix_ratio(self):
        if self._mix_ratio is not None:
            return self._mix_ratio
        else:
            raise ValueError("{}: mix_ratio is None".format(self._name))

    @mix_ratio.setter
    def mix_ratio(self, value):
        self._lock = True