未验证 提交 94710ae3 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #7266 from WenmuZhou/tipc

add PP-OCRv2 det amp custom_black_list
...@@ -14,6 +14,9 @@ Global: ...@@ -14,6 +14,9 @@ Global:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt save_res_path: ./output/det_db/predicts_db.txt
use_amp: False
amp_level: O2
amp_custom_black_list: ['exp']
Architecture: Architecture:
name: DistillationModel name: DistillationModel
......
...@@ -278,7 +278,8 @@ def train(config, ...@@ -278,7 +278,8 @@ def train(config,
model_average = True model_average = True
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(level=amp_level): custom_black_list = config['Global'].get('amp_custom_black_list',[])
with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list):
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册