trainer.py 25.8 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: utf-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xixiaoyao 已提交
16
from __future__ import print_function
X
xixiaoyao 已提交
17 18 19
import os
import json
from paddle import fluid
X
xixiaoyao 已提交
20 21
import time
import numpy as np
X
xixiaoyao 已提交
22
import paddlepalm.utils.basic_helper as helper
X
xixiaoyao 已提交
23
from paddlepalm.utils import reader_helper, saver
X
xixiaoyao 已提交
24
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
X
xixiaoyao 已提交
25
# from paddlepalm.default_settings import *
X
xixiaoyao 已提交
26

X
xixiaoyao 已提交
27
DEBUG=False
X
xixiaoyao 已提交
28

X
xixiaoyao 已提交
29

X
xixiaoyao 已提交
30 31 32 33 34
class Trainer(object):

    def __init__(self, name, reader, task_head, \
                 mix_ratio=1.0, reuse_head_with=None, \
                 silent=False):
X
xixiaoyao 已提交
35 36 37

        self._name = name
        self._verbose = not silent
X
xixiaoyao 已提交
38 39 40
        self._reader = reader
        self._pred_reader = None
        self._task_head = task_head
X
xixiaoyao 已提交
41
        self._pred_head = None
X
xixiaoyao 已提交
42

X
xixiaoyao 已提交
43 44 45 46 47 48 49 50
        # if save_predict_model:
        #     self._save_predict_model = True
        #     assert pred_head is not None, "pred_head is required to save predict model."
        #     self._pred_reader = reader.clone(phase='pred')
        # else:
        #     assert pred_head is None, "You should set save_predict_model as True, or the pred_head is invalid." 
        #     self._save_predict_model = False
        #     self._pred_reader = None
X
xixiaoyao 已提交
51

X
xixiaoyao 已提交
52
        # self._save_steps = save_steps
X
xixiaoyao 已提交
53

X
xixiaoyao 已提交
54
        self._task_reuse_scope = name if reuse_head_with is None else reuse_head_with
X
xixiaoyao 已提交
55 56 57 58

        self._feeded_var_names = None
        self._target_vars = None

X
xixiaoyao 已提交
59 60
        self._num_examples = 0

X
xixiaoyao 已提交
61 62 63 64 65 66 67 68 69 70
        # training process management
        self._mix_ratio = mix_ratio
        self._expected_train_steps = None
        self._expected_train_epochs = None
        self._steps_pur_epoch = None
        self._cur_train_epoch = 0
        self._cur_train_step = 0
        self._train_finish = False

        # 存放不同运行阶段(train,eval,pred)的数据集reader,key为phase,value为Reader实例
X
xixiaoyao 已提交
71 72
        # self._reader = {'train': reader, 'eval': None, 'pred': self._pred_reader}
        # self._input_layer = None
X
xixiaoyao 已提交
73
        self._inputname_to_varname = {}
X
xixiaoyao 已提交
74
        # self._task_layer = {'train': task_head, 'eval': None, 'pred': pred_head}
X
xixiaoyao 已提交
75 76 77 78 79
        self._pred_input_name_list = []
        self._pred_input_varname_list = []
        self._pred_fetch_name_list = []
        self._pred_fetch_var_list = []

X
xixiaoyao 已提交
80 81
        # exe is built when random_init_params is called.
        # self._exe = helper.build_executor(gpu_dev_count>0)
X
xixiaoyao 已提交
82
        self._exe = None
X
xixiaoyao 已提交
83 84 85 86 87 88 89

        self._save_protocol = {
            'input_names': 'self._pred_input_name_list',
            'input_varnames': 'self._pred_input_varname_list',
            'fetch_list': 'self._pred_fetch_name_list'}

        self._lock = False
X
xixiaoyao 已提交
90 91
        self._build_forward = False

X
xixiaoyao 已提交
92 93 94
    def build_predict_head(self, pred_head, pred_backbone, pred_prog=None, pred_init_prog=None):
        self._pred_head = pred_head
        # self._pred_reader = self._reader.clone(phase='pred')
X
xixiaoyao 已提交
95 96
        pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name)
        # pred_task_attr_from_reader = self._pred_head.inputs_attrs['reader']
X
xixiaoyao 已提交
97

X
xixiaoyao 已提交
98 99 100
        # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred')
        # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone')
X
xixiaoyao 已提交
101
        pred_input_names, pred_shape_and_dtypes, _ = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
X
xixiaoyao 已提交
102 103 104 105
        pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)]
        
        if pred_prog is None:
            pred_prog = fluid.Program()
X
xixiaoyao 已提交
106
        self._pred_prog = pred_prog
X
xixiaoyao 已提交
107 108
        if pred_init_prog is None:
            pred_init_prog = fluid.Program()
X
xixiaoyao 已提交
109
        self._pred_init_prog = pred_init_prog
X
xixiaoyao 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        with fluid.program_guard(pred_prog, pred_init_prog):
            pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
            # pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_')
            pred_bb_output_vars = pred_backbone.build(pred_net_inputs)

        # prepare predict vars for saving inference model
        with fluid.program_guard(pred_prog, pred_init_prog):
            cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
            # self.pred_input = cur_inputs
            self._pred_input_name_list, self._pred_input_varname_list = \
                zip(*[[k, v.name] for k,v in cur_inputs.items()])

            pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
            scope = self.name + '.'
            with fluid.unique_name.guard(scope):
X
xixiaoyao 已提交
125 126 127 128 129 130 131 132 133 134
                output_vars = self._build_head(pred_task_inputs, phase='pred', scope=scope)

        if output_vars is not None:
            self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items())
        else:
            self._pred_fetch_name_list = []
            self._pred_fetch_var_list = []

        self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()
        return output_vars
X
xixiaoyao 已提交
135 136 137


    def build_forward(self, backbone, pred_backbone=None, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None):
X
xixiaoyao 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

        # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
        self._build_forward = True
        
        # create reader, task
        # then check i/o across reader, backbone and task_layer
        task_attrs = []
        pred_task_attrs = []

        task_attr_from_reader = helper.encode_inputs(self._task_head.inputs_attrs['reader'], self.name)
        # task_attr_from_reader = self._task_head.inputs_attrs['reader']

        # _check_io(backbone.inputs_attr, inst._reader['train'].outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train')
        # _check_io(inst.taskblock['train'].inputs_attrs['reader'], inst._reader['train'].outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train')
        # _check_io(inst._taskblock['train'].inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone')


        # merge reader input attrs from backbone and task_instances
X
xixiaoyao 已提交
156
        input_names, shape_and_dtypes, name_to_position = reader_helper.merge_input_attrs(backbone.inputs_attr, task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
X
xixiaoyao 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169
        # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN]
        self._shape_and_dtypes = shape_and_dtypes
        self._name_to_position = name_to_position

        if DEBUG:
            print('----- for debug -----')
            print('joint input names:')
            print(joint_input_names)
            print('joint input shape and dtypes:')
            print(joint_shape_and_dtypes)

        input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]

X
xixiaoyao 已提交
170 171 172 173
        if train_prog is None:
            train_prog = fluid.Program()
        if train_init_prog is None:
            train_init_prog = fluid.Program()
X
xixiaoyao 已提交
174 175 176 177 178
        self._prog = train_prog
        self._train_prog = train_prog
        self._train_init_prog = train_init_prog
        with fluid.program_guard(train_prog, train_init_prog):
            net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
X
xixiaoyao 已提交
179
            self._net_inputs = net_inputs
X
xixiaoyao 已提交
180 181 182 183 184

            # build backbone and task layers
            # bb_output_vars = self._backbone.build(net_inputs, scope_name='__paddlepalm_')
            bb_output_vars = backbone.build(net_inputs)
            assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())
X
xixiaoyao 已提交
185
        # self._bb_output_vars.keys
X
xixiaoyao 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        

        # fluid.framework.switch_main_program(train_prog)
        # fluid.framework.switch_startup_program(train_init_prog)

        task_output_vars = {}
        task_inputs = {'backbone': bb_output_vars}
        task_inputs_from_reader = helper.decode_inputs(net_inputs, self.name)
        task_inputs['reader'] = task_inputs_from_reader

        scope = self.name+'.'
        with fluid.program_guard(train_prog, train_init_prog):
            with fluid.unique_name.guard(scope):
                output_vars = self._build_head(task_inputs, phase='train', scope=scope)
        output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
        old = len(task_output_vars) # for debug
        task_output_vars.update(output_vars)
        assert len(task_output_vars) - old == len(output_vars) # for debug

        bb_fetches = {k: v.name for k,v in bb_output_vars.items()}
        task_fetches = {k: v.name for k,v in task_output_vars.items()}
X
xixiaoyao 已提交
207
        self._fetches = task_fetches
X
xixiaoyao 已提交
208
        self._fetch_names, self._fetch_list = zip(*self._fetches.items())
X
xixiaoyao 已提交
209 210 211 212 213 214 215 216 217 218
        # fetches = task_fetches
        # fetches['__task_id'] = net_inputs['__task_id'].name

        # compute loss
        # task_id_var = net_inputs['__task_id']
        # task_id_vec = layers.one_hot(task_id_var, num_instances)
        # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
        # loss = layers.reduce_sum(task_id_vec * losses)
        with fluid.program_guard(train_prog, train_init_prog):
            loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
X
xixiaoyao 已提交
219

X
xixiaoyao 已提交
220 221 222 223
        for _id, block in enumerate(self._train_prog.blocks):
          for var in block.vars:
            print("[debug] : %d, %s" % (_id, var))

X
xixiaoyao 已提交
224 225 226 227
        return loss_var

    def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999):
        # build optimizer
X
xixiaoyao 已提交
228 229
        assert self._train_init_prog is not None, "train graph not foung! You should build_forward first."
        optimizer._set_prog(self._train_prog, self._train_init_prog)
X
xixiaoyao 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        with fluid.program_guard(self._train_prog, self._train_init_prog):
            param_grads = optimizer.build()

            if weight_decay is not None:

                param_list = dict()

                for param in self._prog.global_block().all_parameters():
                    param_list[param.name] = param * 1.0
                    param_list[param.name].stop_gradient = True

                def exclude_from_weight_decay(name):
                    if name.find("layer_norm") > -1:
                        return True
                    bias_suffix = ["_bias", "_b", ".b_0"]
                    for suffix in bias_suffix:
                        if name.endswith(suffix):
                            return True
                    return False

                for param, grad in param_grads:
                    if exclude_from_weight_decay(param.name):
                        continue
                    with param.block.program._optimized_guard(
                        [param, grad]), fluid.framework.name_scope("weight_decay"):
                        updated_param = param - param_list[
                            param.name] * weight_decay * optimizer.get_cur_learning_rate()
                        fluid.layers.assign(output=param, input=updated_param)


            # loss.persistable = True
            if use_ema:
                ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
                ema.update()

X
xixiaoyao 已提交
265 266 267 268 269 270 271
        # for bid, block in enumerate(self._train_prog.blocks):
        #     print('block id: '+str(bid))
        #     for var in block.vars:
        #         print("%d : %s" % (bid, var))
            
        # print(self._train_prog)

X
xixiaoyao 已提交
272 273 274 275 276 277 278
    def load_data(self, input_file, file_format, batch_size, num_epochs=None, shuffle_train=True):
        # load data
        print("preparing data...", end='')
        self._reader._load_data(input_file=input_file, batch_size=batch_size, \
                                num_epochs=num_epochs, file_format=file_format, \
                                shuffle_train=shuffle_train)
        self._num_examples = self._reader.num_examples
X
xixiaoyao 已提交
279 280 281 282
        # 这里不确定是否要向上取整,需确认
        # tail = self._num_examples % batch_size > 0
        # self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0
        self._steps_pur_epoch = self._num_examples // batch_size
X
xixiaoyao 已提交
283 284 285 286 287 288 289
        print('ok!')

        # merge dataset iterators and create net input vars
        iterator = self._reader._iterator()
        prefix = self.name

        # 对yield出的数据进行runtime检查和适配
X
xixiaoyao 已提交
290 291 292 293 294 295 296 297
        iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, self._shape_and_dtypes, self._name_to_position, return_type='dict')
        feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(self._net_inputs)
        self._feed_batch_process_fn = feed_batch_process_fn
        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn)
        else:
            distribute_feeder_fn = iterator_fn
        return distribute_feeder_fn()
X
xixiaoyao 已提交
298 299

    def random_init_params(self):
X
xixiaoyao 已提交
300
        assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters."
X
xixiaoyao 已提交
301
        self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=loss_var.name)
X
xixiaoyao 已提交
302 303
        on_gpu = gpu_dev_count > 0
        self._exe = helper.build_executor(on_gpu)
X
xixiaoyao 已提交
304 305
        print('random init params...')
        self._exe.run(self._train_init_prog)
X
xixiaoyao 已提交
306

X
xixiaoyao 已提交
307 308
    def load_ckpt(self, model_path, phase='train'):
        # load pretrain model (or ckpt)
X
xixiaoyao 已提交
309
        assert self._exe is not None, "You need to random_init_params before load checkpoints."
X
xixiaoyao 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329

        if phase == 'train':
            assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
                main_program=self._train_init_prog)
        elif phase == 'predict':
            assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint."
            saver.init_pretraining_params(
                self._exe,
                model_path,
                main_program=self._pred_init_prog)
        else:
            raise NotImplementedError()
            

    def load_predict_model(self, model_path):
        raise NotImplementedError()

X
xixiaoyao 已提交
330 331 332 333 334 335 336 337
    def load_pretrain(self, model_path):
        # load pretrain model (or ckpt)
        assert self._exe is not None, "You need to random_init_params before load pretrain models."

        saver.init_pretraining_params(
            self._exe,
            model_path,
            main_program=self._train_init_prog)
X
xixiaoyao 已提交
338

X
xixiaoyao 已提交
339 340
    def set_predict_head(self):
        pass
X
xixiaoyao 已提交
341

X
xixiaoyao 已提交
342
    def train(self, iterator, save_path=None, save_steps=None, save_type='ckpt', print_steps=5):
X
xixiaoyao 已提交
343 344 345 346
        """
        Argument:
            save_type: ckpt, predict, pretrain
        """
X
xixiaoyao 已提交
347

X
xixiaoyao 已提交
348 349
        save_type = save_type.split(',')
        if 'predict' in save_type:
X
xixiaoyao 已提交
350
            assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
X
xixiaoyao 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
            assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
            save_predict = True
            if not os.path.exists(save_path):
                os.makedirs(save_path)
        else:
            save_predict = False

        if 'ckpt' in save_type:
            if save_path is not None and save_steps is not None:
                save_ckpt = True
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
            else:
                "WARNING: save_path or save_steps is not set, model will not be saved during training."
                save_ckpt = False
        else:
            save_ckpt = False

        # if save_path is not None or save_steps is not None:
        #     assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built."
        # if self._save_predict_model:
        #     if save_path is None and save_steps is None:
        #         print('Warning: model will not be saved for this run. If you want to save model, set save_path and save_steps.')
        #     else:
        #         assert save_path is not None, "argument save_path is required to save models."
        #         assert save_steps == -1 or save_steps > 0, "argument save_steps should be -1 (only save the last step of this task) or larger than 0"
        #         if save_path is not None and not os.path.exists(save_path):
        #             os.makedirs(save_path)
        # else:
        #     assert save_path is None, "You should set save_predict_model as True, or the argument save_path is invalid."
        #     assert save_steps is None, "You should set save_predict_model as True, or the argument save_steps is invalid."

        time_begin = time.time()
        for feed in iterator:
            rt_outputs = self.train_one_step(feed)
            # if gpu_dev_count > 1:
            #     feed, mask = feed
            # rt_outputs = self.exe.run(self._train_prog, feed=feed, fetch_list=self._fetch_list)
            # print(rt_outputs)
            # print(len(rt_outputs))
            # if gpu_dev_count > 1:
            #     while mask.pop() == False:
            #         rt_outputs.pop()

            # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}

            task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
            self._task_head.postprocess(task_rt_outputs)

            self._cur_train_step += 1
            self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch

            # if self._save_predict_model and self._cur_train_step % save_steps == 0:
            #     self.save(save_path, suffix='.step'+str(self._cur_train_steps))

            if print_steps > 0 and self._cur_train_step % print_steps == 0:
                loss = rt_outputs[self.name+'.loss']
                loss = np.mean(np.squeeze(loss)).tolist()

                time_end = time.time()
                time_cost = time_end - time_begin

                print("step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
                       (self._cur_train_step-1) % self._steps_pur_epoch + 1, self._steps_pur_epoch, self._cur_train_epoch,
                       loss, print_steps / time_cost))
                time_begin = time.time()

            # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
            #     print(cur_task.name+': train finished!')
            #     cur_task.save()

            if (save_predict or save_ckpt) and self._cur_train_step % save_steps == 0:
X
xixiaoyao 已提交
423 424
                if save_predict:
                    self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
X
xixiaoyao 已提交
425
                if save_ckpt:
X
xixiaoyao 已提交
426 427
                    fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
                    print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
X
xixiaoyao 已提交
428 429 430 431 432 433 434 435 436 437 438 439

        # save_path = os.path.join(main_conf['save_path'], 'ckpt',
        #                          "step_" + str(global_step))
        # fluid.io.save_persistables(self.exe, save_path, saver_program)
        # print('checkpoint has been saved at '+save_path)

        # print("ALL tasks train finished, exiting...")

    def train_one_step(self, batch):
        if gpu_dev_count > 1:
            feed, mask = batch
            rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
440 441
            while mask.pop() == False:
                rt_outputs.pop()
X
xixiaoyao 已提交
442 443 444 445 446 447
        else:
            feed = self._feed_batch_process_fn(batch)
            rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list)

        rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
        return rt_outputs
X
xixiaoyao 已提交
448 449 450 451

    def predict_one_batch(self, batch):
        if gpu_dev_count > 1:
            feed, mask = batch
X
xixiaoyao 已提交
452
            rt_outputs = self.exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
453 454 455 456
            while mask.pop() == False:
                rt_outputs.pop()
        else:
            feed = self._feed_batch_process_fn(batch)
X
xixiaoyao 已提交
457
            rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._fetch_list)
X
xixiaoyao 已提交
458 459 460

        rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
        
X
xixiaoyao 已提交
461
        
X
xixiaoyao 已提交
462

X
xixiaoyao 已提交
463 464 465
    def _build_head(self, net_inputs, phase, scope=""):
        if phase == 'train':
            output_vars = self._task_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
466
        if phase == 'pred':
X
xixiaoyao 已提交
467
            output_vars = self._pred_head.build(net_inputs, scope_name=scope)
X
xixiaoyao 已提交
468 469 470 471 472 473 474 475
        return output_vars

    def _postprocess(self, rt_outputs, phase):
        return self._task_layer[phase].postprocess(rt_outputs)

    def _epoch_postprocess(self, epoch_inputs, phase):
        return self._task_layer[phase].epoch_postprocess(epoch_inputs)
    
X
xixiaoyao 已提交
476 477 478 479 480 481
    def save(self, save_path, suffix=None):
        # dirpath = save_path.rstrip('/').rstrip('\\') + suffix
        if suffix is not None:
            dirpath = os.path.join(save_path, suffix)
        else:
            dirpath = save_path
X
xixiaoyao 已提交
482 483
        self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list]

X
xixiaoyao 已提交
484
        prog = self._pred_prog.clone()
X
xixiaoyao 已提交
485 486 487 488 489 490 491 492 493 494
        fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog)

        conf = {}
        for k, strv in self._save_protocol.items(): 
            d = None
            v = locals()
            exec('d={}'.format(strv), globals(), v)
            conf[k] = v['d']
        with open(os.path.join(dirpath, '__conf__'), 'w') as writer:
            writer.write(json.dumps(conf, indent=1))
X
xixiaoyao 已提交
495
        print(self._name + ': predict model saved at ' + dirpath)
X
xixiaoyao 已提交
496

X
xixiaoyao 已提交
497
    
X
xixiaoyao 已提交
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    def _load(self, infer_model_path=None):
        if infer_model_path is None:
            infer_model_path = self._save_infermodel_path
        for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): 
            strv = self._save_protocol[k]
            exec('{}=v'.format(strv))
        pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \
            fluid.io.load_inference_model(infer_model_path, self._exe)
        print(self._name+': inference model loaded from ' + infer_model_path)
        return pred_prog

    @property
    def name(self):
        return self._name

    @property
X
xixiaoyao 已提交
514 515
    def num_examples(self):
        return self._num_examples
X
xixiaoyao 已提交
516

X
xixiaoyao 已提交
517 518 519
    # @property
    # def _pred_input(self):
    #     return zip(*[self._pred_input_name_list, self._pred_input_varname_list])
X
xixiaoyao 已提交
520

X
xixiaoyao 已提交
521 522 523 524 525
    # @_pred_input.setter
    # def _pred_input(self, val):
    #     assert isinstance(val, dict)
    #     self._pred_input_name_list, self._pred_input_varname_list = \
    #         zip(*[[k, v.name] for k,v in val.items()])
X
xixiaoyao 已提交
526

X
xixiaoyao 已提交
527 528 529
    # @property
    # def _pred_fetch_list(self):
    #     return [self._pred_fetch_name_list, self._pred_fetch_var_list]
X
xixiaoyao 已提交
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556

    @property
    def mix_ratio(self):
        if self._mix_ratio is not None:
            return self._mix_ratio
        else:
            raise ValueError("{}: mix_ratio is None".format(self._name))

    @mix_ratio.setter
    def mix_ratio(self, value):
        self._mix_ratio = float(value)
        if self._verbose:
            print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio))

    @property
    def save_infermodel_every_n_steps(self):
        return self._save_infermodel_every_n_steps

    @save_infermodel_every_n_steps.setter
    def save_infermodel_every_n_steps(self, val):
        self._save_infermodel_every_n_steps = val

    @property
    def expected_train_steps(self):
        return self._expected_train_steps

    @expected_train_steps.setter
X
xixiaoyao 已提交
557
    def expected_train_steps(self, value):
X
xixiaoyao 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
        self._expected_train_steps = value
        self._expected_train_epochs = value / float(self._steps_pur_epoch)

    @property
    def expected_train_epochs(self):
        return self._expected_train_epochs

    @property
    def cur_train_epoch(self):
        return self._cur_train_epoch

    @property
    def cur_train_step(self):
        return self._cur_train_step

X
xixiaoyao 已提交
573 574 575 576 577 578 579 580
    # @cur_train_step.setter
    # def _cur_train_step(self, value):
    #     self._cur_train_step = value
    #     if self._cur_train_step > self._steps_pur_epoch:
    #         self._cur_train_epoch += 1
    #         self._cur_train_step = 1
    #     if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
    #         self._train_finish = True
X
xixiaoyao 已提交
581 582 583 584 585 586

    @property
    def steps_pur_epoch(self):
        return self._steps_pur_epoch

    @steps_pur_epoch.setter
X
xixiaoyao 已提交
587
    def steps_pur_epoch(self, value):
X
xixiaoyao 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
        self._steps_pur_epoch = value

    @property
    def train_finish(self):
        return self._train_finish

    def tasklayer_reuse_with(self, task):
        assert isinstance(task, Task)
        if self._lock:
            raise Exception('you can only set tasklayer reuses BEFORE Controller created.')
        self._task_reuse_scope = task.name
    
    def _set_lock(self):
        self._lock = True