distributed_strategy.py 35.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
from paddle.distributed.fleet.proto import distributed_strategy_pb2
17
from paddle.fluid.framework import Variable, set_flags, core
18
import google.protobuf.text_format
19

20 21
__all__ = ["DistributedStrategy"]

22

23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
def get_msg_dict(msg):
    res_dict = {}
    fields = msg.DESCRIPTOR.fields
    for f in fields:
        res_dict[f.name] = getattr(msg, f.name)
    return res_dict


def assign_configs_value(msg, config):
    fields = msg.DESCRIPTOR.fields
    for key in config:
        for f in fields:
            if key == f.name:
                if f.label == 3:
                    getattr(msg, f.name).extend(config[f.name])
                elif f.label == 1 or f.label == 2:
                    setattr(msg, f.name, config[f.name])


def check_configs_key(msg, config, field_name):
    key_list = msg.DESCRIPTOR.fields_by_name.keys()
    for key in config:
        assert key in key_list, "key:{} not in {}".format(key, field_name)


48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
class DistributedJobInfo(object):
    """
    DistributedJobInfo will serialize all distributed training information
    Just for inner use: 1) debug 2) replicate experiments
    """

    def __init__(self):
        self.job_info = distributed_strategy_pb2.DistributedJobInfo()

    def _set_worker_num(self, worker_num):
        self.job_info.worker_num = worker_num

    def _set_server_num(self, server_num):
        self.job_info.server_num = server_num

    def _set_worker_ips(self, worker_ips):
        self.job_info.worker_ips.extend(worker_ips)

    def _set_server_endpoints(self, server_endpoints):
        self.job_info.server_endpoints.extend(server_endpoints)

    def _set_origin_startup(self, origin_startup_prog):
        self.job_info.origin_startup = str(origin_startup_prog)

    def _set_origin_main(self, origin_main_prog):
        self.job_info.origin_main = str(origin_main_prog)

    def _distributed_main(self, distributed_main_prog):
        self.job_info.distributed_main = str(distributed_main_prog)

    def _optimizer_name(self, optimizer_name):
        self.job_info.optimizer_name = optimizer_name

    def _set_distributed_strategy(self, dist_strategy):
        self.job_info.strategy = dist_strategy


class DistributedStrategy(object):
86 87
    __lock_attr = False

88
    def __init__(self):
89 90 91 92 93 94 95 96 97 98 99 100
        """
        DistributedStrategy is the main configuration entry for distributed training of Paddle.
        All of the distributed training configurations can be configured in DistributedStrategy,
        such as automatic mixed precision (AMP), Layer-wise Adaptive Rate Scaling (LARS), 
        asynchronous update parameter server(ASGD), etc.
        
        DistributedStrategy can be serialized into protobuf file or deserialized from protobuf file

        Users who run local training usually configure BuildStrategy and ExecutionStrategy, and 
        DistributedStrategy supports configurations from BuildStrategy and ExecutionStrategy

        """
101
        self.strategy = distributed_strategy_pb2.DistributedStrategy()
102 103 104 105 106 107 108
        self.__lock_attr = True

    def __setattr__(self, key, value):
        if self.__lock_attr and not hasattr(self, key):
            raise TypeError("%s is not a attribute of %s" %
                            (key, self.__class__.__name__))
        object.__setattr__(self, key, value)
109

110
    def save_to_prototxt(self, output):
111 112 113 114 115 116
        """
        Serialize current DistributedStrategy to string and save to output file

        Examples:
          .. code-block:: python
        
117
            import paddle.distributed.fleet as fleet
118 119 120
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True
            strategy.recompute = True
M
mapingshuo 已提交
121
            strategy.recompute_configs = {"checkpoints": ["x"]}
122 123
            strategy.save_to_prototxt("dist_strategy.prototxt")
        """
124 125 126 127
        with open(output, "w") as fout:
            fout.write(str(self.strategy))

    def load_from_prototxt(self, pb_file):
128 129 130 131 132 133
        """
        Load from prototxt file for DistributedStrategy initialization

        Examples:
          .. code-block:: python

134
            import paddle.distributed.fleet as fleet
135
            strategy = fleet.DistributedStrategy()
M
mapingshuo 已提交
136
            strategy.load_from_prototxt("dist_strategy.prototxt")
137 138 139 140 141 142 143 144 145 146 147 148 149
        """
        with open(pb_file, 'r') as f:
            self.strategy = google.protobuf.text_format.Merge(
                str(f.read()), self.strategy)

    @property
    def execution_strategy(self):
        """
        Configure ExecutionStrategy for DistributedStrategy

        Examples:
          .. code-block:: python

M
mapingshuo 已提交
150
            import paddle
151 152 153 154 155
            exe_strategy = paddle.fluid.ExecutionStrategy()
            exe_strategy.num_threads = 10
            exe_strategy.num_iteration_per_drop_scope = 10
            exe_strategy.num_iteration_per_run = 10

156
            strategy = paddle.distributed.fleet.DistributedStrategy()
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
            strategy.execution_strategy = exe_strategy
        """
        execution_strategy = paddle.fluid.ExecutionStrategy()
        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(execution_strategy, f.name,
                    getattr(self.strategy.execution_strategy, f.name))
        return execution_strategy

    @execution_strategy.setter
    def execution_strategy(self, strategy):
        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(self.strategy.execution_strategy, f.name,
                    getattr(strategy, f.name))

    @property
    def build_strategy(self):
        """
        Configure BuildStrategy for DistributedStrategy
        Note that the properties of BuildStrategy are valid in DistributedStrategy
        only if the property is non-distributed strategy.

        Examples:
          .. code-block:: python

M
mapingshuo 已提交
183
            import paddle
184 185 186 187 188 189 190 191 192 193
            build_strategy = paddle.fluid.BuildStrategy()
            build_strategy.enable_sequential_execution = True
            build_strategy.fuse_elewise_add_act_ops = True
            build_strategy.fuse_bn_act_ops = True
            build_strategy.enable_auto_fusion = True
            build_strategy.fuse_relu_depthwise_conv = True
            build_strategy.fuse_broadcast_ops = True
            build_strategy.fuse_all_optimizer_ops = True
            build_strategy.enable_inplace = True
            
194
            strategy = paddle.distributed.fleet.DistributedStrategy()
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
            strategy.build_strategy = build_strategy
        """

        build_strategy = paddle.fluid.BuildStrategy()
        fields = self.strategy.build_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(build_strategy, f.name,
                    getattr(self.strategy.build_strategy, f.name))
        return build_strategy

    @build_strategy.setter
    def build_strategy(self, strategy):
        fields = self.strategy.build_strategy.DESCRIPTOR.fields
        for f in fields:
            if f.label == 1 or f.label == 2:  # optional and required field
                setattr(self.strategy.build_strategy, f.name,
                        getattr(strategy, f.name))
            elif f.label == 3:  # repeated field
                getattr(self.strategy.build_strategy,
                        f.name).extend(getattr(strategy, f.name))

    @property
D
Dong Daxiang 已提交
217
    def a_sync(self):
218 219 220 221 222 223 224 225 226
        """
        Indicating whether we are using asynchronous stocastic gradient descent updates
        for training. This property is valid when we are using parameter server training, 
        which is implied by setting approperate RoleMaker
        Default value: True

        Examples:
          .. code-block:: python

227
            import paddle.distributed.fleet as fleet
228 229 230 231
            role_maker = fleet.PaddleCloudRoleMaker()
            fleet.init(role_maker)

            strategy = fleet.DistributedStrategy()
D
Dong Daxiang 已提交
232
            strategy.a_sync = True  # by default this is True
233 234 235 236
            
            # code block for defining loss and local optimizer
            # sgd = fleet.distributed_optimizer(optimizer, strategy)
        """
D
Dong Daxiang 已提交
237
        return self.strategy.a_sync
238

D
Dong Daxiang 已提交
239 240
    @a_sync.setter
    def a_sync(self, flag):
241
        if isinstance(flag, bool):
D
Dong Daxiang 已提交
242
            self.strategy.a_sync = flag
243
            self.a_sync_configs = {"k_steps": 0}
244
        else:
245 246 247
            raise ValueError(
                "The type of `flag` is invalid, expected type is bool, but received %s".
                format(type(flag)))
248 249

    @property
D
Dong Daxiang 已提交
250
    def a_sync_configs(self):
251
        """
D
Dong Daxiang 已提交
252
        Set a_sync update configurations. In general, asynchronous parameter server
253 254
        training has serveral configurable settings that can be configured through
        a dict.
255

256
        **Notes**:
M
mapingshuo 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269
            k_step(int): number of local optimization updates before communication

            max_merge_var_num(int): maximum number of merged gradients before communication

            send_queue_size(int): a buffer size of worker communication

            independent_recv_thread(bool): if we are using independent recv thread for communication

            thread_pool_size(int): number of thread pool

            send_wait_times(int): waiting time for sending gradients

            runtime_split_send_recv(bool): if we are using Tensor split for send and recv during runtime
270

271 272
        Examples:
          .. code-block:: python
273

274
            import paddle.distributed.fleet as fleet
275 276
            role_maker = fleet.PaddleCloudRoleMaker()
            fleet.init(role_maker)
277

278
            strategy = fleet.DistributedStrategy()
D
Dong Daxiang 已提交
279
            strategy.a_sync = True  # by default this is True
M
mapingshuo 已提交
280
            configs = {"k_steps": 1024, "send_queue_size": 32}
D
Dong Daxiang 已提交
281
            strategy.a_sync_configs = configs
282

283 284
            # code block for defining loss and local optimizer
            # sgd = fleet.distributed_optimizer(optimizer, strategy)
M
mapingshuo 已提交
285

286
        """
D
Dong Daxiang 已提交
287
        return get_msg_dict(self.strategy.a_sync_configs)
288

D
Dong Daxiang 已提交
289 290 291 292 293
    @a_sync_configs.setter
    def a_sync_configs(self, configs):
        check_configs_key(self.strategy.a_sync_configs, configs,
                          "a_sync_configs")
        assign_configs_value(self.strategy.a_sync_configs, configs)
294

295
    @property
296 297 298 299
    def amp(self):
        """
        Indicating whether we are using automatic mixed precision training
        Default Value: False
300

301 302
        Examples:
          .. code-block:: python
303

304
            import paddle.distributed.fleet as fleet
305 306
            strategy = fleet.DistributedStrategy()
            strategy.amp = True # by default this is false
307

308 309
        """
        return self.strategy.amp
310

311 312
    @amp.setter
    def amp(self, flag):
313
        if isinstance(flag, bool):
314
            self.strategy.amp = flag
315
        else:
316
            print("WARNING: amp should have value of bool type")
317 318

    @property
319
    def amp_configs(self):
320 321 322 323 324
        """
        Set automatic mixed precision training configurations. In general, amp has serveral configurable
        settings that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
            init_loss_scaling(float): The initial loss scaling factor. Default 32768.

            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. Default True.

            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. Default 1000.

            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. Default 2.

            incr_ratio(float): The multiplier to use when increasing the loss scaling. Default 2.0.

            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. Default 0.5.

            custom_white_list(list[str]): Users' custom white list which always execution fp16.

            custom_black_list(list[str]): Users' custom black list which forbidden execution fp16.
340 341 342 343 344 345 346 347 348 349 350

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.amp = True
            strategy.amp_configs = {
                "init_loss_scaling": 32768,
                "custom_white_list": ['conv2d']}
        """
351
        return get_msg_dict(self.strategy.amp_configs)
352

353 354 355 356
    @amp_configs.setter
    def amp_configs(self, configs):
        check_configs_key(self.strategy.amp_configs, configs, "amp_configs")
        assign_configs_value(self.strategy.amp_configs, configs)
357 358

    @property
359 360 361 362 363 364 365 366
    def recompute(self):
        """
        Indicating whether we are using forward recomputation for memory optimization
        Default value: False

        Examples:
          .. code-block:: python

367
            import paddle.distributed.fleet as fleet
368 369 370 371 372 373
            strategy = fleet.DistributedStrategy()
            strategy.recompute = True
            # suppose x and y are names of checkpoint tensors for recomputation
            strategy.recompute_configs = {"checkpoints": ["x", "y"]}
        """
        return self.strategy.recompute
374

375 376
    @property
    def sync_nccl_allreduce(self):
377 378 379 380 381 382 383 384 385 386 387
        """
        Indicating whether we are using synchronized all reduce in each communication thread
        We note that system overhead is usually lower when sync_nccl_allreduce = True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.sync_nccl_allreduce = True
        """
388 389 390 391 392 393 394
        return self.strategy.sync_nccl_allreduce

    @sync_nccl_allreduce.setter
    def sync_nccl_allreduce(self, flag):
        if isinstance(flag, bool):
            self.strategy.sync_nccl_allreduce = flag
        else:
395
            print("WARNING: sync_nccl_allreduce should have value of bool type")
396

397
    @property
398
    def use_hierarchical_allreduce(self):
399 400 401 402 403 404 405 406 407 408 409 410
        """
        Indicating whether we are using hierarchical allreduce in collective communication
        Hierarchical allreduce often does allreduce within a certain node group and then do
        allreduce among the leaders of each group

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.use_hierarchical_allreduce = True
        """
411
        return self.strategy.use_hierarchical_allreduce
412

413 414
    @use_hierarchical_allreduce.setter
    def use_hierarchical_allreduce(self, flag):
415
        if isinstance(flag, bool):
416
            self.strategy.use_hierarchical_allreduce = flag
417 418
        else:
            print(
419
                "WARNING: use_hierarchical_allreduce should have value of bool type"
420 421 422
            )

    @property
423
    def hierarchical_allreduce_inter_nranks(self):
424 425 426 427 428 429 430 431 432 433 434
        """
        Number of ranks for low level node groups in hierarchical allreduce
        Default value: number of GPU cards on each single GPU machine

        Example:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.hierarchical_allreduce_inter_nranks = 8
        """
435
        return self.strategy.hierarchical_allreduce_inter_nranks
436

437 438 439 440
    @hierarchical_allreduce_inter_nranks.setter
    def hierarchical_allreduce_inter_nranks(self, value):
        if isinstance(value, int):
            self.strategy.hierarchical_allreduce_inter_nranks = value
441 442
        else:
            print(
443
                "WARNING: hierarchical_allreduce_inter_nranks should have value of int type"
444 445
            )

446
    @property
447
    def sync_batch_norm(self):
448 449 450 451 452 453 454 455 456 457 458 459 460
        """
        Indicating whether we are using sync_batch_norm to do synchronous batch normalization among all training nodes.
        
        Default value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.sync_batch_norm = True
        """

461
        return self.strategy.sync_batch_norm
462

463 464
    @sync_batch_norm.setter
    def sync_batch_norm(self, flag):
465
        if isinstance(flag, bool):
466
            self.strategy.sync_batch_norm = flag
467
        else:
468
            print("WARNING: sync_batch_norm should have value of bool type")
469 470 471

    @property
    def fuse_all_reduce_ops(self):
472 473 474 475 476 477 478 479 480 481 482
        """
        Indicating whether we are using fuse_all_reduce_ops for gradient fusion during backward phase of training
        Default value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fuse_all_reduce_ops = False
        """
483 484 485 486 487 488 489 490 491
        return self.strategy.fuse_all_reduce_ops

    @fuse_all_reduce_ops.setter
    def fuse_all_reduce_ops(self, flag):
        if isinstance(flag, bool):
            self.strategy.fuse_all_reduce_ops = flag
        else:
            print("WARNING: fuse_all_reduce_ops should have value of bool type")

492 493
    @property
    def fuse_grad_size_in_MB(self):
494 495 496 497 498 499 500 501 502 503 504 505
        """
        Specifying the size of gradient to fuse in Mega-Bytes

        Default value: 32

        Examples:
          .. code-block:: python
        
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fuse_grad_size_in_MB = 50
        """
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
        return self.strategy.fuse_grad_size_in_MB

    @fuse_grad_size_in_MB.setter
    def fuse_grad_size_in_MB(self, value):
        if isinstance(value, int):
            self.strategy.fuse_grad_size_in_MB = value
        else:
            print("WARNING: fuse_grad_size_in_MB should have value of int type")

    @property
    def _fuse_grad_size_in_TFLOPS(self):
        return self.strategy.fuse_grad_size_in_TFLOPS

    @_fuse_grad_size_in_TFLOPS.setter
    def _fuse_grad_size_in_TFLOPS(self, value):
        if isinstance(value, float):
            self.strategy.fuse_grad_size_in_TFLOPS = value
        else:
            print(
                "WARNING: fuse_grad_size_in_TFLOPS should have value of float type"
            )

528
    @property
529
    def nccl_comm_num(self):
530 531 532 533 534 535 536 537 538 539 540 541 542
        """
        Specifying the number of NCCL communicator

        Default value: 1

        Examples:
          .. code-block:: python
        
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.nccl_comm_num = 2
        """

543
        return self.strategy.nccl_comm_num
544

545 546
    @nccl_comm_num.setter
    def nccl_comm_num(self, value):
547
        if isinstance(value, int):
548
            self.strategy.nccl_comm_num = value
549
        else:
550
            print("WARNING: nccl_comm_num should have value of int type")
551

552 553
    @recompute.setter
    def recompute(self, flag):
554
        if isinstance(flag, bool):
555
            self.strategy.recompute = flag
556
        else:
557
            print("WARNING: recompute should have value of bool type")
558 559

    @property
560 561 562 563
    def recompute_configs(self):
        """
        Set recompute configurations. In general, the recompute strategy of current
        implementation should have some manually assign checkpoints
564

565 566 567
        Examples:
          .. code-block:: python
        
568
            import paddle.distributed.fleet as fleet
569 570
            strategy = fleet.DistributedStrategy()
            strategy.recompute = True
M
mapingshuo 已提交
571
            strategy.recompute_configs = {"checkpoints": ["x", "y"]}
572 573 574 575 576 577 578 579 580

        """
        return get_msg_dict(self.strategy.recompute_configs)

    @recompute_configs.setter
    def recompute_configs(self, configs):
        check_configs_key(self.strategy.recompute_configs, configs,
                          "checkpoint_configs")
        assign_configs_value(self.strategy.recompute_configs, configs)
581 582

    @property
583 584 585 586 587 588 589 590 591 592
    def pipeline(self):
        """
        Indicating whether we are using pipeline parallelism for distributed training.
        Current implementation mainly focus on single GPU machine pipeline parallelism and
        data parallelism across GPU machine. The pipeline information is indicated through
        device_guard information in user-defined program.

        Examples:
          .. code-block:: python
        
593
            import paddle.distributed.fleet as fleet
594 595 596 597 598
            strategy = fleet.DistributedStrategy()
            strategy.pipeline = True

        """
        return self.strategy.pipeline
599

600 601
    @pipeline.setter
    def pipeline(self, flag):
602
        if isinstance(flag, bool):
603
            self.strategy.pipeline = flag
604
        else:
605
            print("WARNING: pipeline should have value of bool type")
606 607

    @property
608 609 610 611 612 613 614 615 616 617
    def pipeline_configs(self):
        """
        Set pipeline parallelism configurations. In pipeline parallelism,
        different parts of neural networks are running on different GPUS.
        There are Tensor queue buffer between each pair of neighborhood GPUS 
        that are responsible for synchronizing hidden Tensor results between
        GPUs. Pipeline parallelism consists of serveral producer-consumer style
        hardware pairs, such as GPU-GPU, CPU-GPU, GPU-XPU. The best way to speedup
        pipeline parallelism is to make the size of Tensor in Tensor queue smaller, 
        so that we will have a faster producer for downstream consumers.
618

619 620
        **Notes**:
            **Detailed arguments for pipeline_configs**
M
mapingshuo 已提交
621

622
            **micro_batch**: the number of small batches in each user defined batch
623

624 625 626
        Examples:
          .. code-block:: python
        
627
            import paddle.distributed.fleet as fleet
628 629 630
            strategy = fleet.DistributedStrategy()
            strategy.pipeline = True
            strategy.pipeline_configs = {"micro_batch": 12}
631

632
        """
633

634
        return get_msg_dict(self.strategy.pipeline_configs)
635

636 637 638 639 640
    @pipeline_configs.setter
    def pipeline_configs(self, configs):
        check_configs_key(self.strategy.pipeline_configs, configs,
                          "pipeline_configs")
        assign_configs_value(self.strategy.pipeline_configs, configs)
641 642

    @property
643
    def localsgd(self):
644
        """
M
mapingshuo 已提交
645 646 647
        Indicating whether we are using Local SGD training. Default Value: False
        For more details, please refer to
        `Don't Use Large Mini-Batches, Use Local SGD <https://arxiv.org/pdf/1808.07217.pdf>`_.
648 649 650 651 652 653 654 655 656 657


        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.localsgd = True # by default this is false

        """
658
        return self.strategy.localsgd
659

660 661 662 663
    @localsgd.setter
    def localsgd(self, flag):
        if isinstance(flag, bool):
            self.strategy.localsgd = flag
664
        else:
665
            print("WARNING: localsgd should have value of bool type")
666 667

    @property
668
    def localsgd_configs(self):
669 670 671 672 673
        """
        Set LocalSGD training configurations. LocalSGD has a configurable
        setting that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
674 675 676 677 678 679
            k_steps(int) The local steps for training before parameter synchronization. Default 1.

            If strategy.auto is set True, the local steps will be calculated automatically during training.
            The algorithm is referenced in this paper: 
            `Adaptive Communication Strategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD <https://arxiv.org/pdf/1810.08313.pdf>`_.
            In this case, k_steps indicates the first local steps which is suggested setting to 1.
680 681 682 683 684 685 686 687 688 689

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.localsgd = True
            strategy.localsgd_configs = {"k_steps": 4}
        """

690
        return get_msg_dict(self.strategy.localsgd_configs)
691

692 693 694 695 696
    @localsgd_configs.setter
    def localsgd_configs(self, configs):
        check_configs_key(self.strategy.localsgd_configs, configs,
                          "localsgd_configs")
        assign_configs_value(self.strategy.localsgd_configs, configs)
697 698

    @property
699
    def dgc(self):
700 701 702 703 704 705 706 707 708 709 710 711 712 713
        """
        Indicating whether we are using Deep Gradient Compression training. For more details, please refer to
        [Deep Gradient Compression](https://arxiv.org/abs/1712.01887).

        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True # by default this is false

        """
714
        return self.strategy.dgc
715

716 717 718 719
    @dgc.setter
    def dgc(self, flag):
        if isinstance(flag, bool):
            self.strategy.dgc = flag
720
        else:
721
            print("WARNING: dgc should have value of bool type")
722 723

    @property
724
    def dgc_configs(self):
725 726 727 728 729
        """
        Set Deep Gradient Compression training configurations. In general, dgc has serveral configurable
        settings that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
730 731 732 733 734 735 736 737 738 739
            rampup_begin_step(int): The beginning step from which gradient compression is implemented. Default 0.

            rampup_step(int): Time steps used in sparsity warm-up periods. Default is 1. \
                    For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100, \
                    it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. And when reach sparsity array \
                    ends, it will use 0.999 then and after.

            sparsity(list[float]): Get top important element from gradient tensor, the ratio is (1 - sparsity). \
                    Default is [0.999]. For example, if the sparsity is [0.99, 0.999], the top [1%, 0.1%] important \
                    element will be transmitted.
740 741 742 743 744 745 746 747 748

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True
            strategy.dgc_configs = {"rampup_begin_step": 1252}
        """
749
        return get_msg_dict(self.strategy.dgc_configs)
750

751 752 753 754
    @dgc_configs.setter
    def dgc_configs(self, configs):
        check_configs_key(self.strategy.dgc_configs, configs, "dgc_configs")
        assign_configs_value(self.strategy.dgc_configs, configs)
755 756

    @property
757
    def gradient_merge(self):
758 759 760 761 762 763 764 765 766 767 768
        """
        Gradient Merge, also called as Gradient Accumulation,
        is a strategy for large batch training. With this strategy,
        model parameter will not be updated until user-defined steps.
        For each step, the forward network and the backward network
        will run to calculate the gradient of model parameters.
        For every k step, the optimization network will run,
        applying a specific optimization method (such as SGD, Adam)
        to model parameters.

        Examples:
M
mapingshuo 已提交
769 770
          .. code-block:: python

771
            import paddle.distributed.fleet as fleet
772 773 774 775
            strategy = fleet.DistributedStrategy()
            strategy.gradient_merge = True
            strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
        """
776
        return self.strategy.gradient_merge
777

778 779
    @gradient_merge.setter
    def gradient_merge(self, flag):
780
        if isinstance(flag, bool):
781
            self.strategy.gradient_merge = flag
782
        else:
783 784 785 786
            print("WARNING: gradient_merge should have value of bool type")

    @property
    def gradient_merge_configs(self):
787 788
        """
        the key-value configs of distribute_strategy
M
mapingshuo 已提交
789 790 791 792 793 794 795 796 797

        **Note**:
            k_steps(int): the update period of the parameters.

            avg(bool): whether to average the gradients of each mini-batch, the default value is `True`

        Examples:
          .. code-block:: python

798
            import paddle.distributed.fleet as fleet
799 800 801 802
            strategy = fleet.DistributedStrategy()
            strategy.gradient_merge = True
            strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
        """
803 804 805 806 807 808 809
        return get_msg_dict(self.strategy.gradient_merge_configs)

    @gradient_merge_configs.setter
    def gradient_merge_configs(self, configs):
        check_configs_key(self.strategy.gradient_merge_configs, configs,
                          "gradient_configs")
        assign_configs_value(self.strategy.gradient_merge_configs, configs)
810 811

    @property
812
    def lars(self):
813 814 815 816 817 818 819 820 821 822 823 824 825 826
        """
        Set lars configurations. lars is used to deal with the convergence problems when the global 
        batch size is larger than 8k.  For more details, please refer to 
        [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888).

        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lars = True # by default this is false
        """
827
        return self.strategy.lars
828

829 830
    @lars.setter
    def lars(self, flag):
831
        if isinstance(flag, bool):
832
            self.strategy.lars = flag
833
        else:
834
            print("WARNING: lars should have value of bool type")
835

836 837
    @property
    def lars_configs(self):
838 839 840 841 842 843 844 845 846 847 848 849 850
        """
        Set Lars training configurations.

        **Notes**:
        **lars_coeff (float)**: trust ratio in lars formula.
        **lars_weight_decay** (float): weight decay coefficient in lars formula.
        **epsilon (float)**: argument is used to avoid potential devision-by-zero 
        when compute the local lr; 
        **exclude_from_weight_decay ([string])**: is a list of name strings of layers which
        will be exclude from weight decay in lars formula.

        Examples:
          .. code-block:: python
M
mapingshuo 已提交
851

852 853 854 855 856 857 858 859 860 861
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lars = True
            strategy.lars_configs = {
                        "lars_coeff": 0.01,
                        "lars_weight_decay": 0.0005,
                        "epsilon": 0,
                        "exclude_from_weight_decay": ['batch_norm', '.b_0']
                    }
        """
862 863 864 865 866 867 868
        return get_msg_dict(self.strategy.lars_configs)

    @lars_configs.setter
    def lars_configs(self, configs):
        check_configs_key(self.strategy.lars_configs, configs, "lars_configs")
        assign_configs_value(self.strategy.lars_configs, configs)

869
    @property
870
    def lamb(self):
871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886
        """
        Set lamb configurations. lamb is used to deal with the convergence problems for large 
        batch size training, specially for attention-related model like BERT. For more details, 
        please refer to 
        [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962).

        Default Value: False
        
        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lamb = True # by default this is false
        """

887
        return self.strategy.lamb
888

889 890
    @lamb.setter
    def lamb(self, flag):
891
        if isinstance(flag, bool):
892
            self.strategy.lamb = flag
893
        else:
894
            print("WARNING: lamb should have value of bool type")
895

896 897
    @property
    def lamb_configs(self):
898 899 900 901 902 903 904 905 906 907
        """
        Set Lars training configurations.

        **Notes**:
        **lamb_weight_decay** (float): weight decay coefficient in lamb formula.
        **exclude_from_weight_decay ([string])**: is a list of name strings of layers which
        will be exclude from weight decay in lamb formula.

        Examples:
          .. code-block:: python
M
mapingshuo 已提交
908

909 910 911 912 913 914 915 916
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lamb = True
            strategy.lamb_configs = {
                    'lamb_weight_decay': 0.01,
                    'exclude_from_weight_decay': [],
                }
        """
917 918 919 920 921 922 923
        return get_msg_dict(self.strategy.lamb_configs)

    @lamb_configs.setter
    def lamb_configs(self, configs):
        check_configs_key(self.strategy.lamb_configs, configs, "lamb_configs")
        assign_configs_value(self.strategy.lamb_configs, configs)

924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945
    @property
    def elastic(self):
        return self.strategy.elastic

    @elastic.setter
    def elastic(self, flag):
        if isinstance(flag, bool):
            self.strategy.elastic = flag
        else:
            print("WARNING: elastic should have value of bool type")

    @property
    def auto(self):
        return self.strategy.auto

    @auto.setter
    def auto(self, flag):
        if isinstance(flag, bool):
            self.strategy.auto = flag
        else:
            print("WARNING: auto should have value of bool type")

946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007
    @property
    def cudnn_exhaustive_search(self):
        return self.strategy.cudnn_exhaustive_search

    @cudnn_exhaustive_search.setter
    def cudnn_exhaustive_search(self, flag):
        if isinstance(flag, bool):
            self.strategy.cudnn_exhaustive_search = flag
        else:
            print(
                "WARNING: cudnn_exhaustive_search should have value of bool type"
            )

    @property
    def conv_workspace_size_limit(self):
        return self.strategy.conv_workspace_size_limit

    @conv_workspace_size_limit.setter
    def conv_workspace_size_limit(self, value):
        if isinstance(value, int):
            self.strategy.conv_workspace_size_limit = value
        else:
            print(
                "WARNING: conv_workspace_size_limit should have value of int type"
            )

    @property
    def cudnn_batchnorm_spatial_persistent(self):
        return self.strategy.cudnn_batchnorm_spatial_persistent

    @cudnn_batchnorm_spatial_persistent.setter
    def cudnn_batchnorm_spatial_persistent(self, flag):
        if isinstance(flag, bool):
            self.strategy.cudnn_batchnorm_spatial_persistent = flag
        else:
            print(
                "WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type"
            )

    def _enable_env(self):
        strategy = self.strategy
        keys = [
            "FLAGS_cudnn_batchnorm_spatial_persistent",
            "FLAGS_conv_workspace_size_limit",
            "FLAGS_cudnn_exhaustive_search",
            "FLAGS_sync_nccl_allreduce",
            "FLAGS_fuse_parameter_memory_size",
            "FLAGS_fuse_parameter_groups_size",
        ]
        values = [
            bool(strategy.cudnn_batchnorm_spatial_persistent),
            int(strategy.conv_workspace_size_limit),
            bool(strategy.cudnn_exhaustive_search),
            bool(strategy.sync_nccl_allreduce),
            int(strategy.fuse_grad_size_in_MB),
            int(strategy.fuse_grad_size_in_TFLOPS),
        ]

        for i, key in enumerate(keys):
            if core.globals().is_public(key):
                core.globals()[key] = values[i]

1008
    def __repr__(self):
D
Dong Daxiang 已提交
1009 1010 1011
        fields = self.strategy.DESCRIPTOR.fields
        for f in fields:
            print("{}: {}".format(f.name, f.default_value))
1012
        return str(self.strategy)