提交 4402e629 编写于 作者: W WenmuZhou

修正export_model里的bug,添加predict_det

上级 89e031f0
...@@ -12,6 +12,13 @@ ...@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import argparse import argparse
import paddle import paddle
...@@ -20,14 +27,11 @@ from paddle.jit import to_static ...@@ -20,14 +27,11 @@ from paddle.jit import to_static
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
from ppocr.utils.logging import get_logger
from tools.program import load_config from tools.program import load_config
from tools.program import merge_config
def parse_args(): def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use") parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument( parser.add_argument(
...@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer): ...@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
# Please modify the 'shape' according to actual needs # Please modify the 'shape' according to actual needs
@to_static(input_spec=[ @to_static(input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 32, None], dtype='float32') shape=[None, 3, 640, 640], dtype='float32')
]) ])
def forward(self, inputs): def forward(self, inputs):
x = self.pre_model(inputs) x = self.pre_model(inputs)
...@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer): ...@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
def main(): def main():
FLAGS = parse_args() FLAGS = parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) logger = get_logger()
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
config['Global']) config['Global'])
# build model # build model
#for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
...@@ -69,7 +72,10 @@ def main(): ...@@ -69,7 +72,10 @@ def main():
model.eval() model.eval()
model = Model(model) model = Model(model)
paddle.jit.save(model, FLAGS.output_path) save_path = '{}/{}'.format(FLAGS.output_path,
config['Architecture']['model_type'])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -22,7 +22,6 @@ import cv2 ...@@ -22,7 +22,6 @@ import cv2
import numpy as np import numpy as np
import time import time
import sys import sys
import paddle import paddle
import tools.infer.utility as utility import tools.infer.utility as utility
...@@ -39,7 +38,7 @@ class TextDetector(object): ...@@ -39,7 +38,7 @@ class TextDetector(object):
postprocess_params = {} postprocess_params = {}
if self.det_algorithm == "DB": if self.det_algorithm == "DB":
pre_process_list = [{ pre_process_list = [{
'ResizeForTest': { 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len, 'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type 'limit_type': args.det_limit_type
} }
...@@ -53,7 +52,7 @@ class TextDetector(object): ...@@ -53,7 +52,7 @@ class TextDetector(object):
}, { }, {
'ToCHWImage': None 'ToCHWImage': None
}, { }, {
'keepKeys': { 'KeepKeys': {
'keep_keys': ['image', 'shape'] 'keep_keys': ['image', 'shape']
} }
}] }]
...@@ -68,8 +67,9 @@ class TextDetector(object): ...@@ -68,8 +67,9 @@ class TextDetector(object):
self.preprocess_op = create_operators(pre_process_list) self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor = paddle.jit.load(args.det_model_dir) self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
self.predictor.eval() args, 'det', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
...@@ -133,11 +133,23 @@ class TextDetector(object): ...@@ -133,11 +133,23 @@ class TextDetector(object):
return None, 0 return None, 0
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0) shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time() starttime = time.time()
preds = self.predictor(img) if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(img)
self.predictor.zero_copy_run()
else:
im = paddle.fluid.core.PaddleTensor(img)
self.predictor.run([im])
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
# preds = self.predictor(img)
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
...@@ -146,8 +158,6 @@ class TextDetector(object): ...@@ -146,8 +158,6 @@ class TextDetector(object):
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
place = paddle.CPUPlace()
paddle.disable_static(place)
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
logger = get_logger() logger = get_logger()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册