未验证 提交 08d736ad 编写于 作者: D Dong Daxiang 提交者: GitHub

【paddle.fleet】add cudnn related strategies to DistributedStrategy (#26598)

* add cudnn related strategies to DistributedStrategy
上级 0a895bc0
...@@ -113,7 +113,9 @@ message DistributedStrategy { ...@@ -113,7 +113,9 @@ message DistributedStrategy {
optional bool fuse_all_reduce_ops = 18 [ default = true ]; optional bool fuse_all_reduce_ops = 18 [ default = true ];
optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ];
optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ];
// optional bool enable_backward_optimizer_op_deps = 19 [ default = true ]; optional bool cudnn_exhaustive_search = 21 [ default = true ];
optional int32 conv_workspace_size_limit = 22 [ default = 4000 ];
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional RecomputeConfig recompute_configs = 101; optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102; optional AMPConfig amp_configs = 102;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import paddle import paddle
from paddle.distributed.fleet.proto import distributed_strategy_pb2 from paddle.distributed.fleet.proto import distributed_strategy_pb2
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable, set_flags, core
import google.protobuf.text_format import google.protobuf.text_format
...@@ -810,6 +810,68 @@ class DistributedStrategy(object): ...@@ -810,6 +810,68 @@ class DistributedStrategy(object):
else: else:
print("WARNING: auto should have value of bool type") print("WARNING: auto should have value of bool type")
@property
def cudnn_exhaustive_search(self):
return self.strategy.cudnn_exhaustive_search
@cudnn_exhaustive_search.setter
def cudnn_exhaustive_search(self, flag):
if isinstance(flag, bool):
self.strategy.cudnn_exhaustive_search = flag
else:
print(
"WARNING: cudnn_exhaustive_search should have value of bool type"
)
@property
def conv_workspace_size_limit(self):
return self.strategy.conv_workspace_size_limit
@conv_workspace_size_limit.setter
def conv_workspace_size_limit(self, value):
if isinstance(value, int):
self.strategy.conv_workspace_size_limit = value
else:
print(
"WARNING: conv_workspace_size_limit should have value of int type"
)
@property
def cudnn_batchnorm_spatial_persistent(self):
return self.strategy.cudnn_batchnorm_spatial_persistent
@cudnn_batchnorm_spatial_persistent.setter
def cudnn_batchnorm_spatial_persistent(self, flag):
if isinstance(flag, bool):
self.strategy.cudnn_batchnorm_spatial_persistent = flag
else:
print(
"WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type"
)
def _enable_env(self):
strategy = self.strategy
keys = [
"FLAGS_cudnn_batchnorm_spatial_persistent",
"FLAGS_conv_workspace_size_limit",
"FLAGS_cudnn_exhaustive_search",
"FLAGS_sync_nccl_allreduce",
"FLAGS_fuse_parameter_memory_size",
"FLAGS_fuse_parameter_groups_size",
]
values = [
bool(strategy.cudnn_batchnorm_spatial_persistent),
int(strategy.conv_workspace_size_limit),
bool(strategy.cudnn_exhaustive_search),
bool(strategy.sync_nccl_allreduce),
int(strategy.fuse_grad_size_in_MB),
int(strategy.fuse_grad_size_in_TFLOPS),
]
for i, key in enumerate(keys):
if core.globals().is_public(key):
core.globals()[key] = values[i]
def __repr__(self): def __repr__(self):
fields = self.strategy.DESCRIPTOR.fields fields = self.strategy.DESCRIPTOR.fields
for f in fields: for f in fields:
......
...@@ -383,6 +383,7 @@ class Fleet(object): ...@@ -383,6 +383,7 @@ class Fleet(object):
context["valid_strategy"] = valid_strategy context["valid_strategy"] = valid_strategy
self.valid_strategy = valid_strategy self.valid_strategy = valid_strategy
self.valid_strategy._enable_env()
optimize_ops = [] optimize_ops = []
params_grads = [] params_grads = []
......
...@@ -294,6 +294,28 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -294,6 +294,28 @@ class TestStrategyConfig(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
strategy.unknown_key = 'UNK' strategy.unknown_key = 'UNK'
def test_cudnn_exhaustive_search(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.cudnn_exhaustive_search = False
self.assertEqual(strategy.cudnn_exhaustive_search, False)
strategy.cudnn_exhaustive_search = "True"
self.assertEqual(strategy.cudnn_exhaustive_search, False)
def test_cudnn_batchnorm_spatial_persistent(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.cudnn_batchnorm_spatial_persistent = False
self.assertEqual(strategy.cudnn_batchnorm_spatial_persistent, False)
strategy.cudnn_batchnorm_spatial_persistent = "True"
self.assertEqual(strategy.cudnn_batchnorm_spatial_persistent, False)
def test_conv_workspace_size_limit(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.conv_workspace_size_limit = 1000
self.assertEqual(strategy.conv_workspace_size_limit, 1000)
strategy.conv_workspace_size_limit = "400"
self.assertEqual(strategy.conv_workspace_size_limit, 1000)
strategy._enable_env()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册