trainer.py 45.1 KB
Newer Older
H
Helin Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import contextlib
16
import os
T
tangwei12 已提交
17 18 19
import errno
import shutil
import time
20

Y
Yu Yang 已提交
21
import core
22

Y
Yu Yang 已提交
23
import data_feeder
24 25
import executor
import framework
J
Jeff Wang 已提交
26
import io
Y
Yu Yang 已提交
27 28
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
29
import parallel_executor
Y
Yancey 已提交
30
from transpiler import distribute_transpiler
Y
Yu Yang 已提交
31

H
Helin Wang 已提交
32
__all__ = [
33 34
    'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent',
    'EndStepEvent', 'CheckpointConfig'
H
Helin Wang 已提交
35 36 37
]


Y
Yu Yang 已提交
38
class BeginEpochEvent(object):
Y
yuyang18 已提交
39 40 41 42 43 44 45
    """
    The begin of a training epoch.

    Args:
        epoch_id(int): The current epoch ID.
    """

Y
Yu Yang 已提交
46 47 48 49 50
    def __init__(self, epoch_id):
        self.epoch = epoch_id


class EndEpochEvent(object):
Y
yuyang18 已提交
51 52 53 54 55 56 57
    """
    The end of a training epoch.

    Args:
        epoch_id(int): The current epoch ID.
    """

Y
Yu Yang 已提交
58 59
    def __init__(self, epoch_id):
        self.epoch = epoch_id
H
Helin Wang 已提交
60

Y
Yu Yang 已提交
61 62

class BeginStepEvent(object):
Y
yuyang18 已提交
63 64 65 66 67 68 69 70
    """
    The begin of a training epoch.

    Args:
        epoch_id(int): The current epoch ID.
        step_id(int): The current step ID.
    """

Y
Yu Yang 已提交
71 72 73
    def __init__(self, epoch_id, step_id):
        self.epoch = epoch_id
        self.step = step_id
Y
yuyang18 已提交
74
        self.fetch_metrics = True
Y
yuyang18 已提交
75
        """
T
bug fix  
tangwei12 已提交
76
        If fetch_metrics is true, the metrics will be fetched at the
Y
yuyang18 已提交
77 78
        EndStepEvent. Default is True.
        """
Y
Yu Yang 已提交
79 80 81


class EndStepEvent(object):
Y
yuyang18 已提交
82 83 84 85 86 87 88
    """
    The end of a training step.

    Args:
        epoch_id(int): The current epoch ID.
        step_id(int): The current step ID.
        metrics(list): A list of fetched tensor. The order of this list is same
Y
yuyang18 已提交
89
            as the :code:`train_func` returns.
Y
yuyang18 已提交
90 91
    """

Y
yuyang18 已提交
92
    def __init__(self, epoch_id, step_id, metrics):
Y
Yu Yang 已提交
93 94
        self.epoch = epoch_id
        self.step = step_id
Y
yuyang18 已提交
95
        self.metrics = metrics
H
Helin Wang 已提交
96 97


98
class CheckpointConfig(object):
Y
yuyang18 已提交
99
    """
T
tangwei12 已提交
100
    Parameter object for :code:`save_checkpoint` and
Y
yuyang18 已提交
101 102 103 104
    :code:`fluid.Trainer`. Used to configuration how to save checkpoint.

    Args:
        checkpoint_dir(str): Directory path to save check point. Default is the
Y
yuyang18 已提交
105
            current directory.
Y
yuyang18 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119

        max_num_checkpoints(int): The max number of local check points.
        epoch_interval(int): Every number of epoch to save check point.
        step_interval(int): Every number of step to save check point.

    Examples:
        >>> config = fluid.CheckpointConfig("./checkpoints")
        >>> trainer = fluid.Trainer(train_func=train_program,
        >>>                         place=place,
        >>>                         optimizer_func=optimizer_func,
        >>>                         checkpoint_config=config)
        >>> trainer.train(...)
    """

120 121 122
    def __init__(self,
                 checkpoint_dir=None,
                 max_num_checkpoints=3,
T
tangwei12 已提交
123 124 125
                 epoch_interval=1,
                 step_interval=10):

126 127
        assert epoch_interval >= 1
        assert step_interval >= 1
128

T
tangwei12 已提交
129 130
        self.checkpoint_dir = checkpoint_dir \
            if checkpoint_dir is not None else os.getcwd()
131 132 133
        self.max_num_checkpoints = max_num_checkpoints
        self.epoch_interval = epoch_interval
        self.step_interval = step_interval
134 135
        self.epoch_id = 0
        self.step_id = 0
T
tangwei12 已提交
136
        self.load_serial = None
T
tangwei12 已提交
137

138

Q
Qiao Longfei 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
def check_and_get_place(place):
    """
    Check the type of place or get the default place
    Args:
        place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.

    Raises:
        TypeError if the type mismatched.

    Returns:
        the original place if it is not None.
        if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
        Otherwise returns CPUPlace by default.
    """
    if place is None:
        if core.is_compiled_with_cuda():
            return core.CUDAPlace(0)
        else:
            return core.CPUPlace()
    else:
        if not isinstance(place, core.CUDAPlace) and not isinstance(
                place, core.CPUPlace):
            raise TypeError("Place should be either CUDAPlace or CPUPlace")
        return place


H
Helin Wang 已提交
165
class Trainer(object):
Y
Yu Yang 已提交
166
    """
Y
yuyang18 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
    simple neural network easily.

    This API takes a :code:`train_func`. A :code:`train_func` is a function that
    return loss as it first return value. The reset value can be fetched by
    EndStepEvent.metrics

    This API also takes a :code:`optimizer_func` that will return an optimizer
    instance.

    For example, to train a MLP for MNIST dataset, the sample program is

    >>> import paddle.fluid as fluid
    >>>
    >>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
    >>>     hidden = image
    >>>     for layer_size in layer_sizes:
    >>>         hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
    >>>     return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
    >>>
    >>> def train_mnist_mlp():
    >>>     img = fluid.layers.data(name='image', shape=[784])
    >>>     label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    >>>     prediction = mlp(img)
    >>>     return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
    >>>
    >>> def optimizer():
    >>>     return fluid.optimizer.Adam()
    >>>
    >>> trainer = Trainer(train_func=train_mnist_mlp,
    >>>                   optimizer_func=optimizer,
    >>>                   place=fluid.CUDAPlace(0),
    >>>                   parallel=True)
    >>>
    >>> def train_callback(event):
    >>>     if isinstance(event, fluid.EndStepEvent):
    >>>         print "Epoch ID", event.epoch, "Step ID",\
    >>>             event.step, "AvgLoss", event.metrics[0]
    >>>     elif isinstance(event, fluid.EndEpochEvent):
    >>>         trainer.save_params("./model_{0}".format(event.epoch))
    >>>
    >>> trainer.train(num_epochs=100, event_handler=train_callback)

    For more example, please see :ref:`api_guide_high_level_api`.

Y
Yu Yang 已提交
212 213

    Args:
Y
yuyang18 已提交
214 215
        train_func(callable): A function which will return loss. The loss must be
            a scalar tensor.
216
        optimizer_func(callable): A function that returns an Optimizer object.
Y
yuyang18 已提交
217 218 219 220 221 222
        place(CUDAPlace|CPUPlace): The device place of this trainer. If
            :code:`parallel=True,` all CUDA Places will be used if :code:`place`
            is a :code:`CUDAPlace`.
        parallel(bool): True if use multiple devices.
        checkpoint_config(CheckpointConfig): Configuration about how to save
            checkpoints.
Y
Yu Yang 已提交
223 224
    """

Q
Qiao Longfei 已提交
225 226
    def __init__(self,
                 train_func,
227
                 optimizer_func,
T
tangwei12 已提交
228
                 param_path=None,
Y
yuyang18 已提交
229
                 place=None,
230 231
                 parallel=False,
                 checkpoint_config=None):
232
        self.__stop = False
Y
yuyang18 已提交
233
        self.parallel = parallel
Q
Qiao Longfei 已提交
234

235 236
        # config for checkpoint
        # only chief worker will save variables
T
tangwei12 已提交
237
        self.trainer_id = 0
T
tangwei12 已提交
238 239 240
        self.checkpoint_cfg = checkpoint_config
        if self.checkpoint_cfg:
            assert isinstance(self.checkpoint_cfg, CheckpointConfig)
T
tangwei12 已提交
241
            serial = _get_latest_checkpoint_serial(
T
tangwei12 已提交
242 243
                self.checkpoint_cfg.checkpoint_dir)
            self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
244

H
Helin Wang 已提交
245
        self.scope = core.Scope()
Y
Yu Yang 已提交
246

Y
yuyang18 已提交
247 248 249 250
        # 1. we need to generate a framework.Program by calling
        # program_func. Reference: fluid.program_guard in
        # test_word2vec.py

Y
Yu Yang 已提交
251 252 253 254
        self.startup_program = framework.Program()
        self.train_program = framework.Program()

        with framework.program_guard(self.train_program, self.startup_program):
Q
Qiao Longfei 已提交
255
            program_func_outs = train_func()
Y
yuyang18 已提交
256
            self.train_func_outputs = program_func_outs if isinstance(
F
fengjiayi 已提交
257
                program_func_outs, list) else [program_func_outs]
258
            self.test_program = self.train_program.clone(for_test=True)
259

260
            # The first element of program_func_outs is loss.
261 262 263
            loss = self.train_func_outputs[0]

            optimizer = optimizer_func()
Y
Yu Yang 已提交
264 265 266
            if not isinstance(optimizer, opt_module.Optimizer):
                raise TypeError(
                    "The optimizer should be an instance of Optimizer")
267
            optimize_ops, params_grads = optimizer.minimize(loss)
Y
Yu Yang 已提交
268

Q
Qiao Longfei 已提交
269
        self.place = check_and_get_place(place)
H
Helin Wang 已提交
270

Q
Qiao Longfei 已提交
271
        self._dist_transpile_if_necessary(optimize_ops, params_grads)
272

H
Helin Wang 已提交
273 274
        # 2. move the default_main_program to self.program and run the
        # default_startup program on an empty core.Scope()
Y
Yu Yang 已提交
275
        # Run startup program
276 277 278
        with self._prog_and_scope_guard():
            exe = executor.Executor(place)
            exe.run(self.startup_program)
H
Helin Wang 已提交
279

T
bug fix  
tangwei12 已提交
280
        if self.checkpoint_cfg and self.checkpoint_cfg.load_serial is not None:
T
tangwei12 已提交
281
            self._load_checkpoint()
T
tangwei12 已提交
282 283

        if param_path and os.path.isdir(param_path):
T
tangwei12 已提交
284
            # load params from param_path into scope
T
tangwei12 已提交
285 286 287 288
            io.load_persistables(
                executor=exe,
                dirname=param_path,
                main_program=self.startup_program)
T
tangwei12 已提交
289

290 291 292 293 294 295 296 297 298 299 300 301
    def _transpile_nccl2_dist(self):
        # PADDLE_TRAINER_IPS
        if "PADDLE_TRAINER_IPS" not in os.environ:
            self.nccl_id_var = None
        else:
            self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
            port = os.getenv("PADDLE_PSERVER_PORT")
            worker_ips = os.getenv("PADDLE_TRAINER_IPS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            self.num_trainers = len(worker_endpoints)
Y
yi.wu 已提交
302
            current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
            worker_endpoints.remove(current_endpoint)
            # TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
            # in ParallelExecutor to start
            # distributed training using NCCL2
            self.nccl_id_var = self.startup_program.global_block().create_var(
                name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
            self.startup_program.global_block().append_op(
                type="gen_nccl_id",
                inputs={},
                outputs={"NCCLID": self.nccl_id_var},
                attrs={
                    "endpoint": current_endpoint,
                    "endpoint_list": worker_endpoints,
                    "trainer_id": self.trainer_id
                })

Q
Qiao Longfei 已提交
319
    def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
320 321 322 323
        self._transpile_nccl2_dist()
        if self.nccl_id_var != None:
            return

324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
        if "PADDLE_TRAINING_ROLE" not in os.environ:
            return

        # the port of all pservers, needed by both trainer and pserver
        port = os.getenv("PADDLE_PSERVER_PORT", "6174")
        # comma separated ips of all pservers, needed by trainer and
        # pserver
        pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)
        # total number of workers/trainers in the job, needed by
        # trainer and pserver
        trainers = int(os.getenv("PADDLE_TRAINERS"))
        # the IP of the local machine, needed by pserver only
        current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
        # the unique trainer id, starting from 0, needed by trainer
        # only
T
bug fix  
tangwei12 已提交
343
        self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
T
tangwei12 已提交
344

345 346 347 348 349
        # the role, should be either PSERVER or TRAINER
        training_role = os.getenv("PADDLE_TRAINING_ROLE")
        with self._prog_and_scope_guard():
            t = distribute_transpiler.DistributeTranspiler()
            t.transpile(
T
bug fix  
tangwei12 已提交
350
                self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
351
            if training_role == "PSERVER":
352 353 354
                self.pserver_id = eplist.index(current_endpoint)
                self.pserver_endpoints = pserver_endpoints
                self.lookup_table_name = t.table_name if t.has_distributed_lookup_table else None
T
tangwei12 已提交
355

356 357 358 359 360 361 362 363 364
                self.train_program = t.get_pserver_program(current_endpoint)
                self.startup_program = t.get_startup_program(current_endpoint,
                                                             self.train_program)
            elif training_role == "TRAINER":
                self.train_program = t.get_trainer_program()
            else:
                raise ValueError(
                    'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
                )
H
Helin Wang 已提交
365

366 367 368 369 370 371
    def stop(self):
        """
        stop training
        """
        self.__stop = True

Y
yuyang18 已提交
372
    def train(self, num_epochs, event_handler, reader=None, feed_order=None):
Y
Yu Yang 已提交
373
        """
Y
yuyang18 已提交
374
        Start the train loop to train the model.
Y
Yu Yang 已提交
375 376

        Args:
Y
yuyang18 已提交
377 378 379
            num_epochs(int): The number of epoch. An epoch will process all data in reader
            event_handler(callable): The event handler. A function with type (ev:Event)->void
            reader(callable): A reader creator object. See also
Y
yuyang18 已提交
380
                :ref:`api_guide_python_reader` .
Y
yuyang18 已提交
381
            feed_order(list): Feeding order of reader. None will following the defining
Y
Yu Yang 已提交
382 383 384
                order in program

        Returns:
Y
yuyang18 已提交
385
            None
Y
Yu Yang 已提交
386
        """
387 388 389 390 391 392
        training_role = os.getenv("PADDLE_TRAINING_ROLE", "")
        if training_role == "PSERVER":
            with self._prog_and_scope_guard():
                exe = executor.Executor(self.place)
                exe.run()
                return
Y
yuyang18 已提交
393 394 395 396 397 398
        if self.parallel:
            self._train_by_parallel_executor(num_epochs, event_handler, reader,
                                             feed_order)
        else:
            self._train_by_executor(num_epochs, event_handler, reader,
                                    feed_order)
H
Helin Wang 已提交
399

400
    def test(self, reader, feed_order):
F
fengjiayi 已提交
401 402 403 404
        """
        Test the model on given test data

        Args:
Y
yuyang18 已提交
405 406 407
            reader(callable): The reader that yields test data.
            feed_order(list): Feeding order of reader. None will following the
                defining order in program
F
fengjiayi 已提交
408 409
        """

Y
yuyang18 已提交
410 411
        return self._test_by_executor(reader, feed_order,
                                      self.train_func_outputs)
Y
Yu Yang 已提交
412

H
Helin Wang 已提交
413
    def save_params(self, param_path):
Y
yuyang18 已提交
414
        """
Y
yuyang18 已提交
415
        Save all parameters into :code:`param_path`.
416 417 418 419 420
        Only No.0 trainer will save dense params.
        In standalone PaddlePaddle, the only existing trainer will save dense params.
        In distributed PaddlePaddle, the No.0 trainer will save dense params,
        If there have lookup table need to save, No.0 trainer will broadcast notification
        to all Parameter Servers to save it on Parameter Servers independent.
Y
yuyang18 已提交
421

Y
yuyang18 已提交
422
        Args:
Y
yuyang18 已提交
423
            param_path(str): The path to save parameters.
Y
yuyang18 已提交
424 425 426 427

        Returns:
            None
        """
428 429 430 431

        if self.trainer_id != 0:
            return

432
        with self._prog_and_scope_guard():
433
            # save params on trainer
434 435
            exe = executor.Executor(self.place)
            io.save_persistables(exe, dirname=param_path)
436 437 438 439 440
            # save params on pserver
            if self.lookup_table_name:
                _save_pserver_vars_by_notify(exe, param_path,
                                             self.lookup_table_name,
                                             self.pserver_endpoints)
Y
Yu Yang 已提交
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463

    @contextlib.contextmanager
    def _prog_and_scope_guard(self):
        with framework.program_guard(
                main_program=self.train_program,
                startup_program=self.startup_program):
            with executor.scope_guard(self.scope):
                yield

    def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
        """
        Train by Executor and single device.

        Args:
            num_epochs:
            event_handler:
            reader:
            feed_order:

        Returns:

        """
        with self._prog_and_scope_guard():
F
fengjiayi 已提交
464
            feed_var_list = build_feed_var_list(self.train_program, feed_order)
Y
Yu Yang 已提交
465 466
            feeder = data_feeder.DataFeeder(
                feed_list=feed_var_list, place=self.place)
F
fengjiayi 已提交
467
            exe = executor.Executor(self.place)
Y
yuyang18 已提交
468 469 470 471
            reader = feeder.decorate_reader(reader, multi_devices=False)
            self._train_by_any_executor(event_handler, exe, num_epochs, reader)

    def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
T
tangwei12 已提交
472
        if self.checkpoint_cfg:
T
bug fix  
tangwei12 已提交
473 474
            epochs = [
                epoch_id for epoch_id in range(num_epochs)
T
tangwei12 已提交
475
                if epoch_id >= self.checkpoint_cfg.epoch_id
T
bug fix  
tangwei12 已提交
476 477 478 479
            ]
        else:
            epochs = [epoch_id for epoch_id in range(num_epochs)]

T
tangwei12 已提交
480
        for epoch_id in epochs:
Y
yuyang18 已提交
481 482
            event_handler(BeginEpochEvent(epoch_id))
            for step_id, data in enumerate(reader()):
483
                if self.__stop:
T
bug fix  
tangwei12 已提交
484 485
                    if self.checkpoint_cfg:
                        self._clean_checkpoint()
486
                    return
T
tangwei12 已提交
487

T
bug fix  
tangwei12 已提交
488 489 490 491
                if self.checkpoint_cfg and \
                    self.checkpoint_cfg.load_serial is not None and \
                    self.checkpoint_cfg.step_id >= step_id and \
                    self.checkpoint_cfg.epoch_id == epoch_id:
T
tangwei12 已提交
492 493
                    continue

Y
yuyang18 已提交
494 495 496 497 498 499 500 501 502 503
                begin_event = BeginStepEvent(epoch_id, step_id)
                event_handler(begin_event)
                if begin_event.fetch_metrics:
                    metrics = exe.run(feed=data,
                                      fetch_list=[
                                          var.name
                                          for var in self.train_func_outputs
                                      ])
                else:
                    metrics = exe.run(feed=data, fetch_list=[])
T
tangwei12 已提交
504

T
tangwei12 已提交
505 506
                if self.checkpoint_cfg:
                    self._save_checkpoint(epoch_id, step_id)
T
tangwei12 已提交
507
                event_handler(EndStepEvent(epoch_id, step_id, metrics))
Y
yuyang18 已提交
508
            event_handler(EndEpochEvent(epoch_id))
T
tangwei12 已提交
509 510
        if self.checkpoint_cfg:
            self._clean_checkpoint()
F
fengjiayi 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528

    def _test_by_executor(self, reader, feed_order, fetch_list):
        with executor.scope_guard(self.scope):
            feed_var_list = build_feed_var_list(self.test_program, feed_order)
            feeder = data_feeder.DataFeeder(
                feed_list=feed_var_list, place=self.place)
            exe = executor.Executor(self.place)
            accumulated = len(fetch_list) * [0]
            count = 0
            for data in reader():
                outs = exe.run(program=self.test_program,
                               feed=feeder.feed(data),
                               fetch_list=fetch_list)
                accumulated = [x[0] + x[1][0] for x in zip(accumulated, outs)]
                count += 1

            return [x / count for x in accumulated]

Y
yuyang18 已提交
529 530 531 532 533 534 535 536
    def _train_by_parallel_executor(self, num_epochs, event_handler, reader,
                                    feed_order):
        with self._prog_and_scope_guard():
            pe = self._get_or_create_parallel_executor()
            feed_var_list = build_feed_var_list(self.train_program, feed_order)
            feeder = data_feeder.DataFeeder(
                feed_list=feed_var_list, place=self.place)
            reader = feeder.decorate_reader(reader, multi_devices=True)
537
            self._train_by_any_executor(event_handler, pe, num_epochs, reader)
Y
yuyang18 已提交
538 539 540 541 542 543 544 545 546 547 548

    def _get_parallel_executor(self):
        return getattr(self, 'parallel_executor', None)

    def _get_or_create_parallel_executor(self):
        if self._get_parallel_executor() is None:
            self.parallel_executor = parallel_executor.ParallelExecutor(
                use_cuda=isinstance(self.place, core.CUDAPlace),
                loss_name=self.train_func_outputs[0].name)
        return self._get_parallel_executor()

T
tangwei12 已提交
549
    def _clean_checkpoint(self):
T
tangwei12 已提交
550
        assert self.checkpoint_cfg
T
tangwei12 已提交
551
        clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
T
tangwei12 已提交
552

553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
    def _get_checkpoint_load_args(self):
        """
        epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
        """
        return ["epoch_id", "step_id"]

    def _get_checkpoint_save_args(self, epoch_id, step_id):
        """
        epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
        """
        trainer_args = {}
        trainer_args["epoch_id"] = epoch_id
        trainer_args["step_id"] = step_id
        return trainer_args

T
tangwei12 已提交
568
    def _save_checkpoint(self, epoch_id, step_id):
T
tangwei12 已提交
569
        assert self.checkpoint_cfg
T
tangwei12 已提交
570

T
tangwei12 已提交
571 572
        if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
            and step_id % self.checkpoint_cfg.step_interval == 0:
T
bug fix  
tangwei12 已提交
573

T
tangwei12 已提交
574
            exe = executor.Executor(self.place)
T
tangwei12 已提交
575
            save_checkpoint(
T
tangwei12 已提交
576
                executor=exe,
T
tangwei12 已提交
577
                checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
T
tangwei12 已提交
578
                main_program=self.train_program,
579 580 581 582 583
                trainer_id=self.trainer_id,
                save_trainer_args=self._get_checkpoint_save_args(epoch_id,
                                                                 step_id),
                save_lookup_table=self.lookup_table_name,
                pserver_endpoints=self.pserver_endpoints,
T
tangwei12 已提交
584
                max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
T
tangwei12 已提交
585

T
tangwei12 已提交
586 587 588 589
    def _load_checkpoint(self):
        with self._prog_and_scope_guard():
            exe = executor.Executor(self.place)

T
bug fix  
tangwei12 已提交
590 591 592 593
            checkpoint_dir = _get_serial_dir(self.checkpoint_cfg.checkpoint_dir,
                                             self.checkpoint_cfg.load_serial)

            # Trainer Load
594
            if self.pserver_id is None:
T
bug fix  
tangwei12 已提交
595 596
                # load model
                load_checkpoint(
T
tangwei12 已提交
597
                    executor=exe,
T
bug fix  
tangwei12 已提交
598
                    checkpoint_dir=checkpoint_dir,
T
tangwei12 已提交
599 600 601
                    main_program=self.startup_program,
                    role_id=self.trainer_id,
                    is_trainer=True,
T
bug fix  
tangwei12 已提交
602
                    load_models=True)
T
tangwei12 已提交
603

T
bug fix  
tangwei12 已提交
604 605 606 607 608 609 610 611 612 613 614
                # load trainer_args
                trainer_args = self._get_checkpoint_load_args()
                trainer_args_ret = load_checkpoint(
                    executor=exe,
                    checkpoint_dir=checkpoint_dir,
                    main_program=self.startup_program,
                    role_id=self.trainer_id,
                    is_trainer=True,
                    load_trainer_args=trainer_args)

                if len(trainer_args_ret) != 2:
T
tangwei12 已提交
615 616 617
                    raise ValueError(
                        "the return trainer_args length do not equal _get_checkpoint_load_args"
                    )
T
bug fix  
tangwei12 已提交
618 619 620
                self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0])
                self.checkpoint_cfg.step_id = int(trainer_args_ret[1])

T
bug fix  
tangwei12 已提交
621
            # Pserver Load
T
tangwei12 已提交
622
            else:
623 624 625 626 627 628 629 630 631 632
                # load model
                load_checkpoint(
                    executor=exe,
                    checkpoint_dir=checkpoint_dir,
                    main_program=self.startup_program,
                    role_id=self.pserver_id,
                    is_trainer=False,
                    load_models=True,
                    load_lookup_table=self.lookup_table_name)

T
bug fix  
tangwei12 已提交
633
                # load lookup table
634
                if self.lookup_table_name:
T
tangwei12 已提交
635
                    load_checkpoint(
T
tangwei12 已提交
636
                        executor=exe,
T
bug fix  
tangwei12 已提交
637
                        checkpoint_dir=checkpoint_dir,
T
tangwei12 已提交
638
                        main_program=self.startup_program,
639
                        role_id=self.pserver_id,
T
tangwei12 已提交
640
                        is_trainer=False,
641
                        load_lookup_table=self.lookup_table_name)
T
tangwei12 已提交
642

F
fengjiayi 已提交
643 644 645 646 647

def build_feed_var_list(program, feed_order):
    if not isinstance(program, framework.Program):
        raise TypeError("The 'program' should be an object of Program")

648
    if isinstance(feed_order, list):
F
fengjiayi 已提交
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664
        feed_var_list = [
            program.global_block().var(var_name) for var_name in feed_order
        ]
    else:
        if not isinstance(feed_order, dict):
            raise TypeError(
                "The 'feed_order' should be either None, list or dict.")
        if not sorted(feed_order.values()) == range(len(feed_order)):
            raise ValueError(
                "The values of 'feed_order' should be a permutation of [0, len(feed_order))"
            )
        sorted_pair_list = sorted(feed_order.items(), key=lambda item: item[1])
        feed_var_list = [
            program.global_block().var(pair[0]) for pair in sorted_pair_list
        ]
    return feed_var_list
T
tangwei12 已提交
665 666 667 668 669 670 671 672 673 674 675 676 677


# move Checkpoint APIs from io.py to trainer.py, make all of them are private.
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
LOOKUP_TABLE_DIR = "__lookup_table__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_"


def save_checkpoint(executor,
                    checkpoint_dir,
T
bug fix  
tangwei12 已提交
678 679 680
                    main_program=None,
                    trainer_id=0,
                    save_trainer_args=None,
T
bug fix  
tangwei12 已提交
681
                    save_lookup_table=None,
T
bug fix  
tangwei12 已提交
682 683
                    pserver_endpoints=None,
                    max_num_checkpoints=3):
T
tangwei12 已提交
684 685
    """
    This function filters out all checkpoint variables from the give
T
bug fix  
tangwei12 已提交
686
    main_program and then saves these variables to the `checkpoint_dir`
T
tangwei12 已提交
687 688 689
    directory.

    In the training precess, we generally save a checkpoint in each
T
bug fix  
tangwei12 已提交
690 691 692 693
    iteration. So there might be a lot of checkpoints in the
    `checkpoint_dir`. To avoid them taking too much disk space, the
    `max_num_checkpoints` are introduced to limit the total number of
    checkpoints. If the number of existing checkpints is greater than
T
tangwei12 已提交
694 695 696 697 698 699 700 701 702 703 704
    the `max_num_checkpoints`, oldest ones will be scroll deleted.

    A variable is a checkpoint variable and will be saved if it meets
    all following conditions:
        1. It's persistable.
        2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
        3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".

    Args:
        executor(Executor): The executor to run for save checkpoint.
        checkpoint_dir(str): The folder where to save checkpoints.
T
bug fix  
tangwei12 已提交
705
        trainer_id(int): currect trainer id, if id is equal to 0, the trainer
T
tangwei12 已提交
706
            is chief.
T
bug fix  
tangwei12 已提交
707
        trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
T
tangwei12 已提交
708 709 710 711
            and 'step_id'.
            Defaut: None
        main_program(Program): The program whose checkpoint variables will
            be saved.
T
bug fix  
tangwei12 已提交
712
        max_num_checkpoints(int): The max number of total number of existing
T
tangwei12 已提交
713 714
            checkpoints.
            Default: 3
T
bug fix  
tangwei12 已提交
715
        save_lookup_table(string|None): the lookup table name, when use distribute
T
tangwei12 已提交
716
            lookup table, we can get lookup table name by DistributeTranspiler.
T
bug fix  
tangwei12 已提交
717 718 719
            table_name
        pserver_endpoints(list|None): the parameter server ip:port list.
            when use distribute lookup table, we can get pserver_endpoints by
T
tangwei12 已提交
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
            distribute arguments.

    Returns:
        None

    Raises:
        ValueError: If `checkpoint_dir` is None.
        AssertionError: If `trainer_args` is not a dict.

    Examples:
        .. code-block:: python

            exe = fluid.Executor(fluid.CPUPlace())
            path = "./checkpoints"
            prog = fluid.default_main_program()
            trainer_args = {"epoch_id": 200,
                            "step_id": 20} # just an example
            table_name = "share_w"
            ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]

            save_checkpoint(executor=exe,
                                     checkpoint_dir=path,
                                     trainer_id=0,
                                     trainer_args=trainer_args,
                                     main_program=prog,
                                     max_num_checkpoints=3,
T
bug fix  
tangwei12 已提交
746
                                     save_lookup_table=table_name,
T
tangwei12 已提交
747 748 749 750 751 752 753
                                     pserver_endpoints = ps_endpoints)
    """
    if checkpoint_dir is None:
        raise ValueError("'checkpoint_dir' should not be None")

    _make_chekcpoint_dirs(checkpoint_dir)
    serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
T
bug fix  
tangwei12 已提交
754
    cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
T
tangwei12 已提交
755

T
bug fix  
tangwei12 已提交
756 757 758 759
    is_chief = trainer_id == 0

    if save_trainer_args is not None:
        _save_trainer_args(cur_dir, trainer_id, save_trainer_args)
T
tangwei12 已提交
760 761

    if is_chief:
T
bug fix  
tangwei12 已提交
762 763
        if main_program is None:
            raise ValueError('main_program should not be None.')
T
bug fix  
tangwei12 已提交
764
        _save_persistable_vars(executor, cur_dir, main_program)
T
tangwei12 已提交
765

T
bug fix  
tangwei12 已提交
766 767
    if is_chief and save_lookup_table and pserver_endpoints:
        _save_pserver_vars_by_notify(executor, cur_dir, save_lookup_table,
T
tangwei12 已提交
768 769 770 771 772 773 774
                                     pserver_endpoints)

    _scroll_delete(checkpoint_dir, max_num_checkpoints)


def load_checkpoint(executor,
                    checkpoint_dir,
T
bug fix  
tangwei12 已提交
775
                    main_program=None,
T
tangwei12 已提交
776 777
                    role_id=0,
                    is_trainer=True,
T
bug fix  
tangwei12 已提交
778
                    load_models=False,
T
tangwei12 已提交
779 780 781 782 783 784 785 786
                    load_trainer_args=None,
                    load_lookup_table=None):
    """
    This function filters out all checkpoint variables from the give
    main_program and then try to load these variables from the
    `checkpoint_dir` directory.

    In the training precess, we generally save a checkpoint in each
T
bug fix  
tangwei12 已提交
787 788
    iteration. So there are more than one checkpoint in the
    `checkpoint_dir` (each checkpoint has its own sub folder), use
T
tangwei12 已提交
789 790 791 792 793 794 795 796 797 798 799 800 801
    `serial` to specify which serial of checkpoint you would like to
    load.

    A variable is a checkpoint variable and will be loaded if it meets
    all following conditions:
        1. It's persistable.
        2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
        3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".

    Args:
        executor(Executor): The executor to run for loading checkpoint.
        checkpoint_dir(str): The folder where all checkpoints are.
        serial(int): The serial of checkpoint you would like to load.
T
bug fix  
tangwei12 已提交
802
        main_program(Program|None): The program whose checkpoint variables will
T
tangwei12 已提交
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
                               be loaded.
        role_id(int):  the trainer id or the parameter server id.
        is_trainer(bool): trainer is True and parameter server is False.
        load_trainer_args(list|None): list about load trainer args.
        load_lookup_table(str|None): the lookup table name

    Returns:
        None

    Raises:
        ValueError: If `checkpoint_dir` is None.
        ValueError: If `main_program` is None.

    Examples:
        .. code-block:: python

            exe = fluid.Executor(fluid.CPUPlace())
            path = "./checkpoints"
            prog = fluid.default_main_program()
            load_checkpoint(executor=exe, checkpoint_dir=path,
                    serial=9, main_program=prog)

            # In this example, `load_checkpoint` function
            # will first filters out all checkpoint variables in the default
            # main program, and then try to load these variables form the
            # folder "./checkpoints/checkpoint_9/__model__".
    """

    if checkpoint_dir is None:
        raise ValueError("'checkpoint_dir' should not be None")

T
bug fix  
tangwei12 已提交
834 835 836 837
    # trainer load
    if is_trainer:
        if load_models:
            _load_persistable_vars(executor, checkpoint_dir, main_program, True)
838

T
bug fix  
tangwei12 已提交
839 840 841 842 843 844
        if load_trainer_args:
            trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
                                                  load_trainer_args)
            return trainer_args_ret
    # pserver load
    else:
845 846 847 848 849 850 851 852
        if load_models:
            if load_lookup_table:
                _load_persistable_vars(executor, checkpoint_dir, main_program,
                                       True, [load_lookup_table])
            else:
                _load_persistable_vars(executor, checkpoint_dir, main_program,
                                       True)

T
bug fix  
tangwei12 已提交
853 854 855
        if load_lookup_table:
            _load_lookup_table_vars(executor, checkpoint_dir, main_program,
                                    role_id, load_lookup_table)
T
tangwei12 已提交
856 857 858 859


def clean_checkpoint(checkpoint_dir, delete_dir=False):
    """
T
bug fix  
tangwei12 已提交
860
    clean the checkpoint dir, when the train exits normally,
T
tangwei12 已提交
861
    the trainer will call clean_checkpoint to delete checkpoint directory saved before.
T
bug fix  
tangwei12 已提交
862
    delete_dir only works when the directory is empty, otherwise, OSError is raised.
T
tangwei12 已提交
863 864 865 866 867 868 869 870 871 872 873 874 875

    : param checkpoint_dir
    : param delete_dir
    """

    if checkpoint_dir is None:
        raise ValueError("'checkpoint_dir' should not be None")
    _scroll_delete(checkpoint_dir, max_num_checkpoints=0)

    if delete_dir and not os.listdir(checkpoint_dir):
        os.rmdir(checkpoint_dir)


876 877 878 879 880
def _load_persistable_vars(executor,
                           dirname,
                           program,
                           has_model_dir=False,
                           except_vars=None):
T
tangwei12 已提交
881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908
    """
    This function filters out all checkpoint variables from the give
    program and then trys to load these variables from the given directory.

    A variable is a checkpoint variable if it meets all following
    conditions:
        1. It's persistable.
        2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
        3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".

    Args:
        executor(Executor): The executor to run for loading variables.
        dirname(str): The directory path.
        program(Program): The program whose checkpoint variables will
                          be loaded.
        has_model_dir(bool): if True, the function loads variables
                             from a sub directory named '__model__'.
                             Default: False

    Returns:
        None

    Examples:
        .. code-block:: python

            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            prog = fluid.default_main_program()
T
bug fix  
tangwei12 已提交
909
            _load_persistable_vars(executor=exe,
T
tangwei12 已提交
910 911
                    dirname=param_path, program=prog, has_model_dir=True)

T
bug fix  
tangwei12 已提交
912
            # In this example, `_load_persistable_vars` function
T
tangwei12 已提交
913 914 915 916 917 918 919 920 921 922 923 924
            # will first filters out all checkpoint variables in the default
            # main program, and then trys to load these variables form the
            # folder "./my_paddle_model/__model__".
    """

    if has_model_dir:
        dirname = _get_model_dir(dirname)

    io.load_vars(
        executor,
        dirname=dirname,
        main_program=program,
925
        predicate=_is_checkpoint_var(except_vars),
T
tangwei12 已提交
926 927 928 929 930
        filename=None)


def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
    """
T
bug fix  
tangwei12 已提交
931
    The parameter server will load lookup table's local file in
T
tangwei12 已提交
932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978
    selectedrows variable.

    Args:
        executor(Executor): The executor to run for loading persistable variables
        dirname(str): The directory path
        main_program(Program): Find the variable named table_name in main_program
        pserver_id(int): the serial number in pserver_endpoints list
        table_name(str): lookup table name

    Returns:
        None

    Examples:
        .. code-block:: python

            exe = fluid.Executor(fluid.CPUPlace())
            dirname = "./checkpoints/checkpoint_9/"
            prog = fluid.default_main_program()
            pserver_id = 1
            table_name = "share_w"
            _load_lookup_table_vars(executor=exe,
                    dirname=dirname, program=prog, pserver_id=pserver_id,
                    table_name=table_name)
    """

    for var in program.list_vars():
        if var.name == table_name:
            lookup_table_var = var
            break

    assert lookup_table_var is not None

    lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
    table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id)

    load_prog = framework.Program()
    load_block = load_prog.global_block()

    load_block.append_op(
        type='load',
        inputs={},
        outputs={'Out': [lookup_table_var]},
        attrs={'file_path': os.path.join(lookup_table_dir, table_file)})

    executor.run(load_prog)


T
bug fix  
tangwei12 已提交
979
def _save_persistable_vars(executor, dirname, program):
T
tangwei12 已提交
980 981
    """
    This function filters out all checkpoint variables from the give
T
bug fix  
tangwei12 已提交
982
    program and then save these variables to a sub-folder '__model__' of
T
tangwei12 已提交
983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
    the given directory.

    A variable is a checkpoint variable if it meets all following
    conditions:
        1. It's persistable.
        2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
        3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".

    Args:
        executor(Executor): The executor to run for saving variables.
        dirname(str): The directory path.
        program(Program): The program whose checkpoint variables will
                          be saved.

    Returns:
        None

    Examples:
        .. code-block:: python

            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            prog = fluid.default_main_program()
T
bug fix  
tangwei12 已提交
1006
            _save_persistable_vars(executor=exe,
T
tangwei12 已提交
1007 1008
                    dirname=param_path, program=prog)

T
bug fix  
tangwei12 已提交
1009
            # In this example, `_save_persistable_vars` function
T
tangwei12 已提交
1010
            # will first filters out all checkpoint variables in the default
T
bug fix  
tangwei12 已提交
1011
            # main program, and then saves these variables to the folder
T
tangwei12 已提交
1012 1013 1014 1015 1016 1017 1018 1019
            # "./my_paddle_model/__model__".
    """
    cur_dir = _get_model_dir(dirname)
    io.save_vars(
        executor,
        dirname=cur_dir,
        main_program=program,
        vars=None,
1020
        predicate=_is_checkpoint_var(),
T
tangwei12 已提交
1021 1022 1023 1024 1025
        filename=None)
    _write_success(cur_dir)


def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
1026
                                 pserver_endpoints):
T
tangwei12 已提交
1027 1028 1029
    """
    This function will send checkpoint notify message from Trainer 0
    to all the pservers.
T
bug fix  
tangwei12 已提交
1030
    The checkpoint notify message contains lookup table name,
T
tangwei12 已提交
1031 1032 1033 1034 1035 1036 1037
    the absolute path on pserver to save lookup_table.

    Args:
        executor(Executor): The executor to run for send checkpoint notify.
        dirname(str): The folder where to save checkpoints.
        lookup_table(string): the lookup table name, when use distribute
            lookup table, we can get lookup table name by DistributeTranspiler.
T
bug fix  
tangwei12 已提交
1038
            table_name
1039 1040
        pserver_endpoints(list): the parameter server ip:port list.
            when use distribute lookup table, we can get pserver_endpoints by
T
tangwei12 已提交
1041 1042 1043
            distribute arguments.
    Return:
        None
T
bug fix  
tangwei12 已提交
1044

T
tangwei12 已提交
1045 1046 1047 1048 1049 1050 1051 1052 1053 1054
    Examples:
        .. code-block:: python

            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            prog = fluid.default_main_program()
            table_name = "share_w"
            ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]

            _save_pserver_vars_by_notify(executor=exe,
T
bug fix  
tangwei12 已提交
1055
                    dirname=param_path, lookup_table=table_name,
T
tangwei12 已提交
1056 1057 1058 1059 1060 1061 1062 1063
                    ps_endpoint_list=ps_endpoints)
    """
    cur_dir = _get_lookuptable_dir(dirname)

    checkpoint_notify_program = framework.Program()
    checkpoint_notify_block = checkpoint_notify_program.global_block()

    attrs = {}
1064
    attrs['epmap'] = pserver_endpoints.split(",")
T
tangwei12 已提交
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084
    attrs['dir'] = cur_dir
    attrs['lookup_table'] = lookup_table

    checkpoint_notify_block.append_op(
        type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
    executor.run(checkpoint_notify_program)


def _save_trainer_args(dirname, trainer_id, trainer_args):
    assert isinstance(trainer_args, dict)

    cur_dir = _get_trainer_dir(dirname, trainer_id)

    for name, value in trainer_args.iteritems():
        args_file = os.path.join(cur_dir, name)
        with open(args_file, 'w') as f:
            f.write(str(value))
    _write_success(cur_dir)


T
bug fix  
tangwei12 已提交
1085
def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args):
T
tangwei12 已提交
1086
    """
T
bug fix  
tangwei12 已提交
1087
    trainer will load some args from it's independent directory,
T
tangwei12 已提交
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
    such as epoch_id and step_id.

    Args:
        checkpoint_dir(str): The folder where all checkpoints are.
        serial(int): The serial of checkpoint you would like to load.
        trainer_id(int): current trainer id.
        trainer_args(list): list about load trainer args
    Return:
        None

    Examples:
        .. code-block:: python

            param_path = "./checkpoint/"
            serial = 7
            trainer_id = 2
            trainer_args = ["epoch_id", "step_id"]

            _load_trainer_args(checkpoint_dir=param_path, serial=serial,
            trainer_id=trainer_id, trainer_args=trainer_args)
    """
    assert isinstance(trainer_args, list)

T
bug fix  
tangwei12 已提交
1111
    cur_dir = _get_trainer_dir(checkpoint_dir, trainer_id)
T
tangwei12 已提交
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122

    ret_values = []

    for arg in trainer_args:
        cur_file = os.path.join(cur_dir, arg)
        with open(cur_file, 'r') as f:
            contents = f.read()
            ret_values.append(contents.strip())
    return ret_values


1123 1124
def _is_checkpoint_var(except_vars=None):
    except_vars = [] if except_vars is None else except_vars
T
tangwei12 已提交
1125

1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153
    def _except_vars(var):
        """
        the checkpoint will not save or load all the variables.
        var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

        : param var(Variable)
        """
        if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
                var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
                var.desc.type() == core.VarDesc.VarType.RAW:
            return False
        # @GRAD are named for gradient variables, checkpoint will not save it.
        if "@GRAD" in var.name:
            return False
        # .trainer_ are named for distribute train variables, checkpoint will not save it.
        if ".trainer_" in var.name:
            return False

        # .block is named for distribute train variables, checkpoint will not save it.
        if ".block" in var.name:
            return False

        if var in except_vars:
            return False

        return var.persistable

    return _except_vars
T
tangwei12 已提交
1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174


def _make_chekcpoint_dirs(dirs):
    """
    _make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
    """
    assert dirs is not None

    if os.path.isfile(dirs):
        raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)

    if not os.path.isdir(dirs):
        try:
            os.makedirs(dirs)
        except OSError as err:
            if err.errno != errno.EEXIST:
                raise err


def _get_dir_serial(dirname):
    try:
T
bug fix  
tangwei12 已提交
1175
        _, serial = dirname.split(CHECKPOINT_SEPARATOR)
T
tangwei12 已提交
1176 1177 1178 1179 1180 1181
        serial_num = int(serial)
    except ValueError:
        serial_num = -1
    return serial_num


T
bug fix  
tangwei12 已提交
1182
def _get_serial_dir(dirname, serial, makedirs=False):
T
tangwei12 已提交
1183 1184
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
    serial_dir = os.path.join(dirname, serial_folder)
T
bug fix  
tangwei12 已提交
1185 1186
    if makedirs:
        _make_chekcpoint_dirs(serial_dir)
T
tangwei12 已提交
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255
    return serial_dir


def _get_model_dir(dirname):
    model_dir = os.path.join(dirname, MODEL_DIR)
    _make_chekcpoint_dirs(model_dir)
    return model_dir


def _get_lookuptable_dir(dirname):
    lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
    _make_chekcpoint_dirs(lookuptable_dir)
    return lookuptable_dir


def _get_trainer_dir(dirname, trainer_id):
    trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
    trainer_dir = os.path.join(dirname, trainer_folder)
    _make_chekcpoint_dirs(trainer_dir)
    return trainer_dir


def _scroll_delete(dirname, max_num_checkpoints=3):
    dirs = os.listdir(dirname)
    serial_map = {}
    for serial in dirs:
        serial_num = _get_dir_serial(serial)
        serial_map[serial_num] = serial

    if len(serial_map.keys()) <= max_num_checkpoints:
        return

    serials = serial_map.keys()
    serials.sort(reverse=True)
    serials = serials[max_num_checkpoints:]
    for serial in serials:
        cur_dir = _get_serial_dir(dirname, serial)
        try:
            shutil.rmtree(cur_dir)
        except OSError as err:
            if err.errno != errno.ENOENT:
                raise err


def _write_success(dirname):
    """
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.

    : param dirname
    """
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
    with open(success_file, 'a') as f:
        now = time.ctime()
        f.write(now)


def _get_latest_checkpoint_serial(checkpoint_dir):
    """
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    : param checkpoint_dir
    """

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """

        serial = _get_dir_serial(cur_dir)
T
bug fix  
tangwei12 已提交
1256 1257
        if serial == -1 or \
            not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
T
tangwei12 已提交
1258 1259 1260 1261 1262 1263 1264 1265 1266
            return -1

        success_path = os.path.join(
            _get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
            SUCCESS_MARK_FILENAME)
        if os.path.isfile(success_path):
            return serial

    current_dir = -1
T
bug fix  
tangwei12 已提交
1267 1268 1269 1270

    if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
        return current_dir

T
tangwei12 已提交
1271 1272 1273 1274 1275 1276
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir