提交 a9f60a3f 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #624 from MRXLT/ocr-demo

add OCR rec demo
......@@ -2,28 +2,27 @@
([简体中文](./README_CN.md)|English)
### Get model files and sample data
### Get Model
```
sh get_data.sh
python -m paddle_serving_app.package --get_model lac
tar -xzvf lac.tar.gz
```
the package downloaded contains lac model config along with lac dictionary.
#### Start RPC inference service
```
python -m paddle_serving_server.serve --model jieba_server_model/ --port 9292
python -m paddle_serving_server.serve --model lac_model/ --port 9292
```
### RPC Infer
```
echo "我爱北京天安门" | python lac_client.py jieba_client_conf/serving_client_conf.prototxt lac_dict/
echo "我爱北京天安门" | python lac_client.py lac_client/serving_client_conf.prototxt
```
it will get the segmentation result
It will get the segmentation result.
### Start HTTP inference service
```
python lac_web_service.py jieba_server_model/ lac_workdir 9292
python lac_web_service.py lac_model/ lac_workdir 9292
```
### HTTP Infer
......
......@@ -2,28 +2,27 @@
(简体中文|[English](./README.md))
### 获取模型和字典文件
### 获取模型
```
sh get_data.sh
python -m paddle_serving_app.package --get_model lac
tar -xzvf lac.tar.gz
```
下载包里包含了lac模型和lac模型预测需要的字典文件
#### 开启RPC预测服务
```
python -m paddle_serving_server.serve --model jieba_server_model/ --port 9292
python -m paddle_serving_server.serve --model lac_model/ --port 9292
```
### 执行RPC预测
```
echo "我爱北京天安门" | python lac_client.py jieba_client_conf/serving_client_conf.prototxt lac_dict/
echo "我爱北京天安门" | python lac_client.py lac_client/serving_client_conf.prototxt
```
我们就能得到分词结果
### 开启HTTP预测服务
```
python lac_web_service.py jieba_server_model/ lac_workdir 9292
python lac_web_service.py lac_model/ lac_workdir 9292
```
### 执行HTTP预测
......
......@@ -16,7 +16,7 @@
import sys
import time
import requests
from lac_reader import LACReader
from paddle_serving_app.reader import LACReader
from paddle_serving_client import Client
from paddle_serving_client.utils import MultiThreadRunner
from paddle_serving_client.utils import benchmark_args
......@@ -25,7 +25,7 @@ args = benchmark_args()
def single_func(idx, resource):
reader = LACReader("lac_dict")
reader = LACReader()
start = time.time()
if args.request == "rpc":
client = Client()
......
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/lac/lac_model_jieba_web.tar.gz
tar -zxvf lac_model_jieba_web.tar.gz
......@@ -15,7 +15,7 @@
# pylint: disable=doc-string-missing
from paddle_serving_client import Client
from lac_reader import LACReader
from paddle_serving_app.reader import LACReader
import sys
import os
import io
......@@ -24,7 +24,7 @@ client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
reader = LACReader(sys.argv[2])
reader = LACReader()
for line in sys.stdin:
if len(line) <= 0:
continue
......@@ -32,4 +32,8 @@ for line in sys.stdin:
if len(feed_data) <= 0:
continue
fetch_map = client.predict(feed={"words": feed_data}, fetch=["crf_decode"])
print(fetch_map)
begin = fetch_map['crf_decode.lod'][0]
end = fetch_map['crf_decode.lod'][1]
segs = reader.parse_result(line, fetch_map["crf_decode"][begin:end])
print({"word_seg": "|".join(segs)})
......@@ -14,7 +14,7 @@
from paddle_serving_server.web_service import WebService
import sys
from lac_reader import LACReader
from paddle_serving_app.reader import LACReader
class LACService(WebService):
......
# OCR
## Get Model
```
python -m paddle_serving_app.package --get_model ocr_rec
tar -xzvf ocr_rec.tar.gz
```
## RPC Service
### Start Service
```
python -m paddle_serving_server.serve --model ocr_rec_model --port 9292
```
### Client Prediction
```
python test_ocr_rec_client.py
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving_client import Client
from paddle_serving_app.reader import OCRReader
import cv2
client = Client()
client.load_client_config("ocr_rec_client/serving_client_conf.prototxt")
client.connect(["127.0.0.1:9292"])
image_file_list = ["./test_rec.jpg"]
img = cv2.imread(image_file_list[0])
ocr_reader = OCRReader()
feed = {"image": ocr_reader.preprocess([img])}
fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
fetch_map = client.predict(feed=feed, fetch=fetch)
rec_res = ocr_reader.postprocess(fetch_map)
print(image_file_list[0])
print(rec_res[0][0])
......@@ -31,10 +31,12 @@ class ServingModels(object):
self.model_dict["ImageClassification"] = [
"resnet_v2_50_imagenet", "mobilenet_v2_imagenet"
]
self.model_dict["OCR"] = ["ocr_rec"]
image_class_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageClassification/"
image_seg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ImageSegmentation/"
object_detection_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/ObjectDetection/"
ocr_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/image/OCR/"
senta_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SentimentAnalysis/"
semantic_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/SemanticRepresentation/"
wordseg_url = "https://paddle-serving.bj.bcebos.com/paddle_hub_models/text/LexicalAnalysis/"
......@@ -52,6 +54,7 @@ class ServingModels(object):
pack_url(self.model_dict, "ObjectDetection", object_detection_url)
pack_url(self.model_dict, "ImageSegmentation", image_seg_url)
pack_url(self.model_dict, "ImageClassification", image_class_url)
pack_url(self.model_dict, "OCR", ocr_url)
def get_model_list(self):
return self.model_dict
......
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .chinese_bert_reader import ChineseBertReader
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize, CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, RCNNPostprocess, SegPostprocess, PadStride
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize
from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB
from .image_reader import RCNNPostprocess, SegPostprocess, PadStride
from .lac_reader import LACReader
from .senta_reader import SentaReader
from .imdb_reader import IMDBDataset
from .ocr_reader import OCRReader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import copy
import numpy as np
import math
import re
import sys
import argparse
from paddle_serving_app.reader import Sequential, Resize, Transpose, Div, Normalize
class CharacterOps(object):
""" Convert between text-label and text-index """
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
self.character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
dict_character = list(self.character_str)
elif self.character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
else:
self.character_str = None
assert self.character_str is not None, \
"Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if self.character_type == "en":
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
text = np.array(text_list)
return text
def decode(self, text_index, is_remove_duplicate=False):
""" convert text-index into text-label. """
char_list = []
char_num = self.get_char_num()
if self.loss_type == "attention":
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
ignored_tokens = [beg_idx, end_idx]
else:
ignored_tokens = [char_num]
for idx in range(len(text_index)):
if text_index[idx] in ignored_tokens:
continue
if is_remove_duplicate:
if idx > 0 and text_index[idx - 1] == text_index[idx]:
continue
char_list.append(self.character[text_index[idx]])
text = ''.join(char_list)
return text
def get_char_num(self):
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
if self.loss_type == "attention":
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx"\
% beg_or_end
return idx
else:
err = "error in get_beg_end_flag_idx when using the loss %s"\
% (self.loss_type)
assert False, err
class OCRReader(object):
def __init__(self):
args = self.parse_args()
image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.rec_image_shape = image_shape
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path
char_ops_params['loss_type'] = 'ctc'
self.char_ops = CharacterOps(char_ops_params)
def parse_args(self):
parser = argparse.ArgumentParser()
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=1)
parser.add_argument(
"--rec_char_dict_path", type=str, default="./ppocr_keys_v1.txt")
return parser.parse_args()
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
if self.character_type == "ch":
imgW = int(32 * max_wh_ratio)
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
seq = Sequential([
Resize(imgH, resized_w), Transpose((2, 0, 1)), Div(255),
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], True)
])
resized_image = seq(img)
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def preprocess(self, img_list):
img_num = len(img_list)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(img_num):
h, w = img_list[ino].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(img_num):
norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
return norm_img_batch[0]
def postprocess(self, outputs):
rec_res = []
rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
predict_lod = outputs["softmax_0.tmp_0.lod"]
rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
for rno in range(len(rec_idx_lod) - 1):
beg = rec_idx_lod[rno]
end = rec_idx_lod[rno + 1]
rec_idx_tmp = rec_idx_batch[beg:end, 0]
preds_text = self.char_ops.decode(rec_idx_tmp)
beg = predict_lod[rno]
end = predict_lod[rno + 1]
probs = outputs["softmax_0.tmp_0"][beg:end, :]
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
return rec_res
......@@ -86,7 +86,7 @@ class WebService(object):
for key in fetch_map:
fetch_map[key] = fetch_map[key].tolist()
fetch_map = self.postprocess(
feed=feed, fetch=fetch, fetch_map=fetch_map)
feed=request.json["feed"], fetch=fetch, fetch_map=fetch_map)
result = {"result": fetch_map}
except ValueError:
result = {"result": "Request Value Error"}
......
#!/usr/bin/env bash
set -x
function unsetproxy() {
HTTP_PROXY_TEMP=$http_proxy
HTTPS_PROXY_TEMP=$https_proxy
......@@ -455,15 +455,16 @@ function python_test_lac() {
cd lac # pwd: /Serving/python/examples/lac
case $TYPE in
CPU)
sh get_data.sh
check_cmd "python -m paddle_serving_server.serve --model jieba_server_model/ --port 9292 &"
python -m paddle_serving_app.package --get_model lac
tar -xzvf lac.tar.gz
check_cmd "python -m paddle_serving_server.serve --model lac_model/ --port 9292 &"
sleep 5
check_cmd "echo \"我爱北京天安门\" | python lac_client.py jieba_client_conf/serving_client_conf.prototxt lac_dict/"
check_cmd "echo \"我爱北京天安门\" | python lac_client.py lac_client/serving_client_conf.prototxt "
echo "lac CPU RPC inference pass"
kill_server_process
unsetproxy # maybe the proxy is used on iPipe, which makes web-test failed.
check_cmd "python lac_web_service.py jieba_server_model/ lac_workdir 9292 &"
check_cmd "python lac_web_service.py lac_model/ lac_workdir 9292 &"
sleep 5
check_cmd "curl -H \"Content-Type:application/json\" -X POST -d '{\"feed\":[{\"words\": \"我爱北京天安门\"}], \"fetch\":[\"word_seg\"]}' http://127.0.0.1:9292/lac/prediction"
# check http code
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册