distributed_strategy.py 39.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
from paddle.distributed.fleet.proto import distributed_strategy_pb2
17
from paddle.fluid.framework import Variable, set_flags, core
18
from paddle.fluid.wrapped_decorator import wrap_decorator
19
import google.protobuf.text_format
20

21 22
__all__ = ["DistributedStrategy"]

23 24 25 26 27 28 29 30 31 32 33 34 35 36
non_auto_func_called = True


def __non_auto_func_called__(func):
    def __impl__(*args, **kwargs):
        global non_auto_func_called
        non_auto_func_called = False
        return func(*args, **kwargs)

    return __impl__


is_strict_auto = wrap_decorator(__non_auto_func_called__)

37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
def get_msg_dict(msg):
    res_dict = {}
    fields = msg.DESCRIPTOR.fields
    for f in fields:
        res_dict[f.name] = getattr(msg, f.name)
    return res_dict


def assign_configs_value(msg, config):
    fields = msg.DESCRIPTOR.fields
    for key in config:
        for f in fields:
            if key == f.name:
                if f.label == 3:
                    getattr(msg, f.name).extend(config[f.name])
                elif f.label == 1 or f.label == 2:
                    setattr(msg, f.name, config[f.name])


def check_configs_key(msg, config, field_name):
    key_list = msg.DESCRIPTOR.fields_by_name.keys()
    for key in config:
        assert key in key_list, "key:{} not in {}".format(key, field_name)


63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
class DistributedJobInfo(object):
    """
    DistributedJobInfo will serialize all distributed training information
    Just for inner use: 1) debug 2) replicate experiments
    """

    def __init__(self):
        self.job_info = distributed_strategy_pb2.DistributedJobInfo()

    def _set_worker_num(self, worker_num):
        self.job_info.worker_num = worker_num

    def _set_server_num(self, server_num):
        self.job_info.server_num = server_num

    def _set_worker_ips(self, worker_ips):
        self.job_info.worker_ips.extend(worker_ips)

    def _set_server_endpoints(self, server_endpoints):
        self.job_info.server_endpoints.extend(server_endpoints)

    def _set_origin_startup(self, origin_startup_prog):
        self.job_info.origin_startup = str(origin_startup_prog)

    def _set_origin_main(self, origin_main_prog):
        self.job_info.origin_main = str(origin_main_prog)

    def _distributed_main(self, distributed_main_prog):
        self.job_info.distributed_main = str(distributed_main_prog)

    def _optimizer_name(self, optimizer_name):
        self.job_info.optimizer_name = optimizer_name

    def _set_distributed_strategy(self, dist_strategy):
        self.job_info.strategy = dist_strategy


class DistributedStrategy(object):
101 102
    __lock_attr = False

103
    def __init__(self):
104 105 106 107 108 109 110 111 112 113 114 115
        """
        DistributedStrategy is the main configuration entry for distributed training of Paddle.
        All of the distributed training configurations can be configured in DistributedStrategy,
        such as automatic mixed precision (AMP), Layer-wise Adaptive Rate Scaling (LARS), 
        asynchronous update parameter server(ASGD), etc.
        
        DistributedStrategy can be serialized into protobuf file or deserialized from protobuf file

        Users who run local training usually configure BuildStrategy and ExecutionStrategy, and 
        DistributedStrategy supports configurations from BuildStrategy and ExecutionStrategy

        """
116
        self.strategy = distributed_strategy_pb2.DistributedStrategy()
117 118 119 120 121 122 123
        self.__lock_attr = True

    def __setattr__(self, key, value):
        if self.__lock_attr and not hasattr(self, key):
            raise TypeError("%s is not a attribute of %s" %
                            (key, self.__class__.__name__))
        object.__setattr__(self, key, value)
124

125
    def save_to_prototxt(self, output):
126 127 128 129 130 131
        """
        Serialize current DistributedStrategy to string and save to output file

        Examples:
          .. code-block:: python
        
132
            import paddle.distributed.fleet as fleet
133 134 135
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True
            strategy.recompute = True
M
mapingshuo 已提交
136
            strategy.recompute_configs = {"checkpoints": ["x"]}
137 138
            strategy.save_to_prototxt("dist_strategy.prototxt")
        """
139 140 141 142
        with open(output, "w") as fout:
            fout.write(str(self.strategy))

    def load_from_prototxt(self, pb_file):
143 144 145 146 147 148
        """
        Load from prototxt file for DistributedStrategy initialization

        Examples:
          .. code-block:: python

149
            import paddle.distributed.fleet as fleet
150
            strategy = fleet.DistributedStrategy()
M
mapingshuo 已提交
151
            strategy.load_from_prototxt("dist_strategy.prototxt")
152 153 154 155 156 157 158 159 160 161 162 163 164
        """
        with open(pb_file, 'r') as f:
            self.strategy = google.protobuf.text_format.Merge(
                str(f.read()), self.strategy)

    @property
    def execution_strategy(self):
        """
        Configure ExecutionStrategy for DistributedStrategy

        Examples:
          .. code-block:: python

M
mapingshuo 已提交
165
            import paddle
166 167 168 169 170
            exe_strategy = paddle.fluid.ExecutionStrategy()
            exe_strategy.num_threads = 10
            exe_strategy.num_iteration_per_drop_scope = 10
            exe_strategy.num_iteration_per_run = 10

171
            strategy = paddle.distributed.fleet.DistributedStrategy()
172 173 174 175 176 177 178 179 180 181
            strategy.execution_strategy = exe_strategy
        """
        execution_strategy = paddle.fluid.ExecutionStrategy()
        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(execution_strategy, f.name,
                    getattr(self.strategy.execution_strategy, f.name))
        return execution_strategy

    @execution_strategy.setter
182
    @is_strict_auto
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
    def execution_strategy(self, strategy):
        fields = self.strategy.execution_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(self.strategy.execution_strategy, f.name,
                    getattr(strategy, f.name))

    @property
    def build_strategy(self):
        """
        Configure BuildStrategy for DistributedStrategy
        Note that the properties of BuildStrategy are valid in DistributedStrategy
        only if the property is non-distributed strategy.

        Examples:
          .. code-block:: python

M
mapingshuo 已提交
199
            import paddle
200 201 202 203 204 205 206 207 208 209
            build_strategy = paddle.fluid.BuildStrategy()
            build_strategy.enable_sequential_execution = True
            build_strategy.fuse_elewise_add_act_ops = True
            build_strategy.fuse_bn_act_ops = True
            build_strategy.enable_auto_fusion = True
            build_strategy.fuse_relu_depthwise_conv = True
            build_strategy.fuse_broadcast_ops = True
            build_strategy.fuse_all_optimizer_ops = True
            build_strategy.enable_inplace = True
            
210
            strategy = paddle.distributed.fleet.DistributedStrategy()
211 212 213 214 215 216 217 218 219 220 221
            strategy.build_strategy = build_strategy
        """

        build_strategy = paddle.fluid.BuildStrategy()
        fields = self.strategy.build_strategy.DESCRIPTOR.fields
        for f in fields:
            setattr(build_strategy, f.name,
                    getattr(self.strategy.build_strategy, f.name))
        return build_strategy

    @build_strategy.setter
222
    @is_strict_auto
223 224 225 226 227 228 229 230 231 232 233
    def build_strategy(self, strategy):
        fields = self.strategy.build_strategy.DESCRIPTOR.fields
        for f in fields:
            if f.label == 1 or f.label == 2:  # optional and required field
                setattr(self.strategy.build_strategy, f.name,
                        getattr(strategy, f.name))
            elif f.label == 3:  # repeated field
                getattr(self.strategy.build_strategy,
                        f.name).extend(getattr(strategy, f.name))

    @property
D
Dong Daxiang 已提交
234
    def a_sync(self):
235 236 237 238 239 240 241 242 243
        """
        Indicating whether we are using asynchronous stocastic gradient descent updates
        for training. This property is valid when we are using parameter server training, 
        which is implied by setting approperate RoleMaker
        Default value: True

        Examples:
          .. code-block:: python

244
            import paddle.distributed.fleet as fleet
245 246 247 248
            role_maker = fleet.PaddleCloudRoleMaker()
            fleet.init(role_maker)

            strategy = fleet.DistributedStrategy()
D
Dong Daxiang 已提交
249
            strategy.a_sync = True  # by default this is True
250 251 252 253
            
            # code block for defining loss and local optimizer
            # sgd = fleet.distributed_optimizer(optimizer, strategy)
        """
D
Dong Daxiang 已提交
254
        return self.strategy.a_sync
255

D
Dong Daxiang 已提交
256
    @a_sync.setter
257
    @is_strict_auto
D
Dong Daxiang 已提交
258
    def a_sync(self, flag):
259
        if isinstance(flag, bool):
D
Dong Daxiang 已提交
260
            self.strategy.a_sync = flag
261
            self.a_sync_configs = {"k_steps": 0}
262
        else:
263 264 265
            raise ValueError(
                "The type of `flag` is invalid, expected type is bool, but received %s".
                format(type(flag)))
266 267

    @property
D
Dong Daxiang 已提交
268
    def a_sync_configs(self):
269
        """
D
Dong Daxiang 已提交
270
        Set a_sync update configurations. In general, asynchronous parameter server
271 272
        training has serveral configurable settings that can be configured through
        a dict.
273

274
        **Notes**:
M
mapingshuo 已提交
275 276 277 278 279 280 281 282 283 284 285 286 287
            k_step(int): number of local optimization updates before communication

            max_merge_var_num(int): maximum number of merged gradients before communication

            send_queue_size(int): a buffer size of worker communication

            independent_recv_thread(bool): if we are using independent recv thread for communication

            thread_pool_size(int): number of thread pool

            send_wait_times(int): waiting time for sending gradients

            runtime_split_send_recv(bool): if we are using Tensor split for send and recv during runtime
288

289 290
        Examples:
          .. code-block:: python
291

292
            import paddle.distributed.fleet as fleet
293 294
            role_maker = fleet.PaddleCloudRoleMaker()
            fleet.init(role_maker)
295

296
            strategy = fleet.DistributedStrategy()
D
Dong Daxiang 已提交
297
            strategy.a_sync = True  # by default this is True
M
mapingshuo 已提交
298
            configs = {"k_steps": 1024, "send_queue_size": 32}
D
Dong Daxiang 已提交
299
            strategy.a_sync_configs = configs
300

301 302
            # code block for defining loss and local optimizer
            # sgd = fleet.distributed_optimizer(optimizer, strategy)
M
mapingshuo 已提交
303

304
        """
D
Dong Daxiang 已提交
305
        return get_msg_dict(self.strategy.a_sync_configs)
306

D
Dong Daxiang 已提交
307
    @a_sync_configs.setter
308
    @is_strict_auto
D
Dong Daxiang 已提交
309 310 311 312
    def a_sync_configs(self, configs):
        check_configs_key(self.strategy.a_sync_configs, configs,
                          "a_sync_configs")
        assign_configs_value(self.strategy.a_sync_configs, configs)
313

314
    @property
315 316 317 318
    def amp(self):
        """
        Indicating whether we are using automatic mixed precision training
        Default Value: False
319

320 321
        Examples:
          .. code-block:: python
322

323
            import paddle.distributed.fleet as fleet
324 325
            strategy = fleet.DistributedStrategy()
            strategy.amp = True # by default this is false
326

327 328
        """
        return self.strategy.amp
329

330
    @amp.setter
331
    @is_strict_auto
332
    def amp(self, flag):
333
        if isinstance(flag, bool):
334
            self.strategy.amp = flag
335
        else:
336
            print("WARNING: amp should have value of bool type")
337 338

    @property
339
    def amp_configs(self):
340 341 342 343 344
        """
        Set automatic mixed precision training configurations. In general, amp has serveral configurable
        settings that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
            init_loss_scaling(float): The initial loss scaling factor. Default 32768.

            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. Default True.

            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. Default 1000.

            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. Default 2.

            incr_ratio(float): The multiplier to use when increasing the loss scaling. Default 2.0.

            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. Default 0.5.

            custom_white_list(list[str]): Users' custom white list which always execution fp16.

            custom_black_list(list[str]): Users' custom black list which forbidden execution fp16.
360 361 362 363 364 365 366 367 368 369 370

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.amp = True
            strategy.amp_configs = {
                "init_loss_scaling": 32768,
                "custom_white_list": ['conv2d']}
        """
371
        return get_msg_dict(self.strategy.amp_configs)
372

373
    @amp_configs.setter
374
    @is_strict_auto
375 376 377
    def amp_configs(self, configs):
        check_configs_key(self.strategy.amp_configs, configs, "amp_configs")
        assign_configs_value(self.strategy.amp_configs, configs)
378 379

    @property
380 381 382 383 384 385 386 387
    def recompute(self):
        """
        Indicating whether we are using forward recomputation for memory optimization
        Default value: False

        Examples:
          .. code-block:: python

388
            import paddle.distributed.fleet as fleet
389 390 391 392 393 394
            strategy = fleet.DistributedStrategy()
            strategy.recompute = True
            # suppose x and y are names of checkpoint tensors for recomputation
            strategy.recompute_configs = {"checkpoints": ["x", "y"]}
        """
        return self.strategy.recompute
395

396 397
    @property
    def sync_nccl_allreduce(self):
398 399 400 401 402 403 404 405 406 407 408
        """
        Indicating whether we are using synchronized all reduce in each communication thread
        We note that system overhead is usually lower when sync_nccl_allreduce = True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.sync_nccl_allreduce = True
        """
409 410 411
        return self.strategy.sync_nccl_allreduce

    @sync_nccl_allreduce.setter
412
    @is_strict_auto
413 414 415 416
    def sync_nccl_allreduce(self, flag):
        if isinstance(flag, bool):
            self.strategy.sync_nccl_allreduce = flag
        else:
417
            print("WARNING: sync_nccl_allreduce should have value of bool type")
418

419
    @property
420
    def use_hierarchical_allreduce(self):
421 422 423 424 425 426 427 428 429 430 431 432
        """
        Indicating whether we are using hierarchical allreduce in collective communication
        Hierarchical allreduce often does allreduce within a certain node group and then do
        allreduce among the leaders of each group

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.use_hierarchical_allreduce = True
        """
433
        return self.strategy.use_hierarchical_allreduce
434

435
    @use_hierarchical_allreduce.setter
436
    @is_strict_auto
437
    def use_hierarchical_allreduce(self, flag):
438
        if isinstance(flag, bool):
439
            self.strategy.use_hierarchical_allreduce = flag
440 441
        else:
            print(
442
                "WARNING: use_hierarchical_allreduce should have value of bool type"
443 444 445
            )

    @property
446
    def hierarchical_allreduce_inter_nranks(self):
447 448 449 450 451 452 453 454 455 456 457
        """
        Number of ranks for low level node groups in hierarchical allreduce
        Default value: number of GPU cards on each single GPU machine

        Example:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.hierarchical_allreduce_inter_nranks = 8
        """
458
        return self.strategy.hierarchical_allreduce_inter_nranks
459

460
    @hierarchical_allreduce_inter_nranks.setter
461
    @is_strict_auto
462 463 464
    def hierarchical_allreduce_inter_nranks(self, value):
        if isinstance(value, int):
            self.strategy.hierarchical_allreduce_inter_nranks = value
465 466
        else:
            print(
467
                "WARNING: hierarchical_allreduce_inter_nranks should have value of int type"
468 469
            )

470
    @property
471
    def sync_batch_norm(self):
472 473 474 475 476 477 478 479 480 481 482 483 484
        """
        Indicating whether we are using sync_batch_norm to do synchronous batch normalization among all training nodes.
        
        Default value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.sync_batch_norm = True
        """

485
        return self.strategy.sync_batch_norm
486

487
    @sync_batch_norm.setter
488
    @is_strict_auto
489
    def sync_batch_norm(self, flag):
490
        if isinstance(flag, bool):
491
            self.strategy.sync_batch_norm = flag
492
        else:
493
            print("WARNING: sync_batch_norm should have value of bool type")
494 495 496

    @property
    def fuse_all_reduce_ops(self):
497 498 499 500 501 502 503 504 505 506 507
        """
        Indicating whether we are using fuse_all_reduce_ops for gradient fusion during backward phase of training
        Default value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fuse_all_reduce_ops = False
        """
508 509 510
        return self.strategy.fuse_all_reduce_ops

    @fuse_all_reduce_ops.setter
511
    @is_strict_auto
512 513 514 515 516 517
    def fuse_all_reduce_ops(self, flag):
        if isinstance(flag, bool):
            self.strategy.fuse_all_reduce_ops = flag
        else:
            print("WARNING: fuse_all_reduce_ops should have value of bool type")

518 519
    @property
    def fuse_grad_size_in_MB(self):
520 521 522 523 524 525 526 527 528 529 530 531
        """
        Specifying the size of gradient to fuse in Mega-Bytes

        Default value: 32

        Examples:
          .. code-block:: python
        
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.fuse_grad_size_in_MB = 50
        """
532 533 534
        return self.strategy.fuse_grad_size_in_MB

    @fuse_grad_size_in_MB.setter
535
    @is_strict_auto
536 537 538 539 540 541 542 543 544 545 546
    def fuse_grad_size_in_MB(self, value):
        if isinstance(value, int):
            self.strategy.fuse_grad_size_in_MB = value
        else:
            print("WARNING: fuse_grad_size_in_MB should have value of int type")

    @property
    def _fuse_grad_size_in_TFLOPS(self):
        return self.strategy.fuse_grad_size_in_TFLOPS

    @_fuse_grad_size_in_TFLOPS.setter
547
    @is_strict_auto
548 549 550 551 552 553 554 555
    def _fuse_grad_size_in_TFLOPS(self, value):
        if isinstance(value, float):
            self.strategy.fuse_grad_size_in_TFLOPS = value
        else:
            print(
                "WARNING: fuse_grad_size_in_TFLOPS should have value of float type"
            )

556
    @property
557
    def nccl_comm_num(self):
558 559 560 561 562 563 564 565 566 567 568 569 570
        """
        Specifying the number of NCCL communicator

        Default value: 1

        Examples:
          .. code-block:: python
        
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.nccl_comm_num = 2
        """

571
        return self.strategy.nccl_comm_num
572

573
    @nccl_comm_num.setter
574
    @is_strict_auto
575
    def nccl_comm_num(self, value):
576
        if isinstance(value, int):
577
            self.strategy.nccl_comm_num = value
578
        else:
579
            print("WARNING: nccl_comm_num should have value of int type")
580

581
    @recompute.setter
582
    @is_strict_auto
583
    def recompute(self, flag):
584
        if isinstance(flag, bool):
585
            self.strategy.recompute = flag
586
        else:
587
            print("WARNING: recompute should have value of bool type")
588 589

    @property
590 591 592 593
    def recompute_configs(self):
        """
        Set recompute configurations. In general, the recompute strategy of current
        implementation should have some manually assign checkpoints
594

595 596 597
        Examples:
          .. code-block:: python
        
598
            import paddle.distributed.fleet as fleet
599 600
            strategy = fleet.DistributedStrategy()
            strategy.recompute = True
M
mapingshuo 已提交
601
            strategy.recompute_configs = {"checkpoints": ["x", "y"]}
602 603 604 605 606

        """
        return get_msg_dict(self.strategy.recompute_configs)

    @recompute_configs.setter
607
    @is_strict_auto
608 609 610 611
    def recompute_configs(self, configs):
        check_configs_key(self.strategy.recompute_configs, configs,
                          "checkpoint_configs")
        assign_configs_value(self.strategy.recompute_configs, configs)
612 613

    @property
614 615 616 617 618 619 620 621 622 623
    def pipeline(self):
        """
        Indicating whether we are using pipeline parallelism for distributed training.
        Current implementation mainly focus on single GPU machine pipeline parallelism and
        data parallelism across GPU machine. The pipeline information is indicated through
        device_guard information in user-defined program.

        Examples:
          .. code-block:: python
        
624
            import paddle.distributed.fleet as fleet
625 626 627 628 629
            strategy = fleet.DistributedStrategy()
            strategy.pipeline = True

        """
        return self.strategy.pipeline
630

631
    @pipeline.setter
632
    @is_strict_auto
633
    def pipeline(self, flag):
634
        if isinstance(flag, bool):
635
            self.strategy.pipeline = flag
636
        else:
637
            print("WARNING: pipeline should have value of bool type")
638 639

    @property
640 641 642 643 644 645 646 647 648 649
    def pipeline_configs(self):
        """
        Set pipeline parallelism configurations. In pipeline parallelism,
        different parts of neural networks are running on different GPUS.
        There are Tensor queue buffer between each pair of neighborhood GPUS 
        that are responsible for synchronizing hidden Tensor results between
        GPUs. Pipeline parallelism consists of serveral producer-consumer style
        hardware pairs, such as GPU-GPU, CPU-GPU, GPU-XPU. The best way to speedup
        pipeline parallelism is to make the size of Tensor in Tensor queue smaller, 
        so that we will have a faster producer for downstream consumers.
650

651 652
        **Notes**:
            **Detailed arguments for pipeline_configs**
M
mapingshuo 已提交
653

654
            **micro_batch**: the number of small batches in each user defined batch
655

656 657 658
        Examples:
          .. code-block:: python
        
659
            import paddle.distributed.fleet as fleet
660 661 662
            strategy = fleet.DistributedStrategy()
            strategy.pipeline = True
            strategy.pipeline_configs = {"micro_batch": 12}
663

664
        """
665

666
        return get_msg_dict(self.strategy.pipeline_configs)
667

668
    @pipeline_configs.setter
669
    @is_strict_auto
670 671 672 673
    def pipeline_configs(self, configs):
        check_configs_key(self.strategy.pipeline_configs, configs,
                          "pipeline_configs")
        assign_configs_value(self.strategy.pipeline_configs, configs)
674 675

    @property
676
    def localsgd(self):
677
        """
M
mapingshuo 已提交
678 679 680
        Indicating whether we are using Local SGD training. Default Value: False
        For more details, please refer to
        `Don't Use Large Mini-Batches, Use Local SGD <https://arxiv.org/pdf/1808.07217.pdf>`_.
681 682 683 684 685 686 687 688 689 690


        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.localsgd = True # by default this is false

        """
691
        return self.strategy.localsgd
692

693
    @localsgd.setter
694
    @is_strict_auto
695 696 697
    def localsgd(self, flag):
        if isinstance(flag, bool):
            self.strategy.localsgd = flag
698
        else:
699
            print("WARNING: localsgd should have value of bool type")
700 701

    @property
702
    def localsgd_configs(self):
703 704 705 706 707
        """
        Set LocalSGD training configurations. LocalSGD has a configurable
        setting that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
708 709 710 711 712 713
            k_steps(int) The local steps for training before parameter synchronization. Default 1.

            If strategy.auto is set True, the local steps will be calculated automatically during training.
            The algorithm is referenced in this paper: 
            `Adaptive Communication Strategies to Achieve the Best Error-Runtime Trade-off in Local-Update SGD <https://arxiv.org/pdf/1810.08313.pdf>`_.
            In this case, k_steps indicates the first local steps which is suggested setting to 1.
714 715 716 717 718 719 720 721 722 723

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.localsgd = True
            strategy.localsgd_configs = {"k_steps": 4}
        """

724
        return get_msg_dict(self.strategy.localsgd_configs)
725

726
    @localsgd_configs.setter
727
    @is_strict_auto
728 729 730 731
    def localsgd_configs(self, configs):
        check_configs_key(self.strategy.localsgd_configs, configs,
                          "localsgd_configs")
        assign_configs_value(self.strategy.localsgd_configs, configs)
732 733

    @property
734
    def dgc(self):
735 736 737 738 739 740 741 742 743 744 745 746 747 748
        """
        Indicating whether we are using Deep Gradient Compression training. For more details, please refer to
        [Deep Gradient Compression](https://arxiv.org/abs/1712.01887).

        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True # by default this is false

        """
749
        return self.strategy.dgc
750

751
    @dgc.setter
752
    @is_strict_auto
753 754 755
    def dgc(self, flag):
        if isinstance(flag, bool):
            self.strategy.dgc = flag
756
        else:
757
            print("WARNING: dgc should have value of bool type")
758 759

    @property
760
    def dgc_configs(self):
761 762 763 764 765
        """
        Set Deep Gradient Compression training configurations. In general, dgc has serveral configurable
        settings that can be configured through a dict.

        **Notes**:
M
mapingshuo 已提交
766 767 768 769 770 771 772 773 774 775
            rampup_begin_step(int): The beginning step from which gradient compression is implemented. Default 0.

            rampup_step(int): Time steps used in sparsity warm-up periods. Default is 1. \
                    For example, if the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 100, \
                    it will use 0.75 at 0~19 steps, and 0.9375 at 20~39 steps, and so on. And when reach sparsity array \
                    ends, it will use 0.999 then and after.

            sparsity(list[float]): Get top important element from gradient tensor, the ratio is (1 - sparsity). \
                    Default is [0.999]. For example, if the sparsity is [0.99, 0.999], the top [1%, 0.1%] important \
                    element will be transmitted.
776 777 778 779 780 781 782 783 784

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.dgc = True
            strategy.dgc_configs = {"rampup_begin_step": 1252}
        """
785
        return get_msg_dict(self.strategy.dgc_configs)
786

787
    @dgc_configs.setter
788
    @is_strict_auto
789 790 791
    def dgc_configs(self, configs):
        check_configs_key(self.strategy.dgc_configs, configs, "dgc_configs")
        assign_configs_value(self.strategy.dgc_configs, configs)
792 793

    @property
794
    def gradient_merge(self):
795 796 797 798 799 800 801 802 803 804 805
        """
        Gradient Merge, also called as Gradient Accumulation,
        is a strategy for large batch training. With this strategy,
        model parameter will not be updated until user-defined steps.
        For each step, the forward network and the backward network
        will run to calculate the gradient of model parameters.
        For every k step, the optimization network will run,
        applying a specific optimization method (such as SGD, Adam)
        to model parameters.

        Examples:
M
mapingshuo 已提交
806 807
          .. code-block:: python

808
            import paddle.distributed.fleet as fleet
809 810 811 812
            strategy = fleet.DistributedStrategy()
            strategy.gradient_merge = True
            strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
        """
813
        return self.strategy.gradient_merge
814

815
    @gradient_merge.setter
816
    @is_strict_auto
817
    def gradient_merge(self, flag):
818
        if isinstance(flag, bool):
819
            self.strategy.gradient_merge = flag
820
        else:
821 822 823 824
            print("WARNING: gradient_merge should have value of bool type")

    @property
    def gradient_merge_configs(self):
825 826
        """
        the key-value configs of distribute_strategy
M
mapingshuo 已提交
827 828 829 830 831 832 833 834 835

        **Note**:
            k_steps(int): the update period of the parameters.

            avg(bool): whether to average the gradients of each mini-batch, the default value is `True`

        Examples:
          .. code-block:: python

836
            import paddle.distributed.fleet as fleet
837 838 839 840
            strategy = fleet.DistributedStrategy()
            strategy.gradient_merge = True
            strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
        """
841 842 843
        return get_msg_dict(self.strategy.gradient_merge_configs)

    @gradient_merge_configs.setter
844
    @is_strict_auto
845 846 847 848
    def gradient_merge_configs(self, configs):
        check_configs_key(self.strategy.gradient_merge_configs, configs,
                          "gradient_configs")
        assign_configs_value(self.strategy.gradient_merge_configs, configs)
849 850

    @property
851
    def lars(self):
852 853 854 855 856 857 858 859 860 861 862 863 864 865
        """
        Set lars configurations. lars is used to deal with the convergence problems when the global 
        batch size is larger than 8k.  For more details, please refer to 
        [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888).

        Default Value: False

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lars = True # by default this is false
        """
866
        return self.strategy.lars
867

868
    @lars.setter
869
    @is_strict_auto
870
    def lars(self, flag):
871
        if isinstance(flag, bool):
872
            self.strategy.lars = flag
873
        else:
874
            print("WARNING: lars should have value of bool type")
875

876 877
    @property
    def lars_configs(self):
878 879 880 881 882 883 884 885 886 887 888 889 890
        """
        Set Lars training configurations.

        **Notes**:
        **lars_coeff (float)**: trust ratio in lars formula.
        **lars_weight_decay** (float): weight decay coefficient in lars formula.
        **epsilon (float)**: argument is used to avoid potential devision-by-zero 
        when compute the local lr; 
        **exclude_from_weight_decay ([string])**: is a list of name strings of layers which
        will be exclude from weight decay in lars formula.

        Examples:
          .. code-block:: python
M
mapingshuo 已提交
891

892 893 894 895 896 897 898 899 900 901
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lars = True
            strategy.lars_configs = {
                        "lars_coeff": 0.01,
                        "lars_weight_decay": 0.0005,
                        "epsilon": 0,
                        "exclude_from_weight_decay": ['batch_norm', '.b_0']
                    }
        """
902 903 904
        return get_msg_dict(self.strategy.lars_configs)

    @lars_configs.setter
905
    @is_strict_auto
906 907 908 909
    def lars_configs(self, configs):
        check_configs_key(self.strategy.lars_configs, configs, "lars_configs")
        assign_configs_value(self.strategy.lars_configs, configs)

910
    @property
911
    def lamb(self):
912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
        """
        Set lamb configurations. lamb is used to deal with the convergence problems for large 
        batch size training, specially for attention-related model like BERT. For more details, 
        please refer to 
        [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962).

        Default Value: False
        
        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lamb = True # by default this is false
        """

928
        return self.strategy.lamb
929

930
    @lamb.setter
931
    @is_strict_auto
932
    def lamb(self, flag):
933
        if isinstance(flag, bool):
934
            self.strategy.lamb = flag
935
        else:
936
            print("WARNING: lamb should have value of bool type")
937

938 939
    @property
    def lamb_configs(self):
940 941 942 943 944 945 946 947 948 949
        """
        Set Lars training configurations.

        **Notes**:
        **lamb_weight_decay** (float): weight decay coefficient in lamb formula.
        **exclude_from_weight_decay ([string])**: is a list of name strings of layers which
        will be exclude from weight decay in lamb formula.

        Examples:
          .. code-block:: python
M
mapingshuo 已提交
950

951 952 953 954 955 956 957 958
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.lamb = True
            strategy.lamb_configs = {
                    'lamb_weight_decay': 0.01,
                    'exclude_from_weight_decay': [],
                }
        """
959 960 961
        return get_msg_dict(self.strategy.lamb_configs)

    @lamb_configs.setter
962
    @is_strict_auto
963 964 965 966
    def lamb_configs(self, configs):
        check_configs_key(self.strategy.lamb_configs, configs, "lamb_configs")
        assign_configs_value(self.strategy.lamb_configs, configs)

967 968
    @property
    def elastic(self):
969 970 971 972
        """
        Indicating whether we want to do current distributed training on clusters with elastic resources.
        Currently, this is configuration is not valid.
        """
973 974 975
        return self.strategy.elastic

    @elastic.setter
976
    @is_strict_auto
977 978 979 980 981 982 983 984
    def elastic(self, flag):
        if isinstance(flag, bool):
            self.strategy.elastic = flag
        else:
            print("WARNING: elastic should have value of bool type")

    @property
    def auto(self):
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003
        """
        Indicating whether we are using auto-parallel configuration
        This feature is currently an experimental feature. Currently, 
        auto-parallelism can be used only when a user does not set any other
        strategy configs except auto. For details, please reference the following
        code example
        Default Value: False

        Examples:
          .. code-block:: python

            import paddle
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.auto = True

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
        """
1004 1005 1006 1007 1008 1009 1010 1011 1012
        return self.strategy.auto

    @auto.setter
    def auto(self, flag):
        if isinstance(flag, bool):
            self.strategy.auto = flag
        else:
            print("WARNING: auto should have value of bool type")

1013 1014
    @property
    def cudnn_exhaustive_search(self):
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
        """
        Indicating whether to use exhaustive search method to choose convolution algorithms.
        Exhaustive search attempts all cuDNN algorithms to choose the fastest algorithm.
        This method is time-consuming, the choosed algorithm will be cached for the given layer specifications.
        Once the layer specifications (like batch size, feature map size) are changed, it will search again.
        Default Value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.cudnn_exhaustive_search = False

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
        """
1032 1033 1034
        return self.strategy.cudnn_exhaustive_search

    @cudnn_exhaustive_search.setter
1035
    @is_strict_auto
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
    def cudnn_exhaustive_search(self, flag):
        if isinstance(flag, bool):
            self.strategy.cudnn_exhaustive_search = flag
        else:
            print(
                "WARNING: cudnn_exhaustive_search should have value of bool type"
            )

    @property
    def conv_workspace_size_limit(self):
1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
        """
        The workspace limit size in MB unit for choosing cuDNN convolution algorithms.
        The inner funciton of cuDNN obtain the fastest suited algorithm that fits within this memory limit.
        Usually, large workspace size may lead to choose faster algorithms,
        but significant increasing memory workspace. Users need to trade-off between memory and speed.
        Default Value: 4000

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.conv_workspace_size_limit = 1024

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)
        
        """
1064 1065 1066
        return self.strategy.conv_workspace_size_limit

    @conv_workspace_size_limit.setter
1067
    @is_strict_auto
1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
    def conv_workspace_size_limit(self, value):
        if isinstance(value, int):
            self.strategy.conv_workspace_size_limit = value
        else:
            print(
                "WARNING: conv_workspace_size_limit should have value of int type"
            )

    @property
    def cudnn_batchnorm_spatial_persistent(self):
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
        """
        Indicates whether to use the mode CUDNN_BATCHNORM_SPATIAL_PERSISTENT function in batchnorm.
        This is only useful in cudnn.
        Default Value: True

        Examples:
          .. code-block:: python

            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
            strategy.cudnn_batchnorm_spatial_persistent = True

            optimizer = paddle.optimizer.SGD(learning_rate=0.01)
            optimizer = fleet.distributed_optimizer(optimizer, strategy)

        """
1094 1095 1096
        return self.strategy.cudnn_batchnorm_spatial_persistent

    @cudnn_batchnorm_spatial_persistent.setter
1097
    @is_strict_auto
1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128
    def cudnn_batchnorm_spatial_persistent(self, flag):
        if isinstance(flag, bool):
            self.strategy.cudnn_batchnorm_spatial_persistent = flag
        else:
            print(
                "WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type"
            )

    def _enable_env(self):
        strategy = self.strategy
        keys = [
            "FLAGS_cudnn_batchnorm_spatial_persistent",
            "FLAGS_conv_workspace_size_limit",
            "FLAGS_cudnn_exhaustive_search",
            "FLAGS_sync_nccl_allreduce",
            "FLAGS_fuse_parameter_memory_size",
            "FLAGS_fuse_parameter_groups_size",
        ]
        values = [
            bool(strategy.cudnn_batchnorm_spatial_persistent),
            int(strategy.conv_workspace_size_limit),
            bool(strategy.cudnn_exhaustive_search),
            bool(strategy.sync_nccl_allreduce),
            int(strategy.fuse_grad_size_in_MB),
            int(strategy.fuse_grad_size_in_TFLOPS),
        ]

        for i, key in enumerate(keys):
            if core.globals().is_public(key):
                core.globals()[key] = values[i]

1129 1130 1131 1132 1133 1134
    def _is_strict_auto(self):
        global non_auto_func_called
        if self.strategy.auto and non_auto_func_called:
            return True
        return False

1135
    def __repr__(self):
D
Dong Daxiang 已提交
1136 1137 1138
        fields = self.strategy.DESCRIPTOR.fields
        for f in fields:
            print("{}: {}".format(f.name, f.default_value))
1139
        return str(self.strategy)