未验证 提交 1742d1c5 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #5902 from LDOUBLEV/dygraph

[tipc] fix rec pact hang and east hang
...@@ -127,6 +127,7 @@ def main(): ...@@ -127,6 +127,7 @@ def main():
arch_config = config["Architecture"] arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model if arch_config["algorithm"] in ["Distillation", ]: # distillation model
for idx, name in enumerate(model.model_name_list): for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference") sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(quanter, model.model_list[idx], infer_shape, export_single_model(quanter, model.model_list[idx], infer_shape,
sub_model_save_path, logger) sub_model_save_path, logger)
......
...@@ -80,7 +80,6 @@ class CTCHead(nn.Layer): ...@@ -80,7 +80,6 @@ class CTCHead(nn.Layer):
result = (x, predicts) result = (x, predicts)
else: else:
result = predicts result = predicts
if not self.training: if not self.training:
predicts = F.softmax(predicts, axis=2) predicts = F.softmax(predicts, axis=2)
result = predicts result = predicts
......
...@@ -89,7 +89,7 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -89,7 +89,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
use_space_char) use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple): if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1] preds = preds[-1]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
......
===========================train_params=========================== ===========================train_params===========================
model_name:PPOCRv2_ocr_rec_pact model_name:ch_PPOCRv2_rec_PACT
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300 Global.epoch_num:lite_train_lite_infer=6|whole_train_whole_infer=300
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null Global.pretrained_model:pretrain_models/ch_PP-OCRv2_rec_train/best_accuracy
train_model_name:latest train_model_name:latest
train_infer_img_dir:./inference/rec_inference train_infer_img_dir:./inference/rec_inference
null:null null:null
......
===========================train_params=========================== ===========================train_params===========================
model_name:det_mv3_east_v2.0 model_name:det_mv3_east_v2.0
python:python3.7 python:python3.7
gpu_list:0 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500 Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null Global.pretrained_model:./pretrain_models/det_mv3_east_v2.0_train/best_accuracy
train_model_name:latest train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
......
...@@ -64,6 +64,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -64,6 +64,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../ cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../
fi fi
if [ ${model_name} == "ch_PPOCRv2_rec" ] || [ ${model_name} == "ch_PPOCRv2_rec_PACT" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_rec_train.tar && cd ../
fi
if [ ${model_name} == "det_r18_db_v2_0" ]; then if [ ${model_name} == "det_r18_db_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
fi fi
...@@ -91,6 +95,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -91,6 +95,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../ cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi fi
if [ ${model_name} == "det_mv3_east_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
......
...@@ -312,12 +312,22 @@ def create_predictor(args, mode, logger): ...@@ -312,12 +312,22 @@ def create_predictor(args, mode, logger):
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: for name in input_names:
input_tensor = predictor.get_input_handle(name) input_tensor = predictor.get_input_handle(name)
output_names = predictor.get_output_names() output_tensors = get_output_tensors(args, mode, predictor)
output_tensors = [] return predictor, input_tensor, output_tensors, config
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and args.rec_algorithm == "CRNN":
output_name = 'softmax_0.tmp_0'
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
else:
for output_name in output_names: for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name) output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors, config return output_tensors
def get_infer_gpuid(): def get_infer_gpuid():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册