diff --git a/tools/static/train.py b/tools/static/train.py index cd6aaefaa0eb11f0983057831e54f9eb95b02fa5..12894b67bce3c0dabf5802a5b5ceba9bad063b23 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -72,7 +72,7 @@ def main(args): if use_amp or use_pure_fp16: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_exhaustive_search': 1, - 'FLAGS_conv_workspace_size_limit': 4000, + 'FLAGS_conv_workspace_size_limit': 1500, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_max_inplace_grad_add': 8, }