diff --git a/PaddleCV/gan/cycle_gan/train.py b/PaddleCV/gan/cycle_gan/train.py index 3f03c747e7e63d8442c4db7697d46c2a05ebd387..cd0fade81472d5b320cb3ebd59f6e4b85dda26fd 100644 --- a/PaddleCV/gan/cycle_gan/train.py +++ b/PaddleCV/gan/cycle_gan/train.py @@ -9,12 +9,18 @@ def set_paddle_flags(flags): if os.environ.get(key, None) is None: os.environ[key] = str(value) +use_cudnn_deterministic = os.environ.get('FLAGS_cudnn_deterministic', None) + +if use_cudnn_deterministic: + use_cudnn_exhaustive_search = 0 +else: + use_cudnn_exhaustive_search = 1 # NOTE(paddle-dev): All of these flags should be # set before `import paddle`. Otherwise, it would # not take any effect. set_paddle_flags({ - 'FLAGS_cudnn_exhaustive_search': 1, + 'FLAGS_cudnn_exhaustive_search': use_cudnn_exhaustive_search, 'FLAGS_conv_workspace_size_limit': 256, 'FLAGS_eager_delete_tensor_gb': 0, # enable gc # You can omit the following settings, because the default