diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 4b984210ed18d9f51d9485616d1c28871d936237..551d1342edeb335d1cad4782f85ae9f94f8739bd 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -113,7 +113,9 @@ message DistributedStrategy { optional bool fuse_all_reduce_ops = 18 [ default = true ]; optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; - // optional bool enable_backward_optimizer_op_deps = 19 [ default = true ]; + optional bool cudnn_exhaustive_search = 21 [ default = true ]; + optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; + optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 2971617aa705f55f193e512bf7ef75b609588c02..a337fc41f292521c5e90daffd71b5bd4ff4e0553 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -14,7 +14,7 @@ import paddle from paddle.distributed.fleet.proto import distributed_strategy_pb2 -from paddle.fluid.framework import Variable +from paddle.fluid.framework import Variable, set_flags, core import google.protobuf.text_format @@ -810,6 +810,68 @@ class DistributedStrategy(object): else: print("WARNING: auto should have value of bool type") + @property + def cudnn_exhaustive_search(self): + return self.strategy.cudnn_exhaustive_search + + @cudnn_exhaustive_search.setter + def cudnn_exhaustive_search(self, flag): + if isinstance(flag, bool): + self.strategy.cudnn_exhaustive_search = flag + else: + print( + "WARNING: cudnn_exhaustive_search should have value of bool type" + ) + + @property + def conv_workspace_size_limit(self): + return self.strategy.conv_workspace_size_limit + + @conv_workspace_size_limit.setter + def conv_workspace_size_limit(self, value): + if isinstance(value, int): + self.strategy.conv_workspace_size_limit = value + else: + print( + "WARNING: conv_workspace_size_limit should have value of int type" + ) + + @property + def cudnn_batchnorm_spatial_persistent(self): + return self.strategy.cudnn_batchnorm_spatial_persistent + + @cudnn_batchnorm_spatial_persistent.setter + def cudnn_batchnorm_spatial_persistent(self, flag): + if isinstance(flag, bool): + self.strategy.cudnn_batchnorm_spatial_persistent = flag + else: + print( + "WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type" + ) + + def _enable_env(self): + strategy = self.strategy + keys = [ + "FLAGS_cudnn_batchnorm_spatial_persistent", + "FLAGS_conv_workspace_size_limit", + "FLAGS_cudnn_exhaustive_search", + "FLAGS_sync_nccl_allreduce", + "FLAGS_fuse_parameter_memory_size", + "FLAGS_fuse_parameter_groups_size", + ] + values = [ + bool(strategy.cudnn_batchnorm_spatial_persistent), + int(strategy.conv_workspace_size_limit), + bool(strategy.cudnn_exhaustive_search), + bool(strategy.sync_nccl_allreduce), + int(strategy.fuse_grad_size_in_MB), + int(strategy.fuse_grad_size_in_TFLOPS), + ] + + for i, key in enumerate(keys): + if core.globals().is_public(key): + core.globals()[key] = values[i] + def __repr__(self): fields = self.strategy.DESCRIPTOR.fields for f in fields: diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index c8ae8df52066ef9498996c1094cfa01c3f27b615..a6286bcca87fad1afddbd8af1e56dad05dab2578 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -383,6 +383,7 @@ class Fleet(object): context["valid_strategy"] = valid_strategy self.valid_strategy = valid_strategy + self.valid_strategy._enable_env() optimize_ops = [] params_grads = [] diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 40e0168e1ac93dfd93a99c19eced05756a49471f..8d715674cc6c9ba4f8b5c1ff4fe0cbdbe7841643 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -294,6 +294,28 @@ class TestStrategyConfig(unittest.TestCase): with self.assertRaises(TypeError): strategy.unknown_key = 'UNK' + def test_cudnn_exhaustive_search(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.cudnn_exhaustive_search = False + self.assertEqual(strategy.cudnn_exhaustive_search, False) + strategy.cudnn_exhaustive_search = "True" + self.assertEqual(strategy.cudnn_exhaustive_search, False) + + def test_cudnn_batchnorm_spatial_persistent(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.cudnn_batchnorm_spatial_persistent = False + self.assertEqual(strategy.cudnn_batchnorm_spatial_persistent, False) + strategy.cudnn_batchnorm_spatial_persistent = "True" + self.assertEqual(strategy.cudnn_batchnorm_spatial_persistent, False) + + def test_conv_workspace_size_limit(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.conv_workspace_size_limit = 1000 + self.assertEqual(strategy.conv_workspace_size_limit, 1000) + strategy.conv_workspace_size_limit = "400" + self.assertEqual(strategy.conv_workspace_size_limit, 1000) + strategy._enable_env() + if __name__ == '__main__': unittest.main()