提交 d50addb8 编写于 作者: A andyjpaddle

fix rare trt shape error

上级 a8318353
...@@ -37,7 +37,7 @@ export2:null ...@@ -37,7 +37,7 @@ export2:null
train_model:./inference/rec_r34_vd_tps_bilstm_att_v2.0_train/best_accuracy train_model:./inference/rec_r34_vd_tps_bilstm_att_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
infer_quant:False infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:True|False --enable_mkldnn:True|False
--cpu_threads:1|6 --cpu_threads:1|6
......
...@@ -119,6 +119,10 @@ class TextRecognizer(object): ...@@ -119,6 +119,10 @@ class TextRecognizer(object):
resized_w = imgW resized_w = imgW
else: else:
resized_w = int(math.ceil(imgH * ratio)) resized_w = int(math.ceil(imgH * ratio))
if self.rec_algorithm == 'RARE':
if resized_w > self.rec_image_shape[2]:
resized_w = self.rec_image_shape[2]
imgW = self.rec_image_shape[2]
resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32') resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image = resized_image.transpose((2, 0, 1)) / 255
......
...@@ -323,6 +323,10 @@ def get_output_tensors(args, mode, predictor): ...@@ -323,6 +323,10 @@ def get_output_tensors(args, mode, predictor):
output_name = 'softmax_0.tmp_0' output_name = 'softmax_0.tmp_0'
if output_name in output_names: if output_name in output_names:
return [predictor.get_output_handle(output_name)] return [predictor.get_output_handle(output_name)]
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
else: else:
for output_name in output_names: for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name) output_tensor = predictor.get_output_handle(output_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册