diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 8754c3a0c4312398da901d8f8cb39b12380db432..208ab9a93c005b73868de0e1221c2847586057c5 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -141,9 +141,9 @@ message DistributedStrategy { optional bool fuse_all_reduce_ops = 18 [ default = true ]; optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; - optional bool cudnn_exhaustive_search = 21 [ default = true ]; + optional bool cudnn_exhaustive_search = 21 [ default = false ]; optional int32 conv_workspace_size_limit = 22 [ default = 512 ]; - optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; + optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ]; optional bool sharding = 26 [ default = false ]; diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 186d9263dc57df822130e333be082d283e6bb845..f79013d7347c00efc36aef17dba5f6d3a1ae3165 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -118,6 +118,22 @@ class DistributedStrategy(object): """ self.strategy = distributed_strategy_pb2.DistributedStrategy() + + # Set the default values of the following flags to the ones set by users + key = 'FLAGS_cudnn_batchnorm_spatial_persistent' + if core.globals().is_public(key): + self.strategy.cudnn_batchnorm_spatial_persistent = bool( + core.globals()[key]) + key = 'FLAGS_conv_workspace_size_limit' + if core.globals().is_public(key): + self.strategy.conv_workspace_size_limit = int(core.globals()[key]) + key = 'FLAGS_cudnn_exhaustive_search' + if core.globals().is_public(key): + self.strategy.cudnn_exhaustive_search = bool(core.globals()[key]) + key = 'FLAGS_sync_nccl_allreduce' + if core.globals().is_public(key): + self.strategy.sync_nccl_allreduce = bool(core.globals()[key]) + self.__lock_attr = True def __setattr__(self, key, value):