__init__.py 19.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import logging

import paddle.fluid as fluid
import paddle.fluid.io as io
import paddle.fluid.transpiler.distribute_transpiler as dist_transpiler
19 20 21 22
from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.framework import Program
23

T
tangwei12 已提交
24 25 26
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
27

28
from paddle.fluid import compiler
29
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver
30

31 32
import os
import sys
33
import six
34 35 36
import json
import re
import shutil
37 38 39


class LambConfig(object):
40
    def __init__(self):
41 42 43 44
        pass


class DistFCConfig(object):
45
    def __init__(self):
46
        pass
47 48


49 50 51
class Collective(Fleet):
    def __init__(self):
        super(Collective, self).__init__(Mode.COLLECTIVE)
T
tangwei12 已提交
52
        self._local_ip = 0
53

54 55
        self.startup_program = None
        self._origin_program = None
56
        self._transpiled_program = None
57
        self.main_program = None
G
gongweibao 已提交
58
        self._checkpoint_prefix = "__paddle_fleet_checkpoint__"
59
        self._param_file_name = "_paddle_fleet_param__"
60

T
tangwei12 已提交
61
    def init_worker(self):
62 63 64
        logging.warn(
            "You should not call 'init_worker' method for collective mode.")

T
tangwei12 已提交
65
    def run_worker(self, main_programs=None, scopes=None):
66 67 68
        logging.warn(
            "You should not call 'run_worker' method for collective mode.")

T
tangwei12 已提交
69
    def init_server(self, model_dir=None):
70 71 72
        logging.warn(
            "You should not call 'init_server' method for collective mode.")

T
tangwei12 已提交
73
    def run_server(self):
74 75 76 77 78 79 80 81
        logging.warn(
            "You should not call 'run_server' method for collective mode.")

    def stop_worker(self):
        logging.warn(
            "You should not call 'stop_worker' method for collective mode.")

    def distributed_optimizer(self, optimizer, strategy=None):
82
        self._optimizer = \
83
            CollectiveOptimizer(optimizer, strategy)
T
tangwei12 已提交
84
        return self._optimizer
85 86

    def save_inference_model(self,
87
                             executor,
88 89 90 91 92
                             dirname,
                             feeded_var_names=None,
                             target_vars=None,
                             main_program=None,
                             export_for_deployment=True):
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        """
        Prune the given `main_program` to build a new program especially for
        inference, and then save it and all related parameters to given
        `dirname` by the `executor`.
        """
        assert isinstance(executor, Executor), \
            "In fleet.save_inference_model() function, executor must be as" \
            " Executor type."

        if main_program is None:
            main_program = self._origin_program
        assert isinstance(main_program, Program), \
            "In fleet.save_inference_model() function, main_program " \
            "must be as Program type."

108
        io.save_inference_model(dirname, feeded_var_names, target_vars,
109
                                executor, main_program, None, None,
110 111
                                export_for_deployment)

112 113 114 115 116
    def save_persistables(self,
                          executor,
                          dirname,
                          main_program=None,
                          filename=None):
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        """
        This function filters out all variables with `persistable==True` from
        the give `main_program` and then saves these variables to the folder
        `dirname` or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """
        assert isinstance(executor, Executor), \
            "In fleet.save_inference_model() function, executor must be as" \
            " Executor type."

        if main_program is None:
            main_program = self._origin_program

        assert isinstance(main_program, Program), \
            "In fleet.save_inference_model() function, main_program " \
            "must be as Program type."

138 139
        io.save_persistables(executor, dirname, main_program, filename=filename)

G
gongweibao 已提交
140 141 142
    def save_checkpoint(self,
                        executor,
                        path,
143
                        trainer_id,
G
gongweibao 已提交
144
                        train_status,
G
gongweibao 已提交
145
                        fs,
G
gongweibao 已提交
146 147 148
                        main_program=None,
                        local_cache_path=".cache",
                        remain_all_checkpoint=True):
149 150 151 152 153 154
        """
        This function save persistables and current epoch num to path.
        """
        if main_program == None:
            main_program = self._transpiled_program

155 156 157 158 159 160 161 162
        m = PaddleModel(executor, main_program)
        t = train_status
        c = CheckpointSaver(fs)
        real_path, checkpoint_no = c.save_checkpoint(
            path=path,
            slists=[m, t],
            trainer_id=trainer_id,
            local_cache_path=local_cache_path)
163 164

        if not remain_all_checkpoint:
165 166 167
            c.clean_redundant_checkpoints(path)

        return real_path, checkpoint_no
G
gongweibao 已提交
168 169 170 171 172

    def load_checkpoint(self,
                        executor,
                        path,
                        trainer_id,
173
                        train_status,
G
gongweibao 已提交
174
                        fs,
G
gongweibao 已提交
175 176 177
                        main_program=None,
                        local_cache_path=".cache",
                        ignore_empty=True):
178 179 180 181 182 183 184
        """
        This function load persistables and current epoch num from path.
        """

        if main_program == None:
            main_program = self._transpiled_program

185 186 187 188 189 190 191
        m = PaddleModel(executor, main_program)
        c = CheckpointSaver(fs)
        return c.load_checkpoint(
            path, [m, train_status],
            trainer_id=trainer_id,
            ignore_empty=ignore_empty,
            local_cache_path=local_cache_path)
192

193 194 195 196

fleet = Collective()


197 198 199 200 201 202 203 204 205 206 207 208 209
class DistributedStrategy(fluid.BuildStrategy):
    """
    Init function of DistributedStrategy
    """

    def __init__(self):
        super(DistributedStrategy, self).__init__()
        self.use_local_sgd = False
        self.use_dist_fc = False

        self.dist_fc_config = None  # DistFCConfig
        self.mode = "nccl2"  # or collective
        self.collective_mode = None  # local_sgd or grad_allreduce
G
gongweibao 已提交
210
        self.nccl_comm_num = 1
M
mapingshuo 已提交
211
        self.forward_recompute = False  # use RecomputeOptimizer
M
mapingshuo 已提交
212
        self.recompute_checkpoints = []
M
mapingshuo 已提交
213 214
        self.use_amp = False  # use mixed precision optimizer
        self.amp_loss_scaling = 2**15
215 216 217

        self.exec_strategy = fluid.ExecutionStrategy()

218 219 220
        # configurations below are used for unit test
        self._ut4grad_allreduce = False

221

222 223
class CollectiveOpBasedOptimizer(DistributedOptimizer):
    """
224 225
    Collective Operator Base Class For Distributed Optimizer
    The class is invisible to a user
226 227 228
    """

    def __init__(self, optimizer, strategy=None):
229 230 231
        assert isinstance(
            strategy,
            DistributedStrategy), "strategy must be DistributedStrategy"
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
        super(CollectiveOpBasedOptimizer, self).__init__(optimizer, strategy)

    def backward(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
                 callbacks=None):
        return self._optimizer.backward(loss, startup_program, parameter_list,
                                        no_grad_set, callbacks)

    def apply_gradients(self, params_grads):
        return self._optimizer.apply_gradients(params_grads)


247 248 249 250 251 252 253 254 255 256 257
class CollectiveOptimizer(DistributedOptimizer):
    """
    DistributedOptimizer is a wrapper for paddle.fluid.optimizer
    A user should pass a paddle.fluid.optimizer to DistributedOptimizer
    minimize() function is implemented.
    DistributedOptimizer is the starting point for a user who wants to
    run distributed training. The optimized information will be stored in
    Fleet() instance who holds the global information about current distributed
    training.
    """

258
    def __init__(self, optimizer, strategy=DistributedStrategy()):
259 260
        if strategy is None:
            strategy = DistributedStrategy()
261
        super(CollectiveOptimizer, self).__init__(optimizer, strategy)
M
mapingshuo 已提交
262 263 264 265 266 267 268
        self._forward_recompute = strategy.forward_recompute
        if (not isinstance(strategy.recompute_checkpoints, list)):
            raise ValueError("DistStrategy.recompute_checkpoints should"
                             "be a List")
        self._recompute_checkpoints = strategy.recompute_checkpoints
        self._use_amp = strategy.use_amp
        self._amp_loss_scaling = strategy.amp_loss_scaling
269
        self.print_config = False
270 271 272 273 274 275 276 277 278 279 280 281 282

    def backward(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
                 callbacks=None):
        return self._optimizer.backward(loss, startup_program, parameter_list,
                                        no_grad_set, callbacks)

    def apply_gradients(self, params_grads):
        return self._optimizer.apply_gradients(params_grads)

283
    def _check_condition(self, name, **kwargs):
284
        for k, v in six.iteritems(kwargs):
285 286 287 288 289
            if v is True:
                assert False, "you can't use %s and %s together" % (name, k)

    def _check_collective_mode(self, main_program, optimizer, strategy):
        """
T
tianshuo78520a 已提交
290
        Check the conflict conditions.
291 292
        """
        if strategy.use_local_sgd:
293 294
            strategy.mode = "collective"
            strategy.collective_mode = "local_sgd"
295 296 297 298 299 300 301 302 303 304 305 306 307 308
            self._check_condition(
                "use_local_sgd",
                use_dgc=main_program._enable_dgc,
                use_dist_fc=strategy.use_dist_fc,
                use_lamb=main_program._use_lamb)

        if strategy.use_dist_fc:
            self._check_condition(
                "use_dist_fc",
                use_dgc=main_program._enable_dgc,
                use_local_sgd=strategy.use_local_sgd,
                use_lamb=main_program._use_lamb)
            assert strategy.dist_fc_config is not None, "DistributedStrategy.dist_fc_config should be set"

309 310 311 312 313 314 315 316
        if strategy._ut4grad_allreduce:
            strategy.mode = "collective"
            strategy.collective_mode = "grad_allreduce"
            self._check_condition(
                "_ut4grad_allreduce",
                use_dgc=main_program._enable_dgc,
                use_lamb=main_program._use_lamb)

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
        if self._strategy.collective_mode=="local_sgd" \
                or self._strategy.collective_mode == "grad_allreduce":
            assert self._strategy.mode == "collective", \
                "local_sgd and grad_allreduce can be used under collective mode"

    def _transpile(self, startup_program, main_program):
        """
        Transpile the programs to distributed programs. And add the variables.
        """
        worker_endpoints = fleet.worker_endpoints()
        trainer_id = fleet.worker_index()
        current_endpoint = fleet.worker_endpoints()[trainer_id]
        worker_endpoints_env = ','.join(worker_endpoints)
        trainers_num = fleet.worker_num()

        if self.print_config:
            print("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
                  trainer_id:{}".format(worker_endpoints, trainers_num,
                                        current_endpoint, trainer_id))

        # call transpiler
        config = dist_transpiler.DistributeTranspilerConfig()
        config.mode = self._strategy.mode
        config.collective_mode = self._strategy.collective_mode

        config.nccl_comm_num = self._strategy.nccl_comm_num
        config.use_hierarchical_allreduce = self._strategy.use_hierarchical_allreduce
        config.hierarchical_allreduce_inter_nranks = self._strategy.hierarchical_allreduce_inter_nranks

        t = dist_transpiler.DistributeTranspiler(config=config)
        t.transpile(
            trainer_id=trainer_id,
            trainers=worker_endpoints_env,
            startup_program=startup_program,
            program=main_program,
            current_endpoint=current_endpoint)

G
gongweibao 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
    def _get_node_ips_from_endpoints(self, endpoints):
        ss = set()
        ips = []
        for ep in endpoints:
            ip = ep.split(":")[0].strip()
            if ip not in ss:
                ss.add(ip)
                ips.append(ip)
            else:
                continue

        return ips

    def _node_num(self):
        worker_endpoints = fleet.worker_endpoints()
        current_endpoint = fleet.worker_endpoints()[fleet.worker_index()]
        worker_endpoints_env = ','.join(worker_endpoints)

        node_ips = self._get_node_ips_from_endpoints(worker_endpoints)
        node_ip = current_endpoint.split(":")[0].strip()

        node_num = len(node_ips)

        return node_num

379
    def _try_to_compile(self, startup_program, main_program):
G
gongweibao 已提交
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
        node_num = self._node_num()
        assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num

        exec_strategy = self._strategy.exec_strategy

        if node_num <= 1:
            if self._strategy.nccl_comm_num > 1:
                logging.warn("set nccl_comm_num=1 since you only have 1 node.")
            self._strategy.nccl_comm_num = 1

            if self._strategy.use_hierarchical_allreduce:
                logging.warn(
                    "set use_hierarchical_allreduce=False since you only have 1 node."
                )
            self._strategy.use_hierarchical_allreduce = False

        sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce")
        if sync_allreduce is None or sync_allreduce == "1":
            exec_strategy.num_threads = self._strategy.nccl_comm_num + 1
            if self._strategy.use_hierarchical_allreduce:
                exec_strategy.num_threads = 2 * self._strategy.nccl_comm_num + 1
            if exec_strategy.num_threads > 4:
                logging.warn(
                    "if you use use_hierarchical_allreduce or "
                    "with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
                )

407 408 409 410 411 412 413 414 415 416 417
        # NOTE. open sync_batch_norm will hang when use multi num_threads
        sync_batch_norm = self._strategy.sync_batch_norm
        if sync_batch_norm is not None and sync_batch_norm is True:
            self._strategy.nccl_comm_num = 1
            self._strategy.use_hierarchical_allreduce = False
            exec_strategy.num_threads = 1
            logging.warn(
                "use sync_batch_norm will hang when set num_threads > 1, so "
                "set num_threads=1, nccl_comm_num=1, use_hierarchical_allreduce=False."
            )

G
gongweibao 已提交
418 419 420 421 422 423 424
        if self.print_config:
            print("node_num:", node_num, "num_threads:",
                  exec_strategy.num_threads, "use_hierarchical_allreduce:",
                  self._strategy.use_hierarchical_allreduce, "nccl_comm_num:",
                  self._strategy.nccl_comm_num, "FLAGS_sync_nccl_allreduce:",
                  sync_allreduce)

425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
        self._transpile(startup_program, main_program)

        if self._strategy.mode == "collective":
            return main_program

        self._strategy.num_trainers = fleet.worker_num()
        self._strategy.trainer_id = fleet.worker_index()
        self._strategy.trainers_endpoints = fleet.worker_endpoints()
        self._strategy.enable_backward_optimizer_op_deps = True

        self._compiled_program = compiler.CompiledProgram(main_program)

        self._compiled_program.with_data_parallel(
            loss_name=self._loss.name,
            build_strategy=self._strategy,
            exec_strategy=self._strategy.exec_strategy,
            share_vars_from=None)

        return self._compiled_program

M
mapingshuo 已提交
445 446 447 448
    def raiseOptimizeError(self, strategy_name, optimize_name):
        raise ValueError("can not use {0} when you set DistStrategy.{1} "
                         "as True".format(optimize_name, strategy_name))

449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        minimize a program through loss
        Args:
            loss (Variable|Variable List): loss variable or loss variable list to run optimization.
            startup_program (Program): startup_program for initializing parameters
                in `parameter_list`.
            parameter_list (list): list of Variables to update.
            no_grad_set (set|None): set of Variables should be ignored.
        Returns:
            tuple: (optimize_ops, params_grads) which are, list of operators appended;
            and list of (param, grad) Variables pair for optimization.
        Note that in parameter server mode, a worker will not get anything about optimize_os
T
tianshuo78520a 已提交
466
        Because optimizer algorithms run on pserver side. We will make this usable in pserver
467 468 469
        process, but currently the optimization part is written into Fleet(). A user does not
        need to care about how to startup a pserver node.
        """
M
mapingshuo 已提交
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496

        # check optimizer conflicts
        if self._forward_recompute:
            if self._recompute_checkpoints == []:
                raise ValueError("please set strategy.recompute_checkpoints"
                                 "when set strategy.forward_recompute as True")
            if self._optimizer.__class__.__name__ in [
                    "RecomputeOptimizer", "OptimizerWithMixedPrecision"
            ]:
                self.raiseOptimizeError("forward_recompute",
                                        self._optimizer.__class__.__name__)

            self._optimizer = \
                fluid.optimizer.RecomputeOptimizer(self._optimizer)
            self._optimizer._set_checkpoints(self._recompute_checkpoints)

        if self._use_amp:
            if self._optimizer.__class__.__name__ in [
                    "OptimizerWithMixedPrecision", "DGCMomentumOptimizer"
            ]:
                self.raiseOptimizeError("mixed_precision",
                                        self._optimizer.__class__.__name__)
            self._optimizer = fluid.contrib.mixed_precision.decorate(
                self._optimizer,
                init_loss_scaling=self._amp_loss_scaling,
                use_dynamic_loss_scaling=True)

497 498 499 500
        main_program = loss.block.program
        if startup_program is None:
            startup_program = fluid.default_startup_program()
        fleet.startup_program = startup_program
501

502
        self._loss = loss
503

504 505
        self._check_collective_mode(main_program, self._optimizer,
                                    self._strategy)
506

507
        optimize_ops, param_grads = self._optimizer.minimize(
G
gongweibao 已提交
508 509 510 511
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set)
512

513 514
        fleet._origin_program = main_program.clone(for_test=False)
        fleet._transpiled_program = main_program
515
        fleet.main_program = self._try_to_compile(startup_program, main_program)
516 517

        return optimize_ops, param_grads