base_task.py 38.2 KB
Newer Older
K
kinghuin 已提交
1 2
# coding:utf-8
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17 18 19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Z
Zeyu Chen 已提交
20
import os
W
wuzewu 已提交
21
import contextlib
22
import time
W
wuzewu 已提交
23
import copy
K
kinghuin 已提交
24 25
import inspect
from functools import partial
K
kinghuin 已提交
26
from collections import OrderedDict
K
kinghuin 已提交
27 28 29 30 31
import six
if six.PY2:
    from inspect import getargspec as get_args
else:
    from inspect import getfullargspec as get_args
S
Steffy-zxf 已提交
32
import numpy as np
W
wuzewu 已提交
33
import paddle.fluid as fluid
K
kinghuin 已提交
34
from tb_paddle import SummaryWriter
W
wuzewu 已提交
35 36

import paddlehub as hub
S
Steffy-zxf 已提交
37
from paddlehub.common.paddle_helper import dtype_map, clone_program
38 39
from paddlehub.common.utils import mkdir
from paddlehub.common.dir import tmp_dir
W
wuzewu 已提交
40 41 42 43 44 45
from paddlehub.common.logger import logger
from paddlehub.finetune.checkpoint import load_checkpoint, save_checkpoint
from paddlehub.finetune.config import RunConfig


class RunState(object):
46 47 48 49 50 51 52
    """
    RunState is used to save the result of every running step

    Args:
        length (int): the number of fetch result
    """

W
wuzewu 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    def __init__(self, length):
        self.run_time_begin = time.time()
        self.run_step = 0
        self.run_examples = 0
        self.run_results = [0] * length
        self.run_time_used = 0
        self.run_speed = 0.0

    def __add__(self, other):
        self.run_step += other.run_step
        self.run_examples += other.run_examples
        for index in range(len(self.run_results)):
            self.run_results[index] += other.run_results[index]
        return self

    def update(self):
        self.run_time_used = time.time() - self.run_time_begin
        self.run_speed = self.run_step / self.run_time_used
        return self


W
wuzewu 已提交
74
class RunEnv(object):
75 76 77 78
    """
    RunEnv saves the running environment of the train/dev/predict phase, including program, reader, metrics and so on.
    """

W
wuzewu 已提交
79 80 81 82 83 84 85 86 87
    def __init__(self):
        self.current_epoch = 0
        self.current_step = 0
        self.main_program = None
        self.start_program = None
        self.main_program_compiled = None
        self.py_reader = None
        self.reader = None
        self.loss = None
W
wuzewu 已提交
88
        self.labels = None
W
wuzewu 已提交
89 90 91 92 93 94 95 96 97 98 99
        self.metrics = None
        self.is_inititalized = False
        self.UNG = copy.deepcopy(fluid.unique_name.generator)

    def __setattr__(self, key, value):
        self.__dict__[key] = value

    def __getattr__(self, key):
        return self.__dict__[key]


K
kinghuin 已提交
100
class TaskHooks():
101 102 103 104
    """
    TaskHooks can handle some tasks during the spectific event.
    """

K
kinghuin 已提交
105 106
    def __init__(self):
        self._registered_hooks = {
K
kinghuin 已提交
107 108 109 110 111 112 113 114 115 116 117 118
            "build_env_start_event": OrderedDict(),
            "build_env_end_event": OrderedDict(),
            "finetune_start_event": OrderedDict(),
            "finetune_end_event": OrderedDict(),
            "predict_start_event": OrderedDict(),
            "predict_end_event": OrderedDict(),
            "eval_start_event": OrderedDict(),
            "eval_end_event": OrderedDict(),
            "log_interval_event": OrderedDict(),
            "save_ckpt_interval_event": OrderedDict(),
            "eval_interval_event": OrderedDict(),
            "run_step_event": OrderedDict(),
K
kinghuin 已提交
119 120
        }
        self._hook_params_num = {
K
kinghuin 已提交
121 122 123 124 125 126 127 128 129 130 131 132
            "build_env_start_event": 1,
            "build_env_end_event": 1,
            "finetune_start_event": 1,
            "finetune_end_event": 2,
            "predict_start_event": 1,
            "predict_end_event": 2,
            "eval_start_event": 1,
            "eval_end_event": 2,
            "log_interval_event": 2,
            "save_ckpt_interval_event": 1,
            "eval_interval_event": 1,
            "run_step_event": 2,
K
kinghuin 已提交
133 134 135
        }

    def add(self, hook_type, name=None, func=None):
136 137 138 139 140 141 142 143
        """
        add the handler function to spectific event.

        Args:
            hook_type (str): the spectific event name
            name (str): the handler function name, default None
            func (func): the handler function, default None
        """
K
kinghuin 已提交
144 145 146
        if not func or not callable(func):
            raise TypeError(
                "The hook function is empty or it is not a function")
K
kinghuin 已提交
147
        if name == None:
K
kinghuin 已提交
148 149 150
            name = "hook_%s" % id(func)

        # check validity
K
kinghuin 已提交
151 152
        if not isinstance(name, str) or name.strip() == "":
            raise TypeError("The hook name must be a non-empty string")
K
kinghuin 已提交
153 154 155 156 157 158 159
        if hook_type not in self._registered_hooks:
            raise ValueError("hook_type: %s does not exist" % (hook_type))
        if name in self._registered_hooks[hook_type]:
            raise ValueError(
                "name: %s has existed in hook_type:%s, use modify method to modify it"
                % (name, hook_type))
        else:
K
kinghuin 已提交
160
            args_num = len(get_args(func).args)
K
kinghuin 已提交
161 162 163 164 165 166 167
            if args_num != self._hook_params_num[hook_type]:
                raise ValueError(
                    "The number of parameters to the hook hook_type:%s should be %i"
                    % (hook_type, self._hook_params_num[hook_type]))
            self._registered_hooks[hook_type][name] = func

    def delete(self, hook_type, name):
168 169 170 171 172 173 174
        """
        delete the handler function of spectific event.

        Args:
            hook_type (str): the spectific event name
            name (str): the handler function name
        """
K
kinghuin 已提交
175 176 177 178 179 180 181 182
        if self.exist(hook_type, name):
            del self._registered_hooks[hook_type][name]
        else:
            raise ValueError(
                "No hook_type: %s exists or name: %s does not exist in hook_type: %s"
                % (hook_type, name, hook_type))

    def modify(self, hook_type, name, func):
183 184 185 186 187 188 189 190
        """
        modify the handler function of spectific event.

        Args:
            hook_type (str): the spectific event name
            name (str): the handler function name
            func (func): the new handler function
        """
K
kinghuin 已提交
191 192 193 194 195 196 197 198 199 200 201 202
        if not (isinstance(name, str) and callable(func)):
            raise TypeError(
                "The hook name must be a string, and the hook function must be a function"
            )
        if self.exist(hook_type, name):
            self._registered_hooks[hook_type][name] = func
        else:
            raise ValueError(
                "No hook_type: %s exists or name: %s does not exist in hook_type: %s"
                % (hook_type, name, hook_type))

    def exist(self, hook_type, name):
203 204 205 206 207 208 209 210 211 212
        """
        check if the the handler function of spectific event is existing.

        Args:
            hook_type (str): the spectific event name
            name (str): the handler function name

        Returns:
            bool: True or False
        """
K
kinghuin 已提交
213 214 215 216 217 218
        if hook_type not in self._registered_hooks \
                or name not in self._registered_hooks[hook_type]:
            return False
        else:
            return True

K
kinghuin 已提交
219
    def info(self, show_default=False):
220 221 222 223 224 225 226 227 228
        """
        get the hooks information, including the source code.

        Args:
            show_default (bool): show the information of Paddlehub default hooks or not, default False

        Returns:
            str: the formatted string of the hooks information
        """
K
kinghuin 已提交
229 230 231 232 233
        # formatted output the source code
        ret = ""
        for hook_type, hooks in self._registered_hooks.items():
            already_print_type = False
            for name, func in hooks.items():
K
kinghuin 已提交
234
                if name == "default" and not show_default:
K
kinghuin 已提交
235 236 237 238 239 240 241 242 243 244 245 246
                    continue
                if not already_print_type:
                    ret += "hook_type: %s{\n" % hook_type
                    already_print_type = True
                source = inspect.getsource(func)
                ret += " name: %s{\n" % name
                for line in source.split("\n"):
                    ret += "  %s\n" % line
                ret += " }\n"
            if already_print_type:
                ret += "}\n"
        if not ret:
K
kinghuin 已提交
247
            ret = "Not any customized hooks have been defined, you can set show_default=True to see the default hooks information"
K
kinghuin 已提交
248 249 250 251 252 253
        return ret

    def __getitem__(self, hook_type):
        return self._registered_hooks[hook_type]

    def __repr__(self):
254
        return self.info(show_default=False)
K
kinghuin 已提交
255 256


K
kinghuin 已提交
257
class BaseTask(object):
258 259 260 261 262 263 264 265 266 267 268 269
    """
    BaseTask is the base class of all the task. It will complete the building of all the running environment.

    Args:
        feed_list (list): the inputs name
        data_reader (object): data reader for the task
        main_program (object): the customized main_program, default None
        startup_program (object): the customized startup_program, default None
        config (object): the config for the task, default None
        metrics_choices (list): metrics used to the task, default ["acc"]
    """

W
wuzewu 已提交
270
    def __init__(self,
W
wuzewu 已提交
271 272 273 274
                 feed_list,
                 data_reader,
                 main_program=None,
                 startup_program=None,
K
kinghuin 已提交
275 276
                 config=None,
                 metrics_choices="default"):
W
wuzewu 已提交
277 278 279
        # base item
        self._base_data_reader = data_reader
        self._base_feed_list = feed_list
K
kinghuin 已提交
280 281 282 283 284 285 286 287 288 289 290 291

        # metrics item
        self.best_score = -999
        if metrics_choices == "default":
            metrics_choices = ["acc"]
        elif metrics_choices == None:
            metrics_choices = []
        if isinstance(metrics_choices, list):
            self.metrics_choices = metrics_choices
        else:
            self.metrics_choices = [metrics_choices]

W
wuzewu 已提交
292
        if main_program is None:
S
Steffy-zxf 已提交
293 294 295
            self._base_main_program = clone_program(
                fluid.default_main_program(), for_test=False)

W
wuzewu 已提交
296
        else:
S
Steffy-zxf 已提交
297 298
            self._base_main_program = clone_program(
                main_program, for_test=False)
W
wuzewu 已提交
299
        if startup_program is None:
S
Steffy-zxf 已提交
300 301
            self._base_startup_program = clone_program(
                fluid.default_startup_program(), for_test=False)
W
wuzewu 已提交
302
        else:
S
Steffy-zxf 已提交
303 304
            self._base_startup_program = clone_program(
                startup_program, for_test=False)
W
wuzewu 已提交
305
        self.is_checkpoint_loaded = False
S
Steffy-zxf 已提交
306
        self._base_compiled_program = None
W
wuzewu 已提交
307 308

        # run config
W
wuzewu 已提交
309
        self.config = config if config else RunConfig()
310 311 312
        self.place = self.places[0]
        self.device_count = len(self.places)

W
wuzewu 已提交
313 314 315 316 317 318 319 320
        if self.config.use_data_parallel:
            if not self.config.use_pyreader and self.config.batch_size < self.device_count:
                logger.warning(
                    "Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions"
                    .format(self.config.batch_size, self.device_count))
                logger.warning("Batch size automatically adjusted to {}".format(
                    self.device_count))
                self.config._batch_size = self.device_count
321

W
wuzewu 已提交
322
        self.exe = fluid.Executor(place=self.place)
W
wuzewu 已提交
323 324 325 326 327
        self.build_strategy = fluid.BuildStrategy()

        # run environment
        self._phases = []
        self._envs = {}
W
wuzewu 已提交
328
        self._predict_data = None
329
        self._tb_writer = None
W
wuzewu 已提交
330

K
kinghuin 已提交
331 332 333 334
        # event hooks
        self._hooks = TaskHooks()
        for hook_type, event_hooks in self._hooks._registered_hooks.items():
            self._hooks.add(hook_type, "default",
K
kinghuin 已提交
335 336
                            eval("self._default_%s" % hook_type))
            setattr(BaseTask, "_%s" % hook_type,
K
kinghuin 已提交
337 338
                    self.create_event_function(hook_type))

K
kinghuin 已提交
339 340
        # accelerate predict
        self.is_best_model_loaded = False
341
        self._predictor = None
K
kinghuin 已提交
342

W
wuzewu 已提交
343 344
        # set default phase
        self.enter_phase("train")
W
wuzewu 已提交
345 346 347

    @contextlib.contextmanager
    def phase_guard(self, phase):
W
wuzewu 已提交
348 349 350 351 352
        self.enter_phase(phase)
        yield
        self.exit_phase()

    def enter_phase(self, phase):
W
wuzewu 已提交
353 354
        if phase not in ["train", "val", "dev", "test", "predict", "inference"]:
            raise RuntimeError()
K
kinghuin 已提交
355 356 357 358
        if phase in ["val", "dev"]:
            phase = "dev"
        elif phase in ["predict", "inference"]:
            phase = "predict"
W
wuzewu 已提交
359
        self._phases.append(phase)
W
wuzewu 已提交
360 361

    def exit_phase(self):
W
wuzewu 已提交
362 363
        self._phases = self._phases[:-1]

W
wuzewu 已提交
364 365 366 367
    def init_if_necessary(self):
        if not self.is_checkpoint_loaded:
            if not self.load_checkpoint():
                self.exe.run(self._base_startup_program)
K
kinghuin 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
            self.is_checkpoint_loaded = True
            self.is_best_model_loaded = False

    def init_if_load_best_model(self):
        if not self.is_best_model_loaded:
            best_model_path = os.path.join(self.config.checkpoint_dir,
                                           "best_model")
            logger.info("Load the best model from %s" % best_model_path)
            if os.path.exists(best_model_path):
                self.load_parameters(best_model_path)
                self.is_checkpoint_loaded = False
                self.is_best_model_loaded = True
            else:
                self.init_if_necessary()
        else:
            logger.info("The best model has been loaded")
W
wuzewu 已提交
384

W
wuzewu 已提交
385
    def _build_env(self):
386 387 388
        """
        building the program and strategy for specific running phase.
        """
W
wuzewu 已提交
389 390 391 392 393
        if self.env.is_inititalized:
            return

        self._build_env_start_event()
        self.env.is_inititalized = True
S
Steffy-zxf 已提交
394 395 396
        self.env.main_program = clone_program(
            self._base_main_program, for_test=False)

W
wuzewu 已提交
397 398 399 400
        self.env.startup_program = fluid.Program()
        with fluid.program_guard(self.env.main_program,
                                 self._base_startup_program):
            with fluid.unique_name.guard(self.env.UNG):
401
                self.env.outputs = self._build_net()
W
wuzewu 已提交
402
                if self.is_train_phase or self.is_test_phase:
W
wuzewu 已提交
403
                    self.env.labels = self._add_label()
W
wuzewu 已提交
404 405
                    self.env.loss = self._add_loss()
                    self.env.metrics = self._add_metrics()
W
wuzewu 已提交
406

W
wuzewu 已提交
407
        if self.is_predict_phase or self.is_test_phase:
S
Steffy-zxf 已提交
408 409
            self.env.main_program = clone_program(
                self.env.main_program, for_test=True)
W
wuzewu 已提交
410 411 412
            hub.common.paddle_helper.set_op_attr(
                self.env.main_program, is_test=True)

W
wuzewu 已提交
413 414 415 416 417 418 419 420 421
        if self.config.enable_memory_optim:
            for var_name in self.fetch_list:
                var = self.env.main_program.global_block().vars[var_name]
                var.persistable = True

        if self.is_train_phase:
            with fluid.program_guard(self.env.main_program,
                                     self._base_startup_program):
                with fluid.unique_name.guard(self.env.UNG):
K
kinghuin 已提交
422 423 424
                    self.scheduled_lr, self.max_train_steps = self.config.strategy.execute(
                        self.loss, self._base_data_reader, self.config,
                        self.device_count)
W
wuzewu 已提交
425 426 427 428 429 430

        if self.is_train_phase:
            loss_name = self.env.loss.name
        else:
            loss_name = None

K
kinghuin 已提交
431
        share_vars_from = self._base_compiled_program
W
wuzewu 已提交
432

W
wuzewu 已提交
433
        if not self.config.use_data_parallel:
W
wuzewu 已提交
434
            self.env.main_program_compiled = None
W
wuzewu 已提交
435 436 437 438 439
        else:
            self.env.main_program_compiled = fluid.CompiledProgram(
                self.env.main_program).with_data_parallel(
                    loss_name=loss_name,
                    share_vars_from=share_vars_from,
440 441
                    build_strategy=self.build_strategy,
                    places=self.places)
W
wuzewu 已提交
442 443 444 445

        self.exe.run(self.env.startup_program)
        self._build_env_end_event()

446 447 448
    @property
    def places(self):
        if self.config.use_cuda:
W
wuzewu 已提交
449 450 451 452 453 454 455
            _places = fluid.framework.cuda_places()
        else:
            _places = fluid.framework.cpu_places()

        if not self.config.use_data_parallel:
            return [_places[0]]
        return _places
456

S
Steffy-zxf 已提交
457 458 459 460
    @property
    def return_numpy(self):
        return True

W
wuzewu 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    @property
    def is_train_phase(self):
        return self.phase in ["train"]

    @property
    def is_test_phase(self):
        return self.phase in ["val", "dev", "test"]

    @property
    def is_predict_phase(self):
        return self.phase in ["predict", "inference"]

    @property
    def phase(self):
        return self._phases[-1]

    @property
    def env(self):
        phase = self.phase
        if phase in ["val", "dev", "test"]:
K
kinghuin 已提交
481
            phase = "dev"
W
wuzewu 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
        if not phase in self._envs:
            self._envs[phase] = RunEnv()
        return self._envs[phase]

    @property
    def py_reader(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.py_reader

    @property
    def current_step(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.current_step

    @property
    def current_epoch(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.current_epoch

    @property
Z
Zeyu Chen 已提交
505
    def main_program(self):
W
wuzewu 已提交
506 507 508
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.main_program
Z
Zeyu Chen 已提交
509

W
wuzewu 已提交
510
    @property
Z
Zeyu Chen 已提交
511
    def startup_program(self):
W
wuzewu 已提交
512 513 514 515 516 517 518 519 520 521
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.startup_program

    @property
    def main_program_compiled(self):
        if not self.env.is_inititalized:
            self._build_env()
        return self.env.main_program_compiled

W
wuzewu 已提交
522 523 524
    @property
    def main_program_to_be_run(self):
        if self.config.use_data_parallel:
K
kinghuin 已提交
525 526
            if self._base_compiled_program is None:
                self._base_compiled_program = self.env.main_program_compiled
W
wuzewu 已提交
527 528 529
            return self.main_program_compiled
        return self.main_program

W
wuzewu 已提交
530 531
    @property
    def reader(self):
W
wuzewu 已提交
532 533 534 535
        if self.is_predict_phase:
            data = self._predict_data
        else:
            data = None
W
wuzewu 已提交
536
        self.env.reader = self._base_data_reader.data_generator(
537 538 539 540
            batch_size=self.config.batch_size,
            phase=self.phase,
            data=data,
            return_list=not self.config.use_pyreader)
W
wuzewu 已提交
541 542 543 544 545 546 547 548 549 550 551 552
        return self.env.reader

    @property
    def loss(self):
        if self.is_predict_phase:
            raise RuntimeError()

        if not self.env.is_inititalized:
            self._build_env()
        return self.env.loss

    @property
W
wuzewu 已提交
553
    def labels(self):
W
wuzewu 已提交
554 555 556 557 558
        if self.is_predict_phase:
            raise RuntimeError()

        if not self.env.is_inititalized:
            self._build_env()
W
wuzewu 已提交
559
        return self.env.labels
W
wuzewu 已提交
560 561

    @property
562
    def outputs(self):
W
wuzewu 已提交
563 564
        if not self.env.is_inititalized:
            self._build_env()
565
        return self.env.outputs
W
wuzewu 已提交
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583

    @property
    def metrics(self):
        if self.is_predict_phase:
            raise RuntimeError()

        if not self.env.is_inititalized:
            self._build_env()
        return self.env.metrics

    @property
    def unique_name_generator(self):
        return self.env.UNG

    @property
    def feed_list(self):
        feed_list = [varname for varname in self._base_feed_list]
        if self.is_train_phase or self.is_test_phase:
W
wuzewu 已提交
584
            feed_list += [label.name for label in self.labels]
W
wuzewu 已提交
585 586 587 588 589 590 591 592 593 594 595
        return feed_list

    @property
    def feed_var_list(self):
        vars = self.main_program.global_block().vars
        return [vars[varname] for varname in self.feed_list]

    @property
    def fetch_list(self):
        if self.is_train_phase or self.is_test_phase:
            return [metric.name for metric in self.metrics] + [self.loss.name]
596
        return [output.name for output in self.outputs]
W
wuzewu 已提交
597

W
wuzewu 已提交
598 599 600 601 602
    @property
    def fetch_var_list(self):
        vars = self.main_program.global_block().vars
        return [vars[varname] for varname in self.fetch_list]

603 604
    @property
    def tb_writer(self):
605 606 607
        """
        get tb_writer for visualization.
        """
608 609 610 611 612 613 614
        if not os.path.exists(self.config.checkpoint_dir):
            mkdir(self.config.checkpoint_dir)
        tb_log_dir = os.path.join(self.config.checkpoint_dir, "visualization")
        if not self._tb_writer:
            self._tb_writer = SummaryWriter(tb_log_dir)
        return self._tb_writer

K
kinghuin 已提交
615
    def create_event_function(self, hook_type):
616 617 618 619 620 621 622 623 624 625
        """
        create handlers for specific event.

        Args:
            hook_type (str): specific event name

        Returns:
            func: executable function, the class method will receive a parameter named self.
        """

K
kinghuin 已提交
626
        def hook_function(self, *args):
627
            # all the handler in self._hooks[hook_type] will be configured to executable
K
kinghuin 已提交
628 629 630 631 632 633 634 635 636 637 638 639
            for name, func in self._hooks[hook_type].items():
                if inspect.ismethod(func):
                    func(*args)
                else:
                    partial(func, self)(*args)

        return hook_function

    @property
    def hooks(self):
        return self._hooks

640 641 642 643 644 645 646 647 648 649 650
    def hooks_info(self, show_default=False):
        """
        get the hooks information, including the source code.

        Args:
            show_default (bool): show the information of Paddlehub default hooks or not, default False

        Returns:
            str: the formatted string of the hooks information
        """
        return self._hooks.info(show_default)
K
kinghuin 已提交
651 652

    def add_hook(self, hook_type, name=None, func=None):
653 654 655 656 657 658 659 660
        """
        add the handler function to spectific event.

        Args:
            hook_type (str): the spectific event name
            name (str): the handler function name, default None
            func (func): the handler function, default None
        """
K
kinghuin 已提交
661 662
        if name == None:
            name = "hook_%s" % id(func)
K
kinghuin 已提交
663
        self._hooks.add(hook_type, name=name, func=func)
K
kinghuin 已提交
664
        logger.info("Add hook %s:%s successfully" % (hook_type, name))
K
kinghuin 已提交
665 666

    def delete_hook(self, hook_type, name):
667 668 669 670 671 672 673
        """
        delete the handler function of spectific event.

        Args:
            hook_type (str): the spectific event name
            name (str): the handler function name
        """
K
kinghuin 已提交
674
        self._hooks.delete(hook_type, name)
K
kinghuin 已提交
675
        logger.info("Delete hook %s:%s successfully" % (hook_type, name))
K
kinghuin 已提交
676 677

    def modify_hook(self, hook_type, name, func):
678 679 680 681 682 683 684 685
        """
         modify the handler function of spectific event.

         Args:
             hook_type (str): the spectific event name
             name (str): the handler function name
             func (func): the new handler function
         """
K
kinghuin 已提交
686
        self._hooks.modify(hook_type, name, func)
K
kinghuin 已提交
687
        logger.info("Modify hook %s:%s successfully" % (hook_type, name))
K
kinghuin 已提交
688 689

    def _default_build_env_start_event(self):
W
wuzewu 已提交
690 691
        pass

K
kinghuin 已提交
692
    def _default_build_env_end_event(self):
K
kinghuin 已提交
693 694
        if not self.is_predict_phase:
            self.env.score_scalar = {}
W
wuzewu 已提交
695

K
kinghuin 已提交
696 697
    def _default_finetune_start_event(self):
        logger.info("PaddleHub finetune start")
W
wuzewu 已提交
698

K
kinghuin 已提交
699
    def _default_finetune_end_event(self, run_states):
W
wuzewu 已提交
700 701
        logger.info("PaddleHub finetune finished.")

K
kinghuin 已提交
702
    def _default_predict_start_event(self):
W
wuzewu 已提交
703 704
        logger.info("PaddleHub predict start")

K
kinghuin 已提交
705
    def _default_predict_end_event(self, run_states):
W
wuzewu 已提交
706 707
        logger.info("PaddleHub predict finished.")

K
kinghuin 已提交
708 709
    def _default_eval_start_event(self):
        logger.info("Evaluation on {} dataset start".format(self.phase))
W
wuzewu 已提交
710

K
kinghuin 已提交
711
    def _default_eval_end_event(self, run_states):
712 713 714 715 716 717
        """
        Paddlehub default handler for eval_end_event, it will complete visualization and metrics calculation

        Args:
            run_states (object): the results in eval phase
        """
K
kinghuin 已提交
718
        eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states)
K
kinghuin 已提交
719
        if 'train' in self._envs:
K
kinghuin 已提交
720
            self.tb_writer.add_scalar(
K
kinghuin 已提交
721 722
                tag="Loss_{}".format(self.phase),
                scalar_value=eval_loss,
723
                global_step=self._envs['train'].current_step)
K
kinghuin 已提交
724

K
kinghuin 已提交
725 726 727 728 729 730 731
        log_scores = ""
        for metric in eval_scores:
            if 'train' in self._envs:
                self.tb_writer.add_scalar(
                    tag="{}_{}".format(metric, self.phase),
                    scalar_value=eval_scores[metric],
                    global_step=self._envs['train'].current_step)
K
kinghuin 已提交
732
            log_scores += "%s=%.5f " % (metric, eval_scores[metric])
733
        logger.eval(
K
kinghuin 已提交
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
            "[%s dataset evaluation result] loss=%.5f %s[step/sec: %.2f]" %
            (self.phase, eval_loss, log_scores, run_speed))

        eval_scores_items = eval_scores.items()
        if len(eval_scores_items):
            # The first metric will be chose to eval
            main_metric, main_value = list(eval_scores_items)[0]
        else:
            logger.warning(
                "None of metrics has been implemented, loss will be used to evaluate."
            )
            # The larger, the better
            main_metric, main_value = "negative loss", -eval_loss
        if self.phase in ["dev", "val"] and main_value > self.best_score:
            self.best_score = main_value
            model_saved_dir = os.path.join(self.config.checkpoint_dir,
                                           "best_model")
751
            logger.eval("best model saved to %s [best %s=%.5f]" %
K
kinghuin 已提交
752
                        (model_saved_dir, main_metric, main_value))
S
Steffy-zxf 已提交
753
            self.save_inference_model(dirname=model_saved_dir)
W
wuzewu 已提交
754

K
kinghuin 已提交
755
    def _default_log_interval_event(self, run_states):
756 757 758 759 760 761
        """
        PaddleHub default handler for log_interval_event, it will complete visualization.

        Args:
            run_states (object): the results in train phase
        """
K
kinghuin 已提交
762 763
        scores, avg_loss, run_speed = self._calculate_metrics(run_states)
        self.tb_writer.add_scalar(
K
kinghuin 已提交
764
            tag="Loss_{}".format(self.phase),
K
kinghuin 已提交
765
            scalar_value=avg_loss,
766
            global_step=self._envs['train'].current_step)
K
kinghuin 已提交
767 768 769
        log_scores = ""
        for metric in scores:
            self.tb_writer.add_scalar(
K
kinghuin 已提交
770
                tag="{}_{}".format(metric, self.phase),
K
kinghuin 已提交
771
                scalar_value=scores[metric],
772
                global_step=self._envs['train'].current_step)
K
kinghuin 已提交
773
            log_scores += "%s=%.5f " % (metric, scores[metric])
774 775 776
        logger.train("step %d / %d: loss=%.5f %s[step/sec: %.2f]" %
                     (self.current_step, self.max_train_steps, avg_loss,
                      log_scores, run_speed))
W
wuzewu 已提交
777

K
kinghuin 已提交
778
    def _default_save_ckpt_interval_event(self):
W
wuzewu 已提交
779
        self.save_checkpoint()
W
wuzewu 已提交
780

K
kinghuin 已提交
781
    def _default_eval_interval_event(self):
W
wuzewu 已提交
782 783
        self.eval(phase="dev")

K
kinghuin 已提交
784 785
    def _default_run_step_event(self, run_state):
        pass
W
wuzewu 已提交
786 787 788 789 790 791 792 793 794 795 796

    def _build_net(self):
        raise NotImplementedError

    def _add_loss(self):
        raise NotImplementedError

    def _add_label(self):
        raise NotImplementedError

    def _add_metrics(self):
K
kinghuin 已提交
797 798
        # Some metrics like acc, auc can be calculated by fluid.layers
        # The others can be calculated in _calculate_metrics function
W
wuzewu 已提交
799 800
        raise NotImplementedError

W
wuzewu 已提交
801
    def _calculate_metrics(self, run_states):
K
kinghuin 已提交
802 803 804
        # NOTE: if you want to customize the metrics
        # you should make sure that the first parameter returned is a dict
        # The first key will be used as main metrics to update the best model
W
wuzewu 已提交
805 806
        raise NotImplementedError

W
wuzewu 已提交
807 808
    # NOTE: current saved checkpoint machanism is not completed,
    # it can't restore dataset training status
W
wuzewu 已提交
809
    def save_checkpoint(self):
S
Steffy-zxf 已提交
810 811 812
        """
        save the program of the last step in training
        """
S
Steffy-zxf 已提交
813 814
        model_saved_dir = os.path.join(self.config.checkpoint_dir,
                                       "step_%d" % self.current_step)
S
Steffy-zxf 已提交
815

S
Steffy-zxf 已提交
816
        logger.info("Saving model checkpoint to {}".format(model_saved_dir))
S
Steffy-zxf 已提交
817 818 819
        # to resume traning by loading ckpt, it must be save program (save_persistables)
        fluid.io.save_persistables(
            self.exe, dirname=model_saved_dir, main_program=self.main_program)
W
wuzewu 已提交
820 821 822 823
        save_checkpoint(
            checkpoint_dir=self.config.checkpoint_dir,
            current_epoch=self.current_epoch,
            global_step=self.current_step,
K
kinghuin 已提交
824
            best_score=self.best_score,
W
wuzewu 已提交
825 826 827
            exe=self.exe,
            main_program=self.main_program)

W
wuzewu 已提交
828
    def load_checkpoint(self):
K
kinghuin 已提交
829
        is_load_successful, self.env.current_epoch, self.env.current_step, self.best_score = load_checkpoint(
W
wuzewu 已提交
830 831
            self.config.checkpoint_dir,
            self.exe,
W
wuzewu 已提交
832
            main_program=self.main_program)
W
wuzewu 已提交
833

W
wuzewu 已提交
834 835 836 837 838 839 840 841 842 843 844 845 846
        return is_load_successful

    def load_parameters(self, dirname):
        def if_exist(var):
            path = os.path.join(dirname, var.name)
            return os.path.exists(path)

        fluid.io.load_vars(
            self.exe, dirname, self.main_program, predicate=if_exist)

    def save_parameters(self, dirname):
        fluid.io.save_params(
            self.exe, dirname=dirname, main_program=self.main_program)
S
Steffy-zxf 已提交
847

W
wuzewu 已提交
848 849 850 851 852 853 854 855 856 857 858 859 860 861
    def save_inference_model(self,
                             dirname,
                             model_filename=None,
                             params_filename=None):
        with self.phase_guard("predict"):
            fluid.io.save_inference_model(
                dirname=dirname,
                executor=self.exe,
                feeded_var_names=self.feed_list,
                target_vars=self.fetch_var_list,
                main_program=self.main_program,
                model_filename=model_filename,
                params_filename=params_filename)

W
wuzewu 已提交
862
    def finetune_and_eval(self):
863
        return self.finetune(do_eval=True)
W
wuzewu 已提交
864 865

    def finetune(self, do_eval=False):
866 867 868 869 870 871 872 873 874
        """
        train and finetune the module parameters.

        Args:
            do_eval (bool): do eval during train phase or not

        Returns:
            RunState: the running result of train phase
        """
875

W
wuzewu 已提交
876 877 878 879 880 881
        # Start to finetune
        with self.phase_guard(phase="train"):
            self.init_if_necessary()
            self._finetune_start_event()
            run_states = []
            if self.current_epoch <= self.config.num_epoch:
W
wuzewu 已提交
882
                while self.current_epoch <= self.config.num_epoch:
K
kinghuin 已提交
883
                    self.config.strategy.step()
W
wuzewu 已提交
884 885
                    run_states = self._run(do_eval=do_eval)
                    self.env.current_epoch += 1
W
wuzewu 已提交
886

W
wuzewu 已提交
887
                # Final evaluation
888
                if self._base_data_reader.get_dev_examples() != []:
889 890 891
                    # Warning: DO NOT use self.eval(phase="dev", load_best_model=True) during training.
                    # It will cause trainer unable to continue training from checkpoint after eval.
                    # More important, The model should evaluate current performance during training.
892 893
                    self.eval(phase="dev")
                if self._base_data_reader.get_test_examples() != []:
K
kinghuin 已提交
894
                    self.eval(phase="test", load_best_model=True)
895 896
                # Save checkpoint after finetune
                self.save_checkpoint()
W
wuzewu 已提交
897

W
wuzewu 已提交
898
            self._finetune_end_event(run_states)
899
            return run_states
W
wuzewu 已提交
900

K
kinghuin 已提交
901
    def eval(self, phase="dev", load_best_model=False):
902 903 904 905 906 907 908 909 910 911
        """
        evaluate the performance of current module.

        Args:
            phase (str): current run phase
            load_best_model (bool): load the best model or not

        Returns:
            RunState: the running result of eval phase
        """
K
kinghuin 已提交
912 913 914
        # Warning: DO NOT use eval(load_best_model=True) in finetune_and_eval
        # It will cause trainer unable to continue training from checkpoint after eval
        # More important, The model should evaluate current performance during training.
W
wuzewu 已提交
915
        with self.phase_guard(phase=phase):
K
kinghuin 已提交
916 917 918 919
            if load_best_model:
                self.init_if_load_best_model()
            else:
                self.init_if_necessary()
W
wuzewu 已提交
920 921 922
            self._eval_start_event()
            run_states = self._run()
            self._eval_end_event(run_states)
923
            return run_states
W
wuzewu 已提交
924

925 926 927 928 929 930 931 932 933 934
    def _create_predictor(self):
        """
        create high-performance predictor for predict.

        Returns:
            PaddlePredictor: the high-performance predictor
        """
        with tmp_dir() as _dir:
            self.save_inference_model(dirname=_dir)
            predictor_config = fluid.core.AnalysisConfig(_dir)
S
Steffy-zxf 已提交
935
            predictor_config.disable_glog_info()
936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000

            if self.config.use_cuda:
                predictor_config.enable_use_gpu(100, 0)
                predictor_config.switch_ir_optim(True)
            else:
                predictor_config.disable_gpu()
            predictor_config.enable_memory_optim()
            return fluid.core.create_paddle_predictor(predictor_config)

    def _run_with_predictor(self):
        """
        use high-performance predictor to make prediction.

        Returns:
            RunState: the running result of predict phase
        """

        if isinstance(self._base_data_reader, hub.reader.LACClassifyReader):
            raise Exception(
                "LACClassifyReader does not support predictor, please close accelerate_mode"
            )

        global_run_states = []
        period_run_states = []

        for run_step, batch in enumerate(self.reader(), start=1):
            step_run_state = RunState(len(self.fetch_list))
            step_run_state.run_step = 1
            num_batch_examples = len(batch)

            if not self.config.use_pyreader:
                # if use pyreader, the nlp_reader return [batch]
                batch = batch[0]

            batch = [fluid.core.PaddleTensor(data) for data in batch]
            fetch_result = self._predictor.run(batch)
            for index, result in enumerate(fetch_result):
                step_run_state.run_results[index] = result.as_ndarray()
            step_run_state.run_examples += num_batch_examples
            step_run_state.update()
            period_run_states += [step_run_state]
            self._run_step_event(step_run_state)

        global_run_states += period_run_states
        return global_run_states

    def predict(self,
                data,
                load_best_model=True,
                return_result=False,
                accelerate_mode=False):
        """
        make prediction for the input data.

        Args:
            data (list): the data will be predicted.
            load_best_model (bool): load the best model or not
            return_result (bool): return a readable result or just the raw run result
            accelerate_mode (bool): use high-performance predictor or not

        Returns:
            RunState: the running result of predict phase
        """
        self.accelerate_mode = accelerate_mode

W
wuzewu 已提交
1001
        with self.phase_guard(phase="predict"):
1002 1003 1004
            self._predict_data = data
            self._predict_start_event()

W
wuzewu 已提交
1005
            if load_best_model:
K
kinghuin 已提交
1006 1007 1008
                self.init_if_load_best_model()
            else:
                self.init_if_necessary()
1009 1010 1011 1012 1013 1014 1015
            if not self.accelerate_mode:
                run_states = self._run()
            else:
                if not self._predictor:
                    self._predictor = self._create_predictor()
                run_states = self._run_with_predictor()

W
wuzewu 已提交
1016
            self._predict_end_event(run_states)
W
wuzewu 已提交
1017
            self._predict_data = None
K
kinghuin 已提交
1018 1019
            if return_result:
                return self._postprocessing(run_states)
1020
        return run_states
W
wuzewu 已提交
1021

K
kinghuin 已提交
1022
    def _postprocessing(self, run_states):
1023 1024 1025 1026 1027 1028 1029 1030 1031
        """
        postprocessing the run result, get readable result.

        Args:
            run_states (RunState): the raw run result to be processed

        Returns:
            list: readable result
        """
K
kinghuin 已提交
1032 1033 1034 1035 1036 1037
        results = []
        for batch_state in run_states:
            batch_result = batch_state.run_results[0]
            results += [result[0] for result in batch_result]
        return results

W
wuzewu 已提交
1038
    def _run(self, do_eval=False):
1039 1040
        """
        load data and run the program.
W
wuzewu 已提交
1041

1042 1043
        Args:
            do_eval (bool): do eval during train phase or not
W
wuzewu 已提交
1044

1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
        Returns:
            RunState: the running result of specific phase
        """
        with fluid.program_guard(self.main_program, self.startup_program):
            if self.config.use_pyreader:
                data_loader = fluid.io.DataLoader.from_generator(
                    feed_list=self.feed_var_list,
                    capacity=64,
                    use_double_buffer=True,
                    iterable=True)
                data_reader = data_loader.set_batch_generator(
                    self.reader, places=self.places)
            else:
                data_feeder = fluid.DataFeeder(
                    feed_list=self.feed_list, place=self.place)
                data_reader = data_feeder.decorate_reader(
                    self.reader,
                    multi_devices=self.config.use_data_parallel,
                    drop_last=True)
W
wuzewu 已提交
1064

1065 1066
            global_run_states = []
            period_run_states = []
K
kinghuin 已提交
1067

1068 1069 1070 1071
            for run_step, batch in enumerate(data_reader(), start=1):
                step_run_state = RunState(len(self.fetch_list))
                step_run_state.run_step = 1
                num_batch_examples = len(batch)
W
wuzewu 已提交
1072

S
Steffy-zxf 已提交
1073 1074
                fetch_result = self.exe.run(
                    self.main_program_to_be_run,
1075
                    feed=batch,
S
Steffy-zxf 已提交
1076
                    fetch_list=self.fetch_list,
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
                    return_numpy=self.return_numpy)
                if not self.return_numpy:
                    fetch_result = [np.array(x) for x in fetch_result]

                for index, result in enumerate(fetch_result):
                    step_run_state.run_results[index] = result
                step_run_state.run_examples += num_batch_examples
                step_run_state.update()
                period_run_states += [step_run_state]
                self.env.current_step += 1
                if self.is_train_phase:
                    if self.current_step % self.config.log_interval == 0:
                        self._log_interval_event(period_run_states)
                        global_run_states += period_run_states
                        period_run_states = []

                    if self.config.save_ckpt_interval and self.current_step % self.config.save_ckpt_interval == 0:
                        self._save_ckpt_interval_event()

                    if do_eval and self.current_step % self.config.eval_interval == 0:
                        self._eval_interval_event()

                self._run_step_event(step_run_state)

            global_run_states += period_run_states
            return global_run_states
1103 1104 1105 1106 1107

    def __repr__(self):
        return "Task: %s with metrics_choices: %s, reader: %s, %s" % (
            self.__class__.__name__, self.metrics_choices,
            self._base_data_reader.__class__.__name__, self.config)