basic_task.py 26.9 KB
Newer Older
K
kinghuin 已提交
1 2
# coding:utf-8
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17 18 19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Z
Zeyu Chen 已提交
20
import os
W
wuzewu 已提交
21
import contextlib
22
import time
W
wuzewu 已提交
23
import copy
24
import logging
S
Steffy-zxf 已提交
25
import numpy as np
W
wuzewu 已提交
26
import paddle.fluid as fluid
K
kinghuin 已提交
27
from tb_paddle import SummaryWriter
W
wuzewu 已提交
28 29

import paddlehub as hub
S
Steffy-zxf 已提交
30
from paddlehub.common.paddle_helper import dtype_map, clone_program
W
wuzewu 已提交
31
from paddlehub.common.utils import mkdir, to_list
W
wuzewu 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
from paddlehub.common.logger import logger
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.config import RunConfig


class RunState(object):
    def __init__(self, length):
        self.run_time_begin = time.time()
        self.run_step = 0
        self.run_examples = 0
        self.run_results = [0] * length
        self.run_time_used = 0
        self.run_speed = 0.0

    def __add__(self, other):
        self.run_step += other.run_step
        self.run_examples += other.run_examples
        for index in range(len(self.run_results)):
            self.run_results[index] += other.run_results[index]
        return self

    def update(self):
        self.run_time_used = time.time() - self.run_time_begin
        self.run_speed = self.run_step / self.run_time_used
        return self


W
wuzewu 已提交
59 60 61 62 63 64 65 66 67 68
class RunEnv(object):
    def __init__(self):
        self.current_epoch = 0
        self.current_step = 0
        self.main_program = None
        self.start_program = None
        self.main_program_compiled = None
        self.py_reader = None
        self.reader = None
        self.loss = None
W
wuzewu 已提交
69
        self.labels = None
W
wuzewu 已提交
70 71 72 73 74 75 76 77 78 79 80
        self.metrics = None
        self.is_inititalized = False
        self.UNG = copy.deepcopy(fluid.unique_name.generator)

    def __setattr__(self, key, value):
        self.__dict__[key] = value

    def __getattr__(self, key):
        return self.__dict__[key]


W
wuzewu 已提交
81
class BasicTask(object):
W
wuzewu 已提交
82
    def __init__(self,
W
wuzewu 已提交
83 84 85 86
                 feed_list,
                 data_reader,
                 main_program=None,
                 startup_program=None,
K
kinghuin 已提交
87 88
                 config=None,
                 metrics_choices="default"):
W
wuzewu 已提交
89 90 91 92

        # base item
        self._base_data_reader = data_reader
        self._base_feed_list = feed_list
K
kinghuin 已提交
93 94 95 96 97 98 99 100 101 102 103 104

        # metrics item
        self.best_score = -999
        if metrics_choices == "default":
            metrics_choices = ["acc"]
        elif metrics_choices == None:
            metrics_choices = []
        if isinstance(metrics_choices, list):
            self.metrics_choices = metrics_choices
        else:
            self.metrics_choices = [metrics_choices]

W
wuzewu 已提交
105
        if main_program is None:
S
Steffy-zxf 已提交
106 107 108
            self._base_main_program = clone_program(
                fluid.default_main_program(), for_test=False)

W
wuzewu 已提交
109
        else:
S
Steffy-zxf 已提交
110 111
            self._base_main_program = clone_program(
                main_program, for_test=False)
W
wuzewu 已提交
112
        if startup_program is None:
S
Steffy-zxf 已提交
113 114
            self._base_startup_program = clone_program(
                fluid.default_startup_program(), for_test=False)
W
wuzewu 已提交
115
        else:
S
Steffy-zxf 已提交
116 117
            self._base_startup_program = clone_program(
                startup_program, for_test=False)
W
wuzewu 已提交
118
        self.is_checkpoint_loaded = False
S
Steffy-zxf 已提交
119
        self._base_compiled_program = None
W
wuzewu 已提交
120 121

        # run config
W
wuzewu 已提交
122
        self.config = config if config else RunConfig()
123 124 125
        self.place = self.places[0]
        self.device_count = len(self.places)

W
wuzewu 已提交
126 127 128 129 130 131 132 133
        if self.config.use_data_parallel:
            if not self.config.use_pyreader and self.config.batch_size < self.device_count:
                logger.warning(
                    "Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions"
                    .format(self.config.batch_size, self.device_count))
                logger.warning("Batch size automatically adjusted to {}".format(
                    self.device_count))
                self.config._batch_size = self.device_count
134

W
wuzewu 已提交
135
        self.exe = fluid.Executor(place=self.place)
W
wuzewu 已提交
136 137 138 139 140
        self.build_strategy = fluid.BuildStrategy()

        # log item
        if not os.path.exists(self.config.checkpoint_dir):
            mkdir(self.config.checkpoint_dir)
K
kinghuin 已提交
141 142
        tb_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
        self.tb_writer = SummaryWriter(tb_log_dir)
W
wuzewu 已提交
143 144 145 146

        # run environment
        self._phases = []
        self._envs = {}
W
wuzewu 已提交
147
        self._predict_data = None
W
wuzewu 已提交
148

K
kinghuin 已提交
149 150 151
        # accelerate predict
        self.is_best_model_loaded = False

W
wuzewu 已提交
152 153
        # set default phase
        self.enter_phase("train")
W
wuzewu 已提交
154 155 156

    @contextlib.contextmanager
    def phase_guard(self, phase):
W
wuzewu 已提交
157 158 159 160 161
        self.enter_phase(phase)
        yield
        self.exit_phase()

    def enter_phase(self, phase):
W
wuzewu 已提交
162 163
        if phase not in ["train", "val", "dev", "test", "predict", "inference"]:
            raise RuntimeError()
K
kinghuin 已提交
164 165 166 167
        if phase in ["val", "dev"]:
            phase = "dev"
        elif phase in ["predict", "inference"]:
            phase = "predict"
W
wuzewu 已提交
168
        self._phases.append(phase)
W
wuzewu 已提交
169 170

    def exit_phase(self):
W
wuzewu 已提交
171 172
        self._phases = self._phases[:-1]

W
wuzewu 已提交
173 174 175 176
    def init_if_necessary(self):
        if not self.is_checkpoint_loaded:
            if not self.load_checkpoint():
                self.exe.run(self._base_startup_program)
K
kinghuin 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
            self.is_checkpoint_loaded = True
            self.is_best_model_loaded = False

    def init_if_load_best_model(self):
        if not self.is_best_model_loaded:
            best_model_path = os.path.join(self.config.checkpoint_dir,
                                           "best_model")
            logger.info("Load the best model from %s" % best_model_path)
            if os.path.exists(best_model_path):
                self.load_parameters(best_model_path)
                self.is_checkpoint_loaded = False
                self.is_best_model_loaded = True
            else:
                self.init_if_necessary()
        else:
            logger.info("The best model has been loaded")
W
wuzewu 已提交
193

W
wuzewu 已提交
194 195 196 197 198 199
    def _build_env(self):
        if self.env.is_inititalized:
            return

        self._build_env_start_event()
        self.env.is_inititalized = True
S
Steffy-zxf 已提交
200 201 202
        self.env.main_program = clone_program(
            self._base_main_program, for_test=False)

W
wuzewu 已提交
203 204 205 206
        self.env.startup_program = fluid.Program()
        with fluid.program_guard(self.env.main_program,
                                 self._base_startup_program):
            with fluid.unique_name.guard(self.env.UNG):
207
                self.env.outputs = self._build_net()
W
wuzewu 已提交
208
                if self.is_train_phase or self.is_test_phase:
W
wuzewu 已提交
209
                    self.env.labels = self._add_label()
W
wuzewu 已提交
210 211
                    self.env.loss = self._add_loss()
                    self.env.metrics = self._add_metrics()
W
wuzewu 已提交
212

W
wuzewu 已提交
213
        if self.is_predict_phase or self.is_test_phase:
S
Steffy-zxf 已提交
214 215
            self.env.main_program = clone_program(
                self.env.main_program, for_test=True)
W
wuzewu 已提交
216 217 218
            hub.common.paddle_helper.set_op_attr(
                self.env.main_program, is_test=True)

W
wuzewu 已提交
219 220 221 222 223 224 225 226 227 228 229 230
        if self.config.use_pyreader:
            t_program = fluid.Program()
            with fluid.program_guard(t_program, self.env.startup_program):
                self.env.py_reader = fluid.layers.py_reader(
                    capacity=64,
                    shapes=[var.shape for var in self.feed_var_list],
                    dtypes=[dtype_map[var.dtype] for var in self.feed_var_list],
                    lod_levels=[var.lod_level for var in self.feed_var_list],
                    use_double_buffer=False)

                feed_var_list = self.feed_var_list
                py_vars = fluid.layers.read_file(self.env.py_reader)
W
wuzewu 已提交
231
                py_vars = to_list(py_vars)
W
wuzewu 已提交
232 233 234 235 236 237 238 239 240 241 242 243
                input_dict = {
                    feed_var_list[index].name: py_var
                    for index, py_var in enumerate(py_vars)
                }

                hub.connect_program(
                    pre_program=t_program,
                    next_program=self.env.main_program,
                    input_dict=input_dict,
                    need_log=False)

            self.env.main_program = t_program
W
wuzewu 已提交
244 245 246 247 248 249 250 251 252
            if not self.is_predict_phase:
                self.env.loss = self.env.main_program.global_block().vars[
                    self.env.loss.name]
                metrics_name = [var.name for var in self.env.metrics]
                self.env.metrics = [
                    self.env.main_program.global_block().vars[name]
                    for name in metrics_name
                ]

253 254 255 256 257
            outputs_name = [var.name for var in self.env.outputs]
            self.env.outputs = [
                self.env.main_program.global_block().vars[name]
                for name in outputs_name
            ]
W
wuzewu 已提交
258 259 260 261 262 263

        if self.config.enable_memory_optim:
            for var_name in self.fetch_list:
                var = self.env.main_program.global_block().vars[var_name]
                var.persistable = True

264 265 266 267
        # to avoid to print logger two times in result of the logger usage of paddle-fluid
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)

W
wuzewu 已提交
268 269 270 271
        if self.is_train_phase:
            with fluid.program_guard(self.env.main_program,
                                     self._base_startup_program):
                with fluid.unique_name.guard(self.env.UNG):
K
kinghuin 已提交
272 273 274
                    self.scheduled_lr, self.max_train_steps = self.config.strategy.execute(
                        self.loss, self._base_data_reader, self.config,
                        self.device_count)
W
wuzewu 已提交
275 276 277 278 279 280

        if self.is_train_phase:
            loss_name = self.env.loss.name
        else:
            loss_name = None

K
kinghuin 已提交
281
        share_vars_from = self._base_compiled_program
W
wuzewu 已提交
282

W
wuzewu 已提交
283
        if not self.config.use_data_parallel:
W
wuzewu 已提交
284
            self.env.main_program_compiled = None
W
wuzewu 已提交
285 286 287 288 289 290
        else:
            self.env.main_program_compiled = fluid.CompiledProgram(
                self.env.main_program).with_data_parallel(
                    loss_name=loss_name,
                    share_vars_from=share_vars_from,
                    build_strategy=self.build_strategy)
W
wuzewu 已提交
291 292

        self.exe.run(self.env.startup_program)
293

W
wuzewu 已提交
294 295
        self._build_env_end_event()

296 297 298
    @property
    def places(self):
        if self.config.use_cuda:
W
wuzewu 已提交
299 300 301 302 303 304 305
            _places = fluid.framework.cuda_places()
        else:
            _places = fluid.framework.cpu_places()

        if not self.config.use_data_parallel:
            return [_places[0]]
        return _places
306

S
Steffy-zxf 已提交
307 308 309 310
    @property
    def return_numpy(self):
        return True

W
wuzewu 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
    @property
    def is_train_phase(self):
        return self.phase in ["train"]

    @property
    def is_test_phase(self):
        return self.phase in ["val", "dev", "test"]

    @property
    def is_predict_phase(self):
        return self.phase in ["predict", "inference"]

    @property
    def phase(self):
        return self._phases[-1]

    @property
    def env(self):
        phase = self.phase
        if phase in ["val", "dev", "test"]:
K
kinghuin 已提交
331
            phase = "dev"
W
wuzewu 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
        if not phase in self._envs:
            self._envs[phase] = RunEnv()
        return self._envs[phase]

    @property
    def py_reader(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.py_reader

    @property
    def current_step(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.current_step

    @property
    def current_epoch(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.current_epoch

    @property
Z
Zeyu Chen 已提交
355
    def main_program(self):
W
wuzewu 已提交
356 357 358
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.main_program
Z
Zeyu Chen 已提交
359

W
wuzewu 已提交
360
    @property
Z
Zeyu Chen 已提交
361
    def startup_program(self):
W
wuzewu 已提交
362 363 364 365 366 367 368 369 370 371
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.startup_program

    @property
    def main_program_compiled(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.main_program_compiled

W
wuzewu 已提交
372 373 374
    @property
    def main_program_to_be_run(self):
        if self.config.use_data_parallel:
K
kinghuin 已提交
375 376
            if self._base_compiled_program is None:
                self._base_compiled_program = self.env.main_program_compiled
W
wuzewu 已提交
377 378 379
            return self.main_program_compiled
        return self.main_program

W
wuzewu 已提交
380 381
    @property
    def reader(self):
W
wuzewu 已提交
382 383 384 385
        if self.is_predict_phase:
            data = self._predict_data
        else:
            data = None
W
wuzewu 已提交
386
        self.env.reader = self._base_data_reader.data_generator(
W
wuzewu 已提交
387
            batch_size=self.config.batch_size, phase=self.phase, data=data)
W
wuzewu 已提交
388 389 390 391 392 393 394 395 396 397 398 399
        return self.env.reader

    @property
    def loss(self):
        if self.is_predict_phase:
            raise RuntimeError()

        if not self.env.is_inititalized:
            self._build_env()
        return self.env.loss

    @property
W
wuzewu 已提交
400
    def labels(self):
W
wuzewu 已提交
401 402 403 404 405
        if self.is_predict_phase:
            raise RuntimeError()

        if not self.env.is_inititalized:
            self._build_env()
W
wuzewu 已提交
406
        return self.env.labels
W
wuzewu 已提交
407 408

    @property
409
    def outputs(self):
W
wuzewu 已提交
410 411
        if not self.env.is_inititalized:
            self._build_env()
412
        return self.env.outputs
W
wuzewu 已提交
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430

    @property
    def metrics(self):
        if self.is_predict_phase:
            raise RuntimeError()

        if not self.env.is_inititalized:
            self._build_env()
        return self.env.metrics

    @property
    def unique_name_generator(self):
        return self.env.UNG

    @property
    def feed_list(self):
        feed_list = [varname for varname in self._base_feed_list]
        if self.is_train_phase or self.is_test_phase:
W
wuzewu 已提交
431
            feed_list += [label.name for label in self.labels]
W
wuzewu 已提交
432 433 434 435 436 437 438 439 440 441 442
        return feed_list

    @property
    def feed_var_list(self):
        vars = self.main_program.global_block().vars
        return [vars[varname] for varname in self.feed_list]

    @property
    def fetch_list(self):
        if self.is_train_phase or self.is_test_phase:
            return [metric.name for metric in self.metrics] + [self.loss.name]
443
        return [output.name for output in self.outputs]
W
wuzewu 已提交
444 445

    def _build_env_start_event(self):
W
wuzewu 已提交
446 447
        pass

W
wuzewu 已提交
448
    def _build_env_end_event(self):
K
kinghuin 已提交
449 450
        if not self.is_predict_phase:
            self.env.score_scalar = {}
W
wuzewu 已提交
451

W
wuzewu 已提交
452 453 454 455 456 457 458 459 460 461 462 463
    def _finetune_start_event(self):
        logger.info("PaddleHub finetune start")

    def _finetune_end_event(self, run_states):
        logger.info("PaddleHub finetune finished.")

    def _predict_start_event(self):
        logger.info("PaddleHub predict start")

    def _predict_end_event(self, run_states):
        logger.info("PaddleHub predict finished.")

W
wuzewu 已提交
464 465
    def _eval_start_event(self):
        logger.info("Evaluation on {} dataset start".format(self.phase))
W
wuzewu 已提交
466

S
Steffy-zxf 已提交
467
    def _eval_end_event(self, run_states):
K
kinghuin 已提交
468
        eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states)
K
kinghuin 已提交
469
        if 'train' in self._envs:
K
kinghuin 已提交
470
            self.tb_writer.add_scalar(
K
kinghuin 已提交
471 472
                tag="Loss_{}".format(self.phase),
                scalar_value=eval_loss,
473
                global_step=self._envs['train'].current_step)
K
kinghuin 已提交
474

K
kinghuin 已提交
475 476 477 478 479 480 481
        log_scores = ""
        for metric in eval_scores:
            if 'train' in self._envs:
                self.tb_writer.add_scalar(
                    tag="{}_{}".format(metric, self.phase),
                    scalar_value=eval_scores[metric],
                    global_step=self._envs['train'].current_step)
K
kinghuin 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
            log_scores += "%s=%.5f " % (metric, eval_scores[metric])
        logger.info(
            "[%s dataset evaluation result] loss=%.5f %s[step/sec: %.2f]" %
            (self.phase, eval_loss, log_scores, run_speed))

        eval_scores_items = eval_scores.items()
        if len(eval_scores_items):
            # The first metric will be chose to eval
            main_metric, main_value = list(eval_scores_items)[0]
        else:
            logger.warning(
                "None of metrics has been implemented, loss will be used to evaluate."
            )
            # The larger, the better
            main_metric, main_value = "negative loss", -eval_loss
        if self.phase in ["dev", "val"] and main_value > self.best_score:
            self.best_score = main_value
            model_saved_dir = os.path.join(self.config.checkpoint_dir,
                                           "best_model")
            logger.info("best model saved to %s [best %s=%.5f]" %
                        (model_saved_dir, main_metric, main_value))
K
kinghuin 已提交
503

K
kinghuin 已提交
504 505 506 507
            save_result = fluid.io.save_persistables(
                executor=self.exe,
                dirname=model_saved_dir,
                main_program=self.main_program)
W
wuzewu 已提交
508

S
Steffy-zxf 已提交
509
    def _log_interval_event(self, run_states):
K
kinghuin 已提交
510 511
        scores, avg_loss, run_speed = self._calculate_metrics(run_states)
        self.tb_writer.add_scalar(
K
kinghuin 已提交
512
            tag="Loss_{}".format(self.phase),
K
kinghuin 已提交
513
            scalar_value=avg_loss,
514
            global_step=self._envs['train'].current_step)
K
kinghuin 已提交
515 516 517
        log_scores = ""
        for metric in scores:
            self.tb_writer.add_scalar(
K
kinghuin 已提交
518
                tag="{}_{}".format(metric, self.phase),
K
kinghuin 已提交
519
                scalar_value=scores[metric],
520
                global_step=self._envs['train'].current_step)
K
kinghuin 已提交
521 522 523 524
            log_scores += "%s=%.5f " % (metric, scores[metric])
        logger.info("step %d / %d: loss=%.5f %s[step/sec: %.2f]" %
                    (self.current_step, self.max_train_steps, avg_loss,
                     log_scores, run_speed))
W
wuzewu 已提交
525 526

    def _save_ckpt_interval_event(self):
W
wuzewu 已提交
527
        self.save_checkpoint()
W
wuzewu 已提交
528 529 530 531

    def _eval_interval_event(self):
        self.eval(phase="dev")

W
wuzewu 已提交
532 533
    def _run_step_event(self, run_state):
        if self.is_predict_phase:
W
wuzewu 已提交
534 535 536 537 538 539 540 541 542 543 544 545
            yield run_state.run_results

    def _build_net(self):
        raise NotImplementedError

    def _add_loss(self):
        raise NotImplementedError

    def _add_label(self):
        raise NotImplementedError

    def _add_metrics(self):
K
kinghuin 已提交
546 547
        # Some metrics like acc, auc can be calculated by fluid.layers
        # The others can be calculated in _calculate_metrics function
W
wuzewu 已提交
548 549
        raise NotImplementedError

W
wuzewu 已提交
550
    def _calculate_metrics(self, run_states):
K
kinghuin 已提交
551 552 553
        # NOTE: if you want to customize the metrics
        # you should make sure that the first parameter returned is a dict
        # The first key will be used as main metrics to update the best model
W
wuzewu 已提交
554 555
        raise NotImplementedError

W
wuzewu 已提交
556 557
    # NOTE: current saved checkpoint machanism is not completed,
    # it can't restore dataset training status
W
wuzewu 已提交
558
    def save_checkpoint(self):
W
wuzewu 已提交
559 560 561 562
        save_checkpoint(
            checkpoint_dir=self.config.checkpoint_dir,
            current_epoch=self.current_epoch,
            global_step=self.current_step,
K
kinghuin 已提交
563
            best_score=self.best_score,
W
wuzewu 已提交
564 565 566
            exe=self.exe,
            main_program=self.main_program)

W
wuzewu 已提交
567
    def load_checkpoint(self):
K
kinghuin 已提交
568
        is_load_successful, self.env.current_epoch, self.env.current_step, self.best_score = load_checkpoint(
W
wuzewu 已提交
569 570
            self.config.checkpoint_dir,
            self.exe,
W
wuzewu 已提交
571
            main_program=self.main_program)
W
wuzewu 已提交
572

W
wuzewu 已提交
573 574 575 576 577 578 579 580 581 582 583 584 585
        return is_load_successful

    def load_parameters(self, dirname):
        def if_exist(var):
            path = os.path.join(dirname, var.name)
            return os.path.exists(path)

        fluid.io.load_vars(
            self.exe, dirname, self.main_program, predicate=if_exist)

    def save_parameters(self, dirname):
        fluid.io.save_params(
            self.exe, dirname=dirname, main_program=self.main_program)
S
Steffy-zxf 已提交
586

W
wuzewu 已提交
587
    def finetune_and_eval(self):
588
        return self.finetune(do_eval=True)
W
wuzewu 已提交
589 590

    def finetune(self, do_eval=False):
W
wuzewu 已提交
591 592 593 594 595 596
        # Start to finetune
        with self.phase_guard(phase="train"):
            self.init_if_necessary()
            self._finetune_start_event()
            run_states = []
            if self.current_epoch <= self.config.num_epoch:
W
wuzewu 已提交
597
                while self.current_epoch <= self.config.num_epoch:
K
kinghuin 已提交
598
                    self.config.strategy.step()
W
wuzewu 已提交
599 600
                    run_states = self._run(do_eval=do_eval)
                    self.env.current_epoch += 1
W
wuzewu 已提交
601

W
wuzewu 已提交
602
                # Final evaluation
603 604 605
                if self._base_data_reader.get_dev_examples() != []:
                    self.eval(phase="dev")
                if self._base_data_reader.get_test_examples() != []:
K
kinghuin 已提交
606
                    self.eval(phase="test", load_best_model=True)
607 608
                # Save checkpoint after finetune
                self.save_checkpoint()
W
wuzewu 已提交
609

W
wuzewu 已提交
610
            self._finetune_end_event(run_states)
611
            return run_states
W
wuzewu 已提交
612

K
kinghuin 已提交
613 614 615 616
    def eval(self, phase="dev", load_best_model=False):
        # Warning: DO NOT use eval(load_best_model=True) in finetune_and_eval
        # It will cause trainer unable to continue training from checkpoint after eval
        # More important, The model should evaluate current performance during training.
W
wuzewu 已提交
617
        with self.phase_guard(phase=phase):
K
kinghuin 已提交
618 619 620 621
            if load_best_model:
                self.init_if_load_best_model()
            else:
                self.init_if_necessary()
W
wuzewu 已提交
622 623 624
            self._eval_start_event()
            run_states = self._run()
            self._eval_end_event(run_states)
625
            return run_states
W
wuzewu 已提交
626 627

    def predict(self, data, load_best_model=True):
W
wuzewu 已提交
628
        with self.phase_guard(phase="predict"):
W
wuzewu 已提交
629
            if load_best_model:
K
kinghuin 已提交
630 631 632
                self.init_if_load_best_model()
            else:
                self.init_if_necessary()
W
wuzewu 已提交
633
            self._predict_data = data
W
wuzewu 已提交
634
            self._predict_start_event()
W
wuzewu 已提交
635
            run_states = self._run()
W
wuzewu 已提交
636
            self._predict_end_event(run_states)
W
wuzewu 已提交
637
            self._predict_data = None
638
        return run_states
W
wuzewu 已提交
639 640 641 642 643 644 645 646 647 648 649 650

    def _run(self, do_eval=False):
        with fluid.program_guard(self.main_program, self.startup_program):
            if self.config.use_pyreader:
                return self._run_with_py_reader(do_eval=do_eval)
            return self._run_with_data_feeder(do_eval=do_eval)

    def _run_with_data_feeder(self, do_eval=False):

        data_feeder = fluid.DataFeeder(
            feed_list=self.feed_list, place=self.place)

W
wuzewu 已提交
651 652 653
        global_run_states = []
        period_run_states = []

W
wuzewu 已提交
654
        for run_step, batch in enumerate(self.reader(), start=1):
655 656
            if self.config.use_data_parallel and len(batch) < self.device_count:
                continue
W
wuzewu 已提交
657
            step_run_state = RunState(len(self.fetch_list))
W
wuzewu 已提交
658 659 660
            step_run_state.run_step = 1
            num_batch_examples = len(batch)

S
Steffy-zxf 已提交
661 662 663 664 665 666 667 668 669 670 671 672
            if self.return_numpy:
                fetch_result = self.exe.run(
                    self.main_program_to_be_run,
                    feed=data_feeder.feed(batch),
                    fetch_list=self.fetch_list)
            else:
                fetch_result = self.exe.run(
                    self.main_program_to_be_run,
                    feed=data_feeder.feed(batch),
                    fetch_list=self.fetch_list,
                    return_numpy=False)
                fetch_result = [np.array(x) for x in fetch_result]
W
wuzewu 已提交
673 674 675 676 677 678

            for index, result in enumerate(fetch_result):
                step_run_state.run_results[index] = result
            step_run_state.run_examples += num_batch_examples
            step_run_state.update()
            period_run_states += [step_run_state]
S
Steffy-zxf 已提交
679
            self.env.current_step += 1
W
wuzewu 已提交
680
            if self.is_train_phase:
W
wuzewu 已提交
681 682 683 684 685 686 687 688 689 690 691
                if self.current_step % self.config.log_interval == 0:
                    self._log_interval_event(period_run_states)
                    global_run_states += period_run_states
                    period_run_states = []

                if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
                    self._save_ckpt_interval_event()

                if do_eval and self.current_step % self.config.eval_interval == 0:
                    self._eval_interval_event()

W
wuzewu 已提交
692
            self._run_step_event(step_run_state)
W
wuzewu 已提交
693 694 695 696

        global_run_states += period_run_states
        return global_run_states

W
wuzewu 已提交
697
    def _run_with_py_reader(self, do_eval=False):
W
wuzewu 已提交
698
        flag = False
W
wuzewu 已提交
699
        use_data_parallel_backup = self.config.use_data_parallel
W
wuzewu 已提交
700 701 702 703 704 705 706 707 708 709
        while True:
            global_run_states = []
            period_run_states = []
            self.py_reader.decorate_paddle_reader(self.reader)
            self.py_reader.start()
            try:
                while True:
                    num_batch_examples = self.config.batch_size * self.device_count
                    step_run_state = RunState(len(self.fetch_list))
                    step_run_state.run_step = 1
S
Steffy-zxf 已提交
710 711 712 713 714 715 716 717 718 719 720

                    if self.return_numpy:
                        fetch_result = self.exe.run(
                            self.main_program_to_be_run,
                            fetch_list=self.fetch_list)
                    else:
                        fetch_result = self.exe.run(
                            self.main_program_to_be_run,
                            fetch_list=self.fetch_list,
                            return_numpy=False)
                        fetch_result = [np.array(x) for x in fetch_result]
W
wuzewu 已提交
721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755

                    for index, result in enumerate(fetch_result):
                        step_run_state.run_results[index] = result
                    step_run_state.run_examples += num_batch_examples
                    step_run_state.update()
                    period_run_states += [step_run_state]
                    self.env.current_step += 1
                    if self.is_train_phase:
                        if self.current_step % self.config.log_interval == 0:
                            self._log_interval_event(period_run_states)
                            global_run_states += period_run_states
                            period_run_states = []

                        if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
                            self._save_ckpt_interval_event()

                        if do_eval and self.current_step % self.config.eval_interval == 0:
                            self._eval_interval_event()

                    self._run_step_event(step_run_state)
            except fluid.core.EOFException:
                global_run_states += period_run_states
                self.py_reader.reset()
                '''
                When opening use_data_parallel and use_pyreader, if the amount of data is too small,
                the reader will have thrown EOF Exception when not fetching to the running result.
                In this case, temporarily close the use_data_parallel to get the result.
                '''
                if flag:
                    self.config._use_data_parallel = use_data_parallel_backup
                elif len(global_run_states) == 0:
                    flag = True
                    self.config._use_data_parallel = False
                    continue
                break
W
wuzewu 已提交
756 757

        return global_run_states