distributed_strategy.py 45.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
from paddle.distributed.fleet.proto import distributed_strategy_pb2
17
from paddle.fluid.framework import Variable, set_flags, core
18
from paddle.fluid.wrapped_decorator import wrap_decorator
19
import google.protobuf.text_format
20
import google.protobuf
21

22 23
__all__ = ["DistributedStrategy"]

24 25 26 27 28 29 30 31 32 33 34 35 36 37
non_auto_func_called = True


def __non_auto_func_called__(func):
    def __impl__(*args, **kwargs):
        global non_auto_func_called
        non_auto_func_called = False
        return func(*args, **kwargs)

    return __impl__


is_strict_auto = wrap_decorator(__non_auto_func_called__)

38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
def get_msg_dict(msg):
    res_dict = {}
    fields = msg.DESCRIPTOR.fields
    for f in fields:
        res_dict[f.name] = getattr(msg, f.name)
    return res_dict


def assign_configs_value(msg, config):
    fields = msg.DESCRIPTOR.fields
    for key in config:
        for f in fields:
            if key == f.name:
                if f.label == 3:
                    getattr(msg, f.name).extend(config[f.name])
                elif f.label == 1 or f.label == 2:
                    setattr(msg, f.name, config[f.name])


def check_configs_key(msg, config, field_name):
    key_list = msg.DESCRIPTOR.fields_by_name.keys()
    for key in config:
        assert key in key_list, "key:{} not in {}".format(key, field_name)


64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
class DistributedJobInfo(object):
    """
    DistributedJobInfo will serialize all distributed training information
    Just for inner use: 1) debug 2) replicate experiments
    """

    def __init__(self):
        self.job_info = distributed_strategy_pb2.DistributedJobInfo()

    def _set_worker_num(self, worker_num):
        self.job_info.worker_num = worker_num

    def _set_server_num(self, server_num):
        self.job_info.server_num = server_num

    def _set_worker_ips(self, worker_ips):
        self.job_info.worker_ips.extend(worker_ips)

    def _set_server_endpoints(self, server_endpoints):
        self.job_info.server_endpoints.extend(server_endpoints)

    def _set_origin_startup(self, origin_startup_prog):
        self.job_info.origin_startup = str(origin_startup_prog)

    def _set_origin_main(self, origin_main_prog):
        self.job_info.origin_main = str(origin_main_prog)

    def _distributed_main(self, distributed_main_prog):
        self.job_info.distributed_main = str(distributed_main_prog)

    def _optimizer_name(self, optimizer_name):
        self.job_info.optimizer_name = optimizer_name

    def _set_distributed_strategy(self, dist_strategy):
        self.job_info.strategy = dist_strategy


class DistributedStrategy(object):
102 103
    __lock_attr = False

104
    def __init__(self):
105 106 107 108 109 110 111 112 113 114 115 116
        """
        DistributedStrategy is the main configuration entry for distributed training of Paddle.
        All of the distributed training configurations can be configured in DistributedStrategy,
        such as automatic mixed precision (AMP), Layer-wise Adaptive Rate Scaling (LARS), 
        asynchronous update parameter server(ASGD), etc.
        
        DistributedStrategy can be serialized into protobuf file or deserialized from protobuf file

        Users who run local training usually configure BuildStrategy and ExecutionStrategy, and 
        DistributedStrategy supports configurations from BuildStrategy and ExecutionStrategy

        """
117
        self.strategy = distributed_strategy_pb2.DistributedStrategy()
118 119 120 121 122 123 124
        self.__lock_attr = True

    def __setattr__(self, key, value):
        if self.__lock_attr and not hasattr(self, key):
            raise TypeError("%s is not a attribute of %s" %
                            (key, self.__class__.__name__))
        object.__setattr__(self, key, value)
125

126
    def save_to_prototxt(self, output):
127 128 129 130 131 132
        """
        Serialize current DistributedStrategy to string and save to output file

        Examples:
          .. code-block:: python
        
133
            import paddle.distributed.fleet as fleet
134 135 136
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True
            strategy.recompute = True
M
mapingshuo 已提交
137
            strategy.recompute_configs = {"checkpoints": ["x"]}
138 139
            strategy.save_to_prototxt("dist_strategy.prototxt")
        """
140 141 142 143
        with open(output, "w") as fout:
            fout.write(str(self.strategy))

    def load_from_prototxt(self, pb_file):
144 145 146 147 148 149
        """
        Load from prototxt file for DistributedStrategy initialization

        Examples:
          .. code-block:: python

150
            import paddle.distributed.fleet as fleet
151
            strategy = fleet.DistributedStrategy()
M
mapingshuo 已提交
152
            strategy.load_from_prototxt("dist_strategy.prototxt")
153 154 155 156 157 158 159 160 161 162 163 164 165
        """
        with open(pb_file, 'r') as f:
            self.strategy = google.protobuf.text_format.Merge(
                str(f.read()), self.strategy)

    @property
    def execution_strategy(self):
        """
        Configure ExecutionStrategy for DistributedStrategy

        Examples:
          .. code-block:: python

M
mapingshuo 已提交
166
            import paddle
167 168 169 170 171
            exe_strategy = paddle.fluid.ExecutionStrategy()
            exe_strategy.num_threads = 10
            exe_strategy.num_iteration_per_drop_scope = 10
            exe_strategy.num_iteration_per_run = 10

172
            strategy = paddle.distributed.fleet.DistributedStrategy()
173 174 175 176 177 178 179 180 181 182
            strategy.execution_strategy = exe_strategy
        """
        execution_strategy = paddle.fluid.ExecutionStrategy()
        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(execution_strategy, f.name,
                    getattr(self.strategy.execution_strategy, f.name))
        return execution_strategy

    @execution_strategy.setter
183
    @is_strict_auto
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    def execution_strategy(self, strategy):
        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(self.strategy.execution_strategy, f.name,
                    getattr(strategy, f.name))

    @property
    def build_strategy(self):
        """
        Configure BuildStrategy for DistributedStrategy
        Note that the properties of BuildStrategy are valid in DistributedStrategy
        only if the property is non-distributed strategy.

        Examples:
          .. code-block:: python

M
mapingshuo 已提交
200
            import paddle
201 202 203 204 205 206 207 208 209 210
            build_strategy = paddle.fluid.BuildStrategy()
            build_strategy.enable_sequential_execution = True
            build_strategy.fuse_elewise_add_act_ops = True
            build_strategy.fuse_bn_act_ops = True
            build_strategy.enable_auto_fusion = True
            build_strategy.fuse_relu_depthwise_conv = True
            build_strategy.fuse_broadcast_ops = True
            build_strategy.fuse_all_optimizer_ops = True
            build_strategy.enable_inplace = True
            
211
            strategy = paddle.distributed.fleet.DistributedStrategy()
212 213 214 215 216 217 218 219 220 221 222
            strategy.build_strategy = build_strategy
        """

        build_strategy = paddle.fluid.BuildStrategy()
        fields = self.strategy.build_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(build_strategy, f.name,
                    getattr(self.strategy.build_strategy, f.name))
        return build_strategy

    @build_strategy.setter
223
    @is_strict_auto
224 225 226 227 228 229 230 231 232 233 234
    def build_strategy(self, strategy):
        fields = self.strategy.build_strategy.DESCRIPTOR.fields
        for f in fields:
            if f.label == 1 or f.label == 2:  # optional and required field
                setattr(self.strategy.build_strategy, f.name,
                        getattr(strategy, f.name))
            elif f.label == 3:  # repeated field
                getattr(self.strategy.build_strategy,
                        f.name).extend(getattr(strategy, f.name))

    @property
D
Dong Daxiang 已提交
235
    def a_sync(self):
236 237 238 239 240 241 242 243 244
        """
        Indicating whether we are using asynchronous stocastic gradient descent updates
        for training. This property is valid when we are using parameter server training, 
        which is implied by setting approperate RoleMaker
        Default value: True

        Examples:
          .. code-block:: python

245
            import paddle.distributed.fleet as fleet
246 247 248 249
            role_maker = fleet.PaddleCloudRoleMaker()
            fleet.init(role_maker)

            strategy = fleet.DistributedStrategy()
D
Dong Daxiang 已提交
250
            strategy.a_sync = True  # by default this is True
251 252 253 254
            
            # code block for defining loss and local optimizer
            # sgd = fleet.distributed_optimizer(optimizer, strategy)
        """
D
Dong Daxiang 已提交
255
        return self.strategy.a_sync
256

D
Dong Daxiang 已提交
257
    @a_sync.setter
258
    @is_strict_auto
D
Dong Daxiang 已提交
259
    def a_sync(self, flag):
260
        if isinstance(flag, bool):
D
Dong Daxiang 已提交
261
            self.strategy.a_sync = flag
262
            self.a_sync_configs = {"k_steps": 0}
263
        else:
264 265 266
            raise ValueError(
                "The type of `flag` is invalid, expected type is bool, but received %s".
                format(type(flag)))
267 268

    @property
D
Dong Daxiang 已提交
269
    def a_sync_configs(self):
270
        """
D
Dong Daxiang 已提交
271
        Set a_sync update configurations. In general, asynchronous parameter server
272 273
        training has serveral configurable settings that can be configured through
        a dict.
274

275
        **Notes**:
M
mapingshuo 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288
            k_step(int): number of local optimization updates before communication

            max_merge_var_num(int): maximum number of merged gradients before communication

            send_queue_size(int): a buffer size of worker communication

            independent_recv_thread(bool): if we are using independent recv thread for communication

            thread_pool_size(int): number of thread pool

            send_wait_times(int): waiting time for sending gradients

            runtime_split_send_recv(bool): if we are using Tensor split for send and recv during runtime
289

290 291
        Examples:
          .. code-block:: python
292

293
            import paddle.distributed.fleet as fleet
294 295
            role_maker = fleet.PaddleCloudRoleMaker()
            fleet.init(role_maker)
296

297
            strategy = fleet.DistributedStrategy()
D
Dong Daxiang 已提交
298
            strategy.a_sync = True  # by default this is True
M
mapingshuo 已提交
299
            configs = {"k_steps": 1024, "send_queue_size": 32}
D
Dong Daxiang 已提交
300
            strategy.a_sync_configs = configs
301

302 303
            # code block for defining loss and local optimizer
            # sgd = fleet.distributed_optimizer(optimizer, strategy)
M
mapingshuo 已提交
304

305
        """
D
Dong Daxiang 已提交
306
        return get_msg_dict(self.strategy.a_sync_configs)
307

D
Dong Daxiang 已提交
308
    @a_sync_configs.setter
309
    @is_strict_auto
D
Dong Daxiang 已提交
310 311 312 313
    def a_sync_configs(self, configs):
        check_configs_key(self.strategy.a_sync_configs, configs,
                          "a_sync_configs")
        assign_configs_value(self.strategy.a_sync_configs, configs)
314

315
    @property
316 317 318 319
    def amp(self):
        """
        Indicating whether we are using automatic mixed precision training
        Default Value: False
320

321 322
        Examples:
          .. code-block:: python
323

324
            import paddle.distributed.fleet as fleet
325 326
            strategy = fleet.DistributedStrategy()
            strategy.amp = True # by default this is false
327

328 329
        """
        return self.strategy.amp
330

331
    @amp.setter
332
    @is_strict_auto
333
    def amp(self, flag):
334
        if isinstance(flag, bool):
335
            self.strategy.amp = flag
336
        else:
337
            print("WARNING: amp should have value of bool type")
338 339

    @property
340
    def amp_configs(self):
341 342 343 344 345
        """
        Set automatic mixed precision training configurations. In general, amp has serveral configurable
        settings that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
            init_loss_scaling(float): The initial loss scaling factor. Default 32768.

            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. Default True.

            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. Default 1000.

            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. Default 2.

            incr_ratio(float): The multiplier to use when increasing the loss scaling. Default 2.0.

            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. Default 0.5.

            custom_white_list(list[str]): Users' custom white list which always execution fp16.

            custom_black_list(list[str]): Users' custom black list which forbidden execution fp16.
361 362 363 364 365 366 367 368 369 370 371

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.amp = True
            strategy.amp_configs = {
                "init_loss_scaling": 32768,
                "custom_white_list": ['conv2d']}
        """
372
        return get_msg_dict(self.strategy.amp_configs)
373

374
    @amp_configs.setter
375
    @is_strict_auto
376 377 378
    def amp_configs(self, configs):
        check_configs_key(self.strategy.amp_configs, configs, "amp_configs")
        assign_configs_value(self.strategy.amp_configs, configs)
379 380

    @property
381 382 383 384 385 386 387 388
    def recompute(self):
        """
        Indicating whether we are using forward recomputation for memory optimization
        Default value: False

        Examples:
          .. code-block:: python

389
            import paddle.distributed.fleet as fleet
390 391 392 393 394 395
            strategy = fleet.DistributedStrategy()
            strategy.recompute = True
            # suppose x and y are names of checkpoint tensors for recomputation
            strategy.recompute_configs = {"checkpoints": ["x", "y"]}
        """
        return self.strategy.recompute
396

397 398
    @property
    def sync_nccl_allreduce(self):
399 400 401 402 403 404 405 406 407 408 409
        """
        Indicating whether we are using synchronized all reduce in each communication thread
        We note that system overhead is usually lower when sync_nccl_allreduce = True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.sync_nccl_allreduce = True
        """
410 411 412
        return self.strategy.sync_nccl_allreduce

    @sync_nccl_allreduce.setter
413
    @is_strict_auto
414 415 416 417
    def sync_nccl_allreduce(self, flag):
        if isinstance(flag, bool):
            self.strategy.sync_nccl_allreduce = flag
        else:
418
            print("WARNING: sync_nccl_allreduce should have value of bool type")
419

420
    @property
421
    def use_hierarchical_allreduce(self):
422 423 424 425 426 427 428 429 430 431 432 433
        """
        Indicating whether we are using hierarchical allreduce in collective communication
        Hierarchical allreduce often does allreduce within a certain node group and then do
        allreduce among the leaders of each group

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.use_hierarchical_allreduce = True
        """
434
        return self.strategy.use_hierarchical_allreduce
435

436
    @use_hierarchical_allreduce.setter
437
    @is_strict_auto
438
    def use_hierarchical_allreduce(self, flag):
439
        if isinstance(flag, bool):
440
            self.strategy.use_hierarchical_allreduce = flag
441 442
        else:
            print(
443
                "WARNING: use_hierarchical_allreduce should have value of bool type"
444 445 446
            )

    @property
447
    def hierarchical_allreduce_inter_nranks(self):
448 449 450 451 452 453 454 455 456 457 458
        """
        Number of ranks for low level node groups in hierarchical allreduce
        Default value: number of GPU cards on each single GPU machine

        Example:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.hierarchical_allreduce_inter_nranks = 8
        """
459
        return self.strategy.hierarchical_allreduce_inter_nranks
460

461
    @hierarchical_allreduce_inter_nranks.setter
462
    @is_strict_auto
463 464 465
    def hierarchical_allreduce_inter_nranks(self, value):
        if isinstance(value, int):
            self.strategy.hierarchical_allreduce_inter_nranks = value
466 467
        else:
            print(
468
                "WARNING: hierarchical_allreduce_inter_nranks should have value of int type"
469 470
            )

471
    @property
472
    def sync_batch_norm(self):
473 474 475 476 477 478 479 480 481 482 483 484 485
        """
        Indicating whether we are using sync_batch_norm to do synchronous batch normalization among all training nodes.
        
        Default value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.sync_batch_norm = True
        """

486
        return self.strategy.sync_batch_norm
487

488
    @sync_batch_norm.setter
489
    @is_strict_auto
490
    def sync_batch_norm(self, flag):
491
        if isinstance(flag, bool):
492
            self.strategy.sync_batch_norm = flag
493
        else:
494
            print("WARNING: sync_batch_norm should have value of bool type")
495 496 497

    @property
    def fuse_all_reduce_ops(self):
498 499 500 501 502 503 504 505 506 507 508
        """
        Indicating whether we are using fuse_all_reduce_ops for gradient fusion during backward phase of training
        Default value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fuse_all_reduce_ops = False
        """
509 510 511
        return self.strategy.fuse_all_reduce_ops

    @fuse_all_reduce_ops.setter
512
    @is_strict_auto
513 514 515 516 517 518
    def fuse_all_reduce_ops(self, flag):
        if isinstance(flag, bool):
            self.strategy.fuse_all_reduce_ops = flag
        else:
            print("WARNING: fuse_all_reduce_ops should have value of bool type")

519 520
    @property
    def fuse_grad_size_in_MB(self):
521 522 523 524 525 526 527 528 529 530 531 532
        """
        Specifying the size of gradient to fuse in Mega-Bytes

        Default value: 32

        Examples:
          .. code-block:: python
        
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fuse_grad_size_in_MB = 50
        """
533 534 535
        return self.strategy.fuse_grad_size_in_MB

    @fuse_grad_size_in_MB.setter
536
    @is_strict_auto
537 538 539 540 541 542 543 544 545 546 547
    def fuse_grad_size_in_MB(self, value):
        if isinstance(value, int):
            self.strategy.fuse_grad_size_in_MB = value
        else:
            print("WARNING: fuse_grad_size_in_MB should have value of int type")

    @property
    def _fuse_grad_size_in_TFLOPS(self):
        return self.strategy.fuse_grad_size_in_TFLOPS

    @_fuse_grad_size_in_TFLOPS.setter
548
    @is_strict_auto
549 550 551 552 553 554 555 556
    def _fuse_grad_size_in_TFLOPS(self, value):
        if isinstance(value, float):
            self.strategy.fuse_grad_size_in_TFLOPS = value
        else:
            print(
                "WARNING: fuse_grad_size_in_TFLOPS should have value of float type"
            )

557
    @property
558
    def nccl_comm_num(self):
559 560 561 562 563 564 565 566 567 568 569 570 571
        """
        Specifying the number of NCCL communicator

        Default value: 1

        Examples:
          .. code-block:: python
        
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.nccl_comm_num = 2
        """

572
        return self.strategy.nccl_comm_num
573

574
    @nccl_comm_num.setter
575
    @is_strict_auto
576
    def nccl_comm_num(self, value):
577
        if isinstance(value, int):
578
            self.strategy.nccl_comm_num = value
579
        else:
580
            print("WARNING: nccl_comm_num should have value of int type")
581

582
    @recompute.setter
583
    @is_strict_auto
584
    def recompute(self, flag):
585
        if isinstance(flag, bool):
586
            self.strategy.recompute = flag
587
        else:
588
            print("WARNING: recompute should have value of bool type")
589 590

    @property
591 592 593 594
    def recompute_configs(self):
        """
        Set recompute configurations. In general, the recompute strategy of current
        implementation should have some manually assign checkpoints
595

596 597 598
        Examples:
          .. code-block:: python
        
599
            import paddle.distributed.fleet as fleet
600 601
            strategy = fleet.DistributedStrategy()
            strategy.recompute = True
M
mapingshuo 已提交
602
            strategy.recompute_configs = {"checkpoints": ["x", "y"]}
603 604 605 606 607

        """
        return get_msg_dict(self.strategy.recompute_configs)

    @recompute_configs.setter
608
    @is_strict_auto
609 610 611 612
    def recompute_configs(self, configs):
        check_configs_key(self.strategy.recompute_configs, configs,
                          "checkpoint_configs")
        assign_configs_value(self.strategy.recompute_configs, configs)
613 614

    @property
615 616 617 618 619 620 621 622 623 624
    def pipeline(self):
        """
        Indicating whether we are using pipeline parallelism for distributed training.
        Current implementation mainly focus on single GPU machine pipeline parallelism and
        data parallelism across GPU machine. The pipeline information is indicated through
        device_guard information in user-defined program.

        Examples:
          .. code-block:: python
        
625
            import paddle.distributed.fleet as fleet
626 627 628 629 630
            strategy = fleet.DistributedStrategy()
            strategy.pipeline = True

        """
        return self.strategy.pipeline
631

632
    @pipeline.setter
633
    @is_strict_auto
634
    def pipeline(self, flag):
635
        if isinstance(flag, bool):
636
            self.strategy.pipeline = flag
637
        else:
638
            print("WARNING: pipeline should have value of bool type")
639 640

    @property
641 642 643 644 645 646 647 648 649 650
    def pipeline_configs(self):
        """
        Set pipeline parallelism configurations. In pipeline parallelism,
        different parts of neural networks are running on different GPUS.
        There are Tensor queue buffer between each pair of neighborhood GPUS 
        that are responsible for synchronizing hidden Tensor results between
        GPUs. Pipeline parallelism consists of serveral producer-consumer style
        hardware pairs, such as GPU-GPU, CPU-GPU, GPU-XPU. The best way to speedup
        pipeline parallelism is to make the size of Tensor in Tensor queue smaller, 
        so that we will have a faster producer for downstream consumers.
651

652 653
        **Notes**:
            **Detailed arguments for pipeline_configs**
M
mapingshuo 已提交
654

655
            **micro_batch**: the number of small batches in each user defined batch
656

657 658 659
        Examples:
          .. code-block:: python
        
660
            import paddle.distributed.fleet as fleet
661 662 663
            strategy = fleet.DistributedStrategy()
            strategy.pipeline = True
            strategy.pipeline_configs = {"micro_batch": 12}
664

665
        """
666

667
        return get_msg_dict(self.strategy.pipeline_configs)
668

669
    @pipeline_configs.setter
670
    @is_strict_auto
671 672 673 674
    def pipeline_configs(self, configs):
        check_configs_key(self.strategy.pipeline_configs, configs,
                          "pipeline_configs")
        assign_configs_value(self.strategy.pipeline_configs, configs)
675 676

    @property
677
    def localsgd(self):
678
        """
M
mapingshuo 已提交
679 680 681
        Indicating whether we are using Local SGD training. Default Value: False
        For more details, please refer to
        `Don't Use Large Mini-Batches, Use Local SGD <https://arxiv.org/pdf/1808.07217.pdf>`_.
682 683 684 685 686 687 688 689 690 691


        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.localsgd = True # by default this is false

        """
692
        return self.strategy.localsgd
693

694
    @localsgd.setter
695
    @is_strict_auto
696 697 698
    def localsgd(self, flag):
        if isinstance(flag, bool):
            self.strategy.localsgd = flag
699
        else:
700
            print("WARNING: localsgd should have value of bool type")
701 702

    @property
703
    def localsgd_configs(self):
704 705 706 707 708
        """
        Set LocalSGD training configurations. LocalSGD has a configurable
        setting that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
709
            k_steps(int) The local steps for training before parameter synchronization. Default 1.
710
            begin_step(int) The step of begining training by localsgd. Default 1.
711 712 713 714 715 716 717

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.localsgd = True
718 719
            strategy.localsgd_configs = {"k_steps": 4,
                                         "begin_step": 30}
720 721
        """

722
        return get_msg_dict(self.strategy.localsgd_configs)
723

724
    @localsgd_configs.setter
725
    @is_strict_auto
726 727 728 729
    def localsgd_configs(self, configs):
        check_configs_key(self.strategy.localsgd_configs, configs,
                          "localsgd_configs")
        assign_configs_value(self.strategy.localsgd_configs, configs)
730

731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787
    @property
    def adaptive_localsgd(self):
        """
        Indicating whether we are using Adaptive Local SGD training. Default Value: False
        For more details, please refer to `Adaptive Communication Strategies to Achieve 
        the Best Error-Runtime Trade-off in Local-Update SGD <https://arxiv.org/pdf/1810.08313.pdf>`_.


        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.adaptive_localsgd = True # by default this is false

        """
        return self.strategy.localsgd

    @adaptive_localsgd.setter
    @is_strict_auto
    def adaptive_localsgd(self, flag):
        if isinstance(flag, bool):
            self.strategy.localsgd = flag
        else:
            print("WARNING: adaptive_localsgd should have value of bool type")

    @property
    def adaptive_localsgd_configs(self):
        """
        Set AdaptiveLocalSGD training configurations. AdaptiveLocalSGD has a configurable
        setting that can be configured through a dict.

        **Notes**:
            init_k_steps(int) The initial steps for training before adaptive localsgd.
                              Then, the adaptive localsgd method will modify init_k_steps automatically.
                              Default 1.
            begin_step(int) The step of begining training by adaptive localsgd. Default 1.

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.adaptive_localsgd = True
            strategy.adaptive_localsgd_configs = {"init_k_steps": 1,
                                                  "begin_step": 30}
        """

        return get_msg_dict(self.strategy.adaptive_localsgd_configs)

    @adaptive_localsgd_configs.setter
    @is_strict_auto
    def adaptive_localsgd_configs(self, configs):
        check_configs_key(self.strategy.adaptive_localsgd_configs, configs,
                          "adaptive_localsgd_configs")
        assign_configs_value(self.strategy.adaptive_localsgd_configs, configs)

788
    @property
789
    def dgc(self):
790 791 792 793 794 795 796 797 798 799 800 801 802 803
        """
        Indicating whether we are using Deep Gradient Compression training. For more details, please refer to
        [Deep Gradient Compression](https://arxiv.org/abs/1712.01887).

        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True # by default this is false

        """
804
        return self.strategy.dgc
805

806
    @dgc.setter
807
    @is_strict_auto
808 809 810
    def dgc(self, flag):
        if isinstance(flag, bool):
            self.strategy.dgc = flag
811
        else:
812
            print("WARNING: dgc should have value of bool type")
813 814

    @property
815
    def dgc_configs(self):
816 817 818 819 820
        """
        Set Deep Gradient Compression training configurations. In general, dgc has serveral configurable
        settings that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
821 822 823 824 825 826 827 828 829 830
            rampup_begin_step(int): The beginning step from which gradient compression is implemented. Default 0.

            rampup_step(int): Time steps used in sparsity warm-up periods. Default is 1. \
                    For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100, \
                    it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. And when reach sparsity array \
                    ends, it will use 0.999 then and after.

            sparsity(list[float]): Get top important element from gradient tensor, the ratio is (1 - sparsity). \
                    Default is [0.999]. For example, if the sparsity is [0.99, 0.999], the top [1%, 0.1%] important \
                    element will be transmitted.
831 832 833 834 835 836 837 838 839

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True
            strategy.dgc_configs = {"rampup_begin_step": 1252}
        """
840
        return get_msg_dict(self.strategy.dgc_configs)
841

842
    @dgc_configs.setter
843
    @is_strict_auto
844 845 846
    def dgc_configs(self, configs):
        check_configs_key(self.strategy.dgc_configs, configs, "dgc_configs")
        assign_configs_value(self.strategy.dgc_configs, configs)
847

848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870
    @property
    def fp16_allreduce(self):
        """
        Indicating whether we are using fp16 gradient allreduce training
        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fp16_allreduce = True # by default this is false

        """
        return self.strategy.fp16_allreduce

    @fp16_allreduce.setter
    @is_strict_auto
    def fp16_allreduce(self, flag):
        if not isinstance(flag, bool):
            raise TypeError('fp16_allreduce must be value of bool type')
        self.strategy.fp16_allreduce = flag

871
    @property
872
    def gradient_merge(self):
873 874 875 876 877 878 879 880 881 882 883
        """
        Gradient Merge, also called as Gradient Accumulation,
        is a strategy for large batch training. With this strategy,
        model parameter will not be updated until user-defined steps.
        For each step, the forward network and the backward network
        will run to calculate the gradient of model parameters.
        For every k step, the optimization network will run,
        applying a specific optimization method (such as SGD, Adam)
        to model parameters.

        Examples:
M
mapingshuo 已提交
884 885
          .. code-block:: python

886
            import paddle.distributed.fleet as fleet
887 888 889 890
            strategy = fleet.DistributedStrategy()
            strategy.gradient_merge = True
            strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
        """
891
        return self.strategy.gradient_merge
892

893
    @gradient_merge.setter
894
    @is_strict_auto
895
    def gradient_merge(self, flag):
896
        if isinstance(flag, bool):
897
            self.strategy.gradient_merge = flag
898
        else:
899 900 901 902
            print("WARNING: gradient_merge should have value of bool type")

    @property
    def gradient_merge_configs(self):
903 904
        """
        the key-value configs of distribute_strategy
M
mapingshuo 已提交
905 906 907 908 909 910 911 912 913

        **Note**:
            k_steps(int): the update period of the parameters.

            avg(bool): whether to average the gradients of each mini-batch, the default value is `True`

        Examples:
          .. code-block:: python

914
            import paddle.distributed.fleet as fleet
915 916 917 918
            strategy = fleet.DistributedStrategy()
            strategy.gradient_merge = True
            strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
        """
919 920 921
        return get_msg_dict(self.strategy.gradient_merge_configs)

    @gradient_merge_configs.setter
922
    @is_strict_auto
923 924 925 926
    def gradient_merge_configs(self, configs):
        check_configs_key(self.strategy.gradient_merge_configs, configs,
                          "gradient_configs")
        assign_configs_value(self.strategy.gradient_merge_configs, configs)
927 928

    @property
929
    def lars(self):
930 931 932 933 934 935 936 937 938 939 940 941 942 943
        """
        Set lars configurations. lars is used to deal with the convergence problems when the global 
        batch size is larger than 8k.  For more details, please refer to 
        [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888).

        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lars = True # by default this is false
        """
944
        return self.strategy.lars
945

946
    @lars.setter
947
    @is_strict_auto
948
    def lars(self, flag):
949
        if isinstance(flag, bool):
950
            self.strategy.lars = flag
951
        else:
952
            print("WARNING: lars should have value of bool type")
953

954 955
    @property
    def lars_configs(self):
956 957 958 959 960 961 962 963 964 965 966 967 968
        """
        Set Lars training configurations.

        **Notes**:
        **lars_coeff (float)**: trust ratio in lars formula.
        **lars_weight_decay** (float): weight decay coefficient in lars formula.
        **epsilon (float)**: argument is used to avoid potential devision-by-zero 
        when compute the local lr; 
        **exclude_from_weight_decay ([string])**: is a list of name strings of layers which
        will be exclude from weight decay in lars formula.

        Examples:
          .. code-block:: python
M
mapingshuo 已提交
969

970 971 972 973 974 975 976 977 978 979
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lars = True
            strategy.lars_configs = {
                        "lars_coeff": 0.01,
                        "lars_weight_decay": 0.0005,
                        "epsilon": 0,
                        "exclude_from_weight_decay": ['batch_norm', '.b_0']
                    }
        """
980 981 982
        return get_msg_dict(self.strategy.lars_configs)

    @lars_configs.setter
983
    @is_strict_auto
984 985 986 987
    def lars_configs(self, configs):
        check_configs_key(self.strategy.lars_configs, configs, "lars_configs")
        assign_configs_value(self.strategy.lars_configs, configs)

988
    @property
989
    def lamb(self):
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
        """
        Set lamb configurations. lamb is used to deal with the convergence problems for large 
        batch size training, specially for attention-related model like BERT. For more details, 
        please refer to 
        [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962).

        Default Value: False
        
        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lamb = True # by default this is false
        """

1006
        return self.strategy.lamb
1007

1008
    @lamb.setter
1009
    @is_strict_auto
1010
    def lamb(self, flag):
1011
        if isinstance(flag, bool):
1012
            self.strategy.lamb = flag
1013
        else:
1014
            print("WARNING: lamb should have value of bool type")
1015

1016 1017
    @property
    def lamb_configs(self):
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
        """
        Set Lars training configurations.

        **Notes**:
        **lamb_weight_decay** (float): weight decay coefficient in lamb formula.
        **exclude_from_weight_decay ([string])**: is a list of name strings of layers which
        will be exclude from weight decay in lamb formula.

        Examples:
          .. code-block:: python
M
mapingshuo 已提交
1028

1029 1030 1031 1032 1033 1034 1035 1036
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lamb = True
            strategy.lamb_configs = {
                    'lamb_weight_decay': 0.01,
                    'exclude_from_weight_decay': [],
                }
        """
1037 1038 1039
        return get_msg_dict(self.strategy.lamb_configs)

    @lamb_configs.setter
1040
    @is_strict_auto
1041 1042 1043 1044
    def lamb_configs(self, configs):
        check_configs_key(self.strategy.lamb_configs, configs, "lamb_configs")
        assign_configs_value(self.strategy.lamb_configs, configs)

1045 1046
    @property
    def elastic(self):
1047 1048 1049 1050
        """
        Indicating whether we want to do current distributed training on clusters with elastic resources.
        Currently, this is configuration is not valid.
        """
1051 1052 1053
        return self.strategy.elastic

    @elastic.setter
1054
    @is_strict_auto
1055 1056 1057 1058 1059 1060 1061 1062
    def elastic(self, flag):
        if isinstance(flag, bool):
            self.strategy.elastic = flag
        else:
            print("WARNING: elastic should have value of bool type")

    @property
    def auto(self):
1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
        """
        Indicating whether we are using auto-parallel configuration
        This feature is currently an experimental feature. Currently, 
        auto-parallelism can be used only when a user does not set any other
        strategy configs except auto. For details, please reference the following
        code example
        Default Value: False

        Examples:
          .. code-block:: python

            import paddle
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.auto = True

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
        """
1082 1083 1084 1085 1086 1087 1088 1089 1090
        return self.strategy.auto

    @auto.setter
    def auto(self, flag):
        if isinstance(flag, bool):
            self.strategy.auto = flag
        else:
            print("WARNING: auto should have value of bool type")

1091 1092
    @property
    def cudnn_exhaustive_search(self):
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
        """
        Indicating whether to use exhaustive search method to choose convolution algorithms.
        Exhaustive search attempts all cuDNN algorithms to choose the fastest algorithm.
        This method is time-consuming, the choosed algorithm will be cached for the given layer specifications.
        Once the layer specifications (like batch size, feature map size) are changed, it will search again.
        Default Value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.cudnn_exhaustive_search = False

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
        """
1110 1111 1112
        return self.strategy.cudnn_exhaustive_search

    @cudnn_exhaustive_search.setter
1113
    @is_strict_auto
1114 1115 1116 1117 1118 1119 1120 1121 1122 1123
    def cudnn_exhaustive_search(self, flag):
        if isinstance(flag, bool):
            self.strategy.cudnn_exhaustive_search = flag
        else:
            print(
                "WARNING: cudnn_exhaustive_search should have value of bool type"
            )

    @property
    def conv_workspace_size_limit(self):
1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
        """
        The workspace limit size in MB unit for choosing cuDNN convolution algorithms.
        The inner funciton of cuDNN obtain the fastest suited algorithm that fits within this memory limit.
        Usually, large workspace size may lead to choose faster algorithms,
        but significant increasing memory workspace. Users need to trade-off between memory and speed.
        Default Value: 4000

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.conv_workspace_size_limit = 1024

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
        
        """
1142 1143 1144
        return self.strategy.conv_workspace_size_limit

    @conv_workspace_size_limit.setter
1145
    @is_strict_auto
1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
    def conv_workspace_size_limit(self, value):
        if isinstance(value, int):
            self.strategy.conv_workspace_size_limit = value
        else:
            print(
                "WARNING: conv_workspace_size_limit should have value of int type"
            )

    @property
    def cudnn_batchnorm_spatial_persistent(self):
1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171
        """
        Indicates whether to use the mode CUDNN_BATCHNORM_SPATIAL_PERSISTENT function in batchnorm.
        This is only useful in cudnn.
        Default Value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.cudnn_batchnorm_spatial_persistent = True

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)

        """
1172 1173 1174
        return self.strategy.cudnn_batchnorm_spatial_persistent

    @cudnn_batchnorm_spatial_persistent.setter
1175
    @is_strict_auto
1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
    def cudnn_batchnorm_spatial_persistent(self, flag):
        if isinstance(flag, bool):
            self.strategy.cudnn_batchnorm_spatial_persistent = flag
        else:
            print(
                "WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type"
            )

    def _enable_env(self):
        strategy = self.strategy
        keys = [
            "FLAGS_cudnn_batchnorm_spatial_persistent",
            "FLAGS_conv_workspace_size_limit",
            "FLAGS_cudnn_exhaustive_search",
            "FLAGS_sync_nccl_allreduce",
            "FLAGS_fuse_parameter_memory_size",
            "FLAGS_fuse_parameter_groups_size",
        ]
        values = [
            bool(strategy.cudnn_batchnorm_spatial_persistent),
            int(strategy.conv_workspace_size_limit),
            bool(strategy.cudnn_exhaustive_search),
            bool(strategy.sync_nccl_allreduce),
            int(strategy.fuse_grad_size_in_MB),
            int(strategy.fuse_grad_size_in_TFLOPS),
        ]

        for i, key in enumerate(keys):
            if core.globals().is_public(key):
                core.globals()[key] = values[i]

1207 1208 1209 1210 1211 1212
    def _is_strict_auto(self):
        global non_auto_func_called
        if self.strategy.auto and non_auto_func_called:
            return True
        return False

1213
    def __repr__(self):
1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231
        spacing = 2
        max_k = 38
        max_v = 38

        length = max_k + max_v + spacing

        h1_format = "    " + "|{{:^{}s}}|\n".format(length)
        h2_format = "    " + "|{{:>{}s}}{}{{:^{}s}}|\n".format(max_k, " " *
                                                               spacing, max_v)

        border = "    +" + "".join(["="] * length) + "+"
        line = "    +" + "".join(["-"] * length) + "+"

        draws = border + "\n"
        draws += h1_format.format("")
        draws += h1_format.format("DistributedStrategy Overview")
        draws += h1_format.format("")

D
Dong Daxiang 已提交
1232
        fields = self.strategy.DESCRIPTOR.fields
1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246
        str_res = ""

        env_draws = line + "\n"
        for f in fields:
            if "build_strategy" in f.name or "execution_strategy" in f.name:
                continue
            if "_configs" in f.name:
                continue
            else:
                if isinstance(getattr(self.strategy, f.name), bool):
                    if hasattr(self.strategy, f.name + "_configs"):
                        if getattr(self.strategy, f.name):
                            draws += border + "\n"
                            draws += h1_format.format(
D
Dong Daxiang 已提交
1247
                                "{}=True <-> {}_configs".format(f.name, f.name))
1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284
                            draws += line + "\n"
                            my_configs = getattr(self.strategy,
                                                 f.name + "_configs")
                            config_fields = my_configs.DESCRIPTOR.fields
                            for ff in config_fields:
                                if isinstance(
                                        getattr(my_configs, ff.name),
                                        google.protobuf.pyext._message.
                                        RepeatedScalarContainer):
                                    values = getattr(my_configs, ff.name)
                                    for i, v in enumerate(values):
                                        if i == 0:
                                            draws += h2_format.format(ff.name,
                                                                      str(v))
                                        else:
                                            draws += h2_format.format("",
                                                                      str(v))
                                else:
                                    draws += h2_format.format(
                                        ff.name,
                                        str(getattr(my_configs, ff.name)))
                    else:
                        env_draws += h2_format.format(
                            f.name, str(getattr(self.strategy, f.name)))
                else:
                    env_draws += h2_format.format(
                        f.name, str(getattr(self.strategy, f.name)))

        result_res = draws + border + "\n" + h1_format.format(
            "Environment Flags, Communication Flags")
        result_res += env_draws

        build_strategy_str = border + "\n"
        build_strategy_str += h1_format.format("Build Strategy")
        build_strategy_str += line + "\n"

        fields = self.strategy.build_strategy.DESCRIPTOR.fields
D
Dong Daxiang 已提交
1285
        for f in fields:
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300
            build_strategy_str += h2_format.format(
                f.name, str(getattr(self.strategy.build_strategy, f.name)))
        build_strategy_str += border + "\n"

        execution_strategy_str = h1_format.format("Execution Strategy")
        execution_strategy_str += line + "\n"

        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            execution_strategy_str += h2_format.format(
                f.name, str(getattr(self.strategy.execution_strategy, f.name)))
        execution_strategy_str += border + "\n"

        result_res += build_strategy_str + execution_strategy_str
        return result_res