diff --git a/tools/static/train.py b/tools/static/train.py index f63f6a9a29f1110316dc27192b49f669828a5542..973b29d26ad32065f3b44bba0afdf254d669e7c0 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -68,7 +68,7 @@ def main(args): if 'AMP' in config: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_exhaustive_search': 1, - 'FLAGS_conv_workspace_size_limit': 4000, + 'FLAGS_conv_workspace_size_limit': 1500, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_max_inplace_grad_add': 8, }