未验证 提交 dc8dfba3 编写于 作者: L lilong12 提交者: GitHub

align the default value of some configuration for fleet to that of single cards (#30740)

* update, test=develop
上级 a373aa76
...@@ -141,9 +141,9 @@ message DistributedStrategy { ...@@ -141,9 +141,9 @@ message DistributedStrategy {
optional bool fuse_all_reduce_ops = 18 [ default = true ]; optional bool fuse_all_reduce_ops = 18 [ default = true ];
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
optional bool cudnn_exhaustive_search = 21 [ default = true ]; optional bool cudnn_exhaustive_search = 21 [ default = false ];
optional int32 conv_workspace_size_limit = 22 [ default = 512 ]; optional int32 conv_workspace_size_limit = 22 [ default = 512 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = false ];
optional bool adaptive_localsgd = 24 [ default = false ]; optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ]; optional bool sharding = 26 [ default = false ];
......
...@@ -118,6 +118,22 @@ class DistributedStrategy(object): ...@@ -118,6 +118,22 @@ class DistributedStrategy(object):
""" """
self.strategy = distributed_strategy_pb2.DistributedStrategy() self.strategy = distributed_strategy_pb2.DistributedStrategy()
# Set the default values of the following flags to the ones set by users
key = 'FLAGS_cudnn_batchnorm_spatial_persistent'
if core.globals().is_public(key):
self.strategy.cudnn_batchnorm_spatial_persistent = bool(
core.globals()[key])
key = 'FLAGS_conv_workspace_size_limit'
if core.globals().is_public(key):
self.strategy.conv_workspace_size_limit = int(core.globals()[key])
key = 'FLAGS_cudnn_exhaustive_search'
if core.globals().is_public(key):
self.strategy.cudnn_exhaustive_search = bool(core.globals()[key])
key = 'FLAGS_sync_nccl_allreduce'
if core.globals().is_public(key):
self.strategy.sync_nccl_allreduce = bool(core.globals()[key])
self.__lock_attr = True self.__lock_attr = True
def __setattr__(self, key, value): def __setattr__(self, key, value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册