diff --git a/tools/program.py b/tools/program.py index 012a2c615629c58e40e54190e8f5a098305224f5..50f8455e3c74d30599d13fe8adfe4eee3129714b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -277,7 +277,8 @@ def train(config, model_average = True # use amp if scaler: - with paddle.amp.auto_cast(level=amp_level): + custom_black_list = config['Global'].get('amp_custom_black_list',[]) + with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: