Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
c4720557
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c4720557
编写于
9月 22, 2020
作者:
W
wangjiawei04
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix all minor bugs
上级
7cacfc97
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
22 addition
and
33 deletion
+22
-33
deploy/pdserving/clas_local_server.py
deploy/pdserving/clas_local_server.py
+1
-1
deploy/pdserving/clas_web_client.py
deploy/pdserving/clas_web_client.py
+0
-1
deploy/pdserving/det_local_server.py
deploy/pdserving/det_local_server.py
+1
-1
deploy/pdserving/det_rpc_server.py
deploy/pdserving/det_rpc_server.py
+1
-2
deploy/pdserving/ocr_local_server.py
deploy/pdserving/ocr_local_server.py
+2
-5
deploy/pdserving/ocr_rpc_server.py
deploy/pdserving/ocr_rpc_server.py
+4
-3
deploy/pdserving/ocr_web_client.py
deploy/pdserving/ocr_web_client.py
+0
-3
deploy/pdserving/rec_local_server.py
deploy/pdserving/rec_local_server.py
+0
-2
deploy/pdserving/rec_rpc_server.py
deploy/pdserving/rec_rpc_server.py
+0
-2
deploy/pdserving/rec_web_client.py
deploy/pdserving/rec_web_client.py
+0
-1
tools/infer/predict_cls.py
tools/infer/predict_cls.py
+1
-1
tools/infer/predict_det.py
tools/infer/predict_det.py
+1
-1
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+1
-1
tools/infer/predict_system.py
tools/infer/predict_system.py
+6
-1
tools/infer/utility.py
tools/infer/utility.py
+4
-8
未找到文件。
deploy/pdserving/clas_local_server.py
浏览文件 @
c4720557
...
@@ -117,7 +117,7 @@ class OCRService(WebService):
...
@@ -117,7 +117,7 @@ class OCRService(WebService):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
ocr_service
=
OCRService
(
name
=
"ocr"
)
ocr_service
=
OCRService
(
name
=
"ocr"
)
ocr_service
.
load_model_config
(
"cls_server"
)
ocr_service
.
load_model_config
(
global_args
.
cls_model_dir
)
ocr_service
.
init_rec
()
ocr_service
.
init_rec
()
if
global_args
.
use_gpu
:
if
global_args
.
use_gpu
:
ocr_service
.
prepare_server
(
ocr_service
.
prepare_server
(
...
...
deploy/pdserving/clas_web_client.py
浏览文件 @
c4720557
...
@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
...
@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
data
=
{
"feed"
:
[{
"image"
:
image
}],
"fetch"
:
[
"res"
]}
data
=
{
"feed"
:
[{
"image"
:
image
}],
"fetch"
:
[
"res"
]}
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
print
(
r
.
json
())
print
(
r
.
json
())
break
deploy/pdserving/det_local_server.py
浏览文件 @
c4720557
...
@@ -96,7 +96,7 @@ class DetService(WebService):
...
@@ -96,7 +96,7 @@ class DetService(WebService):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
ocr_service
=
DetService
(
name
=
"ocr"
)
ocr_service
=
DetService
(
name
=
"ocr"
)
ocr_service
.
load_model_config
(
"serving_server_dir"
)
ocr_service
.
load_model_config
(
global_args
.
det_model_dir
)
ocr_service
.
init_det
()
ocr_service
.
init_det
()
if
global_args
.
use_gpu
:
if
global_args
.
use_gpu
:
ocr_service
.
prepare_server
(
ocr_service
.
prepare_server
(
...
...
deploy/pdserving/det_rpc_server.py
浏览文件 @
c4720557
...
@@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector):
...
@@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector):
class
DetService
(
WebService
):
class
DetService
(
WebService
):
def
init_det
(
self
):
def
init_det
(
self
):
self
.
text_detector
=
TextDetectorHelper
(
global_args
)
self
.
text_detector
=
TextDetectorHelper
(
global_args
)
print
(
"init finish"
)
def
preprocess
(
self
,
feed
=
[],
fetch
=
[]):
def
preprocess
(
self
,
feed
=
[],
fetch
=
[]):
data
=
base64
.
b64decode
(
feed
[
0
][
"image"
].
encode
(
'utf8'
))
data
=
base64
.
b64decode
(
feed
[
0
][
"image"
].
encode
(
'utf8'
))
...
@@ -96,7 +95,7 @@ class DetService(WebService):
...
@@ -96,7 +95,7 @@ class DetService(WebService):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
ocr_service
=
DetService
(
name
=
"ocr"
)
ocr_service
=
DetService
(
name
=
"ocr"
)
ocr_service
.
load_model_config
(
"serving_server_dir"
)
ocr_service
.
load_model_config
(
global_args
.
det_model_dir
)
ocr_service
.
init_det
()
ocr_service
.
init_det
()
if
global_args
.
use_gpu
:
if
global_args
.
use_gpu
:
ocr_service
.
prepare_server
(
ocr_service
.
prepare_server
(
...
...
deploy/pdserving/ocr_local_server.py
浏览文件 @
c4720557
...
@@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem):
...
@@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem):
if
self
.
use_angle_cls
:
if
self
.
use_angle_cls
:
self
.
clas_client
=
Debugger
()
self
.
clas_client
=
Debugger
()
self
.
clas_client
.
load_model_config
(
self
.
clas_client
.
load_model_config
(
"ocr_clas_server"
,
gpu
=
True
,
profile
=
False
)
global_args
.
cls_model_dir
,
gpu
=
True
,
profile
=
False
)
self
.
text_classifier
=
TextClassifierHelper
(
args
)
self
.
text_classifier
=
TextClassifierHelper
(
args
)
self
.
det_client
=
Debugger
()
self
.
det_client
=
Debugger
()
self
.
det_client
.
load_model_config
(
self
.
det_client
.
load_model_config
(
"serving_server_dir"
,
gpu
=
True
,
profile
=
False
)
global_args
.
det_model_dir
,
gpu
=
True
,
profile
=
False
)
self
.
fetch
=
[
"ctc_greedy_decoder_0.tmp_0"
,
"softmax_0.tmp_0"
]
self
.
fetch
=
[
"ctc_greedy_decoder_0.tmp_0"
,
"softmax_0.tmp_0"
]
def
preprocess
(
self
,
img
):
def
preprocess
(
self
,
img
):
feed
,
fetch
,
self
.
tmp_args
=
self
.
text_detector
.
preprocess
(
img
)
feed
,
fetch
,
self
.
tmp_args
=
self
.
text_detector
.
preprocess
(
img
)
fetch_map
=
self
.
det_client
.
predict
(
feed
,
fetch
)
fetch_map
=
self
.
det_client
.
predict
(
feed
,
fetch
)
print
(
"det fetch_map"
,
fetch_map
)
outputs
=
[
fetch_map
[
x
]
for
x
in
fetch
]
outputs
=
[
fetch_map
[
x
]
for
x
in
fetch
]
dt_boxes
=
self
.
text_detector
.
postprocess
(
outputs
,
self
.
tmp_args
)
dt_boxes
=
self
.
text_detector
.
postprocess
(
outputs
,
self
.
tmp_args
)
if
dt_boxes
is
None
:
if
dt_boxes
is
None
:
...
@@ -90,12 +89,10 @@ class OCRService(WebService):
...
@@ -90,12 +89,10 @@ class OCRService(WebService):
def
preprocess
(
self
,
feed
=
[],
fetch
=
[]):
def
preprocess
(
self
,
feed
=
[],
fetch
=
[]):
# TODO: to handle batch rec images
# TODO: to handle batch rec images
print
(
"start preprocess"
)
data
=
base64
.
b64decode
(
feed
[
0
][
"image"
].
encode
(
'utf8'
))
data
=
base64
.
b64decode
(
feed
[
0
][
"image"
].
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
im
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
feed
,
fetch
,
self
.
tmp_args
=
self
.
text_system
.
preprocess
(
im
)
feed
,
fetch
,
self
.
tmp_args
=
self
.
text_system
.
preprocess
(
im
)
print
(
"ocr preprocess done"
)
return
feed
,
fetch
return
feed
,
fetch
def
postprocess
(
self
,
feed
=
{},
fetch
=
[],
fetch_map
=
None
):
def
postprocess
(
self
,
feed
=
{},
fetch
=
[],
fetch_map
=
None
):
...
...
deploy/pdserving/ocr_rpc_server.py
浏览文件 @
c4720557
...
@@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper
...
@@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper
from
det_rpc_server
import
TextDetectorHelper
from
det_rpc_server
import
TextDetectorHelper
from
rec_rpc_server
import
TextRecognizerHelper
from
rec_rpc_server
import
TextRecognizerHelper
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
from
tools.infer.predict_system
import
TextSystem
from
tools.infer.predict_system
import
TextSystem
,
sorted_boxes
import
copy
import
copy
global_args
=
utility
.
parse_args
()
global_args
=
utility
.
parse_args
()
...
@@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem):
...
@@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem):
self
.
text_classifier
=
TextClassifierHelper
(
args
)
self
.
text_classifier
=
TextClassifierHelper
(
args
)
self
.
det_client
=
Client
()
self
.
det_client
=
Client
()
self
.
det_client
.
load_client_config
(
self
.
det_client
.
load_client_config
(
"
ocr_det_server
/serving_client_conf.prototxt"
)
"
det_db_client
/serving_client_conf.prototxt"
)
self
.
det_client
.
connect
([
"127.0.0.1:9293"
])
self
.
det_client
.
connect
([
"127.0.0.1:9293"
])
self
.
fetch
=
[
"ctc_greedy_decoder_0.tmp_0"
,
"softmax_0.tmp_0"
]
self
.
fetch
=
[
"ctc_greedy_decoder_0.tmp_0"
,
"softmax_0.tmp_0"
]
...
@@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem):
...
@@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem):
fetch_map
=
self
.
det_client
.
predict
(
feed
,
fetch
)
fetch_map
=
self
.
det_client
.
predict
(
feed
,
fetch
)
outputs
=
[
fetch_map
[
x
]
for
x
in
fetch
]
outputs
=
[
fetch_map
[
x
]
for
x
in
fetch
]
dt_boxes
=
self
.
text_detector
.
postprocess
(
outputs
,
self
.
tmp_args
)
dt_boxes
=
self
.
text_detector
.
postprocess
(
outputs
,
self
.
tmp_args
)
print
(
dt_boxes
)
if
dt_boxes
is
None
:
if
dt_boxes
is
None
:
return
None
,
None
return
None
,
None
img_crop_list
=
[]
img_crop_list
=
[]
sorted_boxes
=
SortedBoxes
()
dt_boxes
=
sorted_boxes
(
dt_boxes
)
dt_boxes
=
sorted_boxes
(
dt_boxes
)
for
bno
in
range
(
len
(
dt_boxes
)):
for
bno
in
range
(
len
(
dt_boxes
)):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
...
@@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem):
...
@@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem):
feed
,
fetch
,
self
.
tmp_args
=
self
.
text_classifier
.
preprocess
(
feed
,
fetch
,
self
.
tmp_args
=
self
.
text_classifier
.
preprocess
(
img_crop_list
)
img_crop_list
)
fetch_map
=
self
.
clas_client
.
predict
(
feed
,
fetch
)
fetch_map
=
self
.
clas_client
.
predict
(
feed
,
fetch
)
print
(
fetch_map
)
outputs
=
[
fetch_map
[
x
]
for
x
in
self
.
text_classifier
.
fetch
]
outputs
=
[
fetch_map
[
x
]
for
x
in
self
.
text_classifier
.
fetch
]
for
x
in
fetch_map
.
keys
():
for
x
in
fetch_map
.
keys
():
if
".lod"
in
x
:
if
".lod"
in
x
:
...
...
deploy/pdserving/ocr_web_client.py
浏览文件 @
c4720557
...
@@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir):
...
@@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir):
image
=
cv2_to_base64
(
image_data1
)
image
=
cv2_to_base64
(
image_data1
)
data
=
{
"feed"
:
[{
"image"
:
image
}],
"fetch"
:
[
"res"
]}
data
=
{
"feed"
:
[{
"image"
:
image
}],
"fetch"
:
[
"res"
]}
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
print
(
r
)
rjson
=
r
.
json
()
rjson
=
r
.
json
()
print
(
rjson
)
print
(
rjson
)
#for x in rjson["result"]["pred_text"]:
# print(x)
deploy/pdserving/rec_local_server.py
浏览文件 @
c4720557
...
@@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer):
...
@@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer):
rec_idx_lod
=
args
[
"ctc_greedy_decoder_0.tmp_0.lod"
]
rec_idx_lod
=
args
[
"ctc_greedy_decoder_0.tmp_0.lod"
]
predict_lod
=
args
[
"softmax_0.tmp_0.lod"
]
predict_lod
=
args
[
"softmax_0.tmp_0.lod"
]
indices
=
args
[
"indices"
]
indices
=
args
[
"indices"
]
print
(
"indices"
,
indices
,
rec_idx_lod
)
rec_res
=
[[
''
,
0.0
]]
*
(
len
(
rec_idx_lod
)
-
1
)
rec_res
=
[[
''
,
0.0
]]
*
(
len
(
rec_idx_lod
)
-
1
)
for
rno
in
range
(
len
(
rec_idx_lod
)
-
1
):
for
rno
in
range
(
len
(
rec_idx_lod
)
-
1
):
beg
=
rec_idx_lod
[
rno
]
beg
=
rec_idx_lod
[
rno
]
...
@@ -155,7 +154,6 @@ class OCRService(WebService):
...
@@ -155,7 +154,6 @@ class OCRService(WebService):
if
".lod"
in
x
:
if
".lod"
in
x
:
self
.
tmp_args
[
x
]
=
fetch_map
[
x
]
self
.
tmp_args
[
x
]
=
fetch_map
[
x
]
rec_res
=
self
.
text_recognizer
.
postprocess
(
outputs
,
self
.
tmp_args
)
rec_res
=
self
.
text_recognizer
.
postprocess
(
outputs
,
self
.
tmp_args
)
print
(
"rec_res"
,
rec_res
)
res
=
{
res
=
{
"pred_text"
:
[
x
[
0
]
for
x
in
rec_res
],
"pred_text"
:
[
x
[
0
]
for
x
in
rec_res
],
"score"
:
[
str
(
x
[
1
])
for
x
in
rec_res
]
"score"
:
[
str
(
x
[
1
])
for
x
in
rec_res
]
...
...
deploy/pdserving/rec_rpc_server.py
浏览文件 @
c4720557
...
@@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer):
...
@@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer):
rec_idx_lod
=
args
[
"ctc_greedy_decoder_0.tmp_0.lod"
]
rec_idx_lod
=
args
[
"ctc_greedy_decoder_0.tmp_0.lod"
]
predict_lod
=
args
[
"softmax_0.tmp_0.lod"
]
predict_lod
=
args
[
"softmax_0.tmp_0.lod"
]
indices
=
args
[
"indices"
]
indices
=
args
[
"indices"
]
print
(
"indices"
,
indices
,
rec_idx_lod
)
rec_res
=
[[
''
,
0.0
]]
*
(
len
(
rec_idx_lod
)
-
1
)
rec_res
=
[[
''
,
0.0
]]
*
(
len
(
rec_idx_lod
)
-
1
)
for
rno
in
range
(
len
(
rec_idx_lod
)
-
1
):
for
rno
in
range
(
len
(
rec_idx_lod
)
-
1
):
beg
=
rec_idx_lod
[
rno
]
beg
=
rec_idx_lod
[
rno
]
...
@@ -161,7 +160,6 @@ class OCRService(WebService):
...
@@ -161,7 +160,6 @@ class OCRService(WebService):
if
".lod"
in
x
:
if
".lod"
in
x
:
self
.
tmp_args
[
x
]
=
fetch_map
[
x
]
self
.
tmp_args
[
x
]
=
fetch_map
[
x
]
rec_res
=
self
.
text_recognizer
.
postprocess
(
outputs
,
self
.
tmp_args
)
rec_res
=
self
.
text_recognizer
.
postprocess
(
outputs
,
self
.
tmp_args
)
print
(
"rec_res"
,
rec_res
)
res
=
{
res
=
{
"pred_text"
:
[
x
[
0
]
for
x
in
rec_res
],
"pred_text"
:
[
x
[
0
]
for
x
in
rec_res
],
"score"
:
[
str
(
x
[
1
])
for
x
in
rec_res
]
"score"
:
[
str
(
x
[
1
])
for
x
in
rec_res
]
...
...
deploy/pdserving/rec_web_client.py
浏览文件 @
c4720557
...
@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
...
@@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
data
=
{
"feed"
:
[{
"image"
:
image
}],
"fetch"
:
[
"res"
]}
data
=
{
"feed"
:
[{
"image"
:
image
}],
"fetch"
:
[
"res"
]}
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
print
(
r
.
json
())
print
(
r
.
json
())
break
tools/infer/predict_cls.py
浏览文件 @
c4720557
...
@@ -33,7 +33,7 @@ from paddle import fluid
...
@@ -33,7 +33,7 @@ from paddle import fluid
class
TextClassifier
(
object
):
class
TextClassifier
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
if
args
.
use_serving
is
False
:
if
args
.
use_
pd
serving
is
False
:
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"cls"
)
utility
.
create_predictor
(
args
,
mode
=
"cls"
)
self
.
cls_image_shape
=
[
int
(
v
)
for
v
in
args
.
cls_image_shape
.
split
(
","
)]
self
.
cls_image_shape
=
[
int
(
v
)
for
v
in
args
.
cls_image_shape
.
split
(
","
)]
...
...
tools/infer/predict_det.py
浏览文件 @
c4720557
...
@@ -75,7 +75,7 @@ class TextDetector(object):
...
@@ -75,7 +75,7 @@ class TextDetector(object):
else
:
else
:
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
logger
.
info
(
"unknown det_algorithm:{}"
.
format
(
self
.
det_algorithm
))
sys
.
exit
(
0
)
sys
.
exit
(
0
)
if
args
.
use_
gpu
is
False
:
if
args
.
use_
pdserving
is
False
:
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"det"
)
utility
.
create_predictor
(
args
,
mode
=
"det"
)
...
...
tools/infer/predict_rec.py
浏览文件 @
c4720557
...
@@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps
...
@@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps
class
TextRecognizer
(
object
):
class
TextRecognizer
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
if
args
.
use_serving
is
False
:
if
args
.
use_
pd
serving
is
False
:
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"rec"
)
utility
.
create_predictor
(
args
,
mode
=
"rec"
)
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
...
...
tools/infer/predict_system.py
浏览文件 @
c4720557
...
@@ -161,7 +161,12 @@ def main(args):
...
@@ -161,7 +161,12 @@ def main(args):
scores
=
[
rec_res
[
i
][
1
]
for
i
in
range
(
len
(
rec_res
))]
scores
=
[
rec_res
[
i
][
1
]
for
i
in
range
(
len
(
rec_res
))]
draw_img
=
draw_ocr
(
draw_img
=
draw_ocr
(
image
,
boxes
,
txts
,
scores
,
drop_score
=
drop_score
,
font_path
=
font_path
)
image
,
boxes
,
txts
,
scores
,
drop_score
=
drop_score
,
font_path
=
font_path
)
draw_img_save
=
"./inference_results/"
draw_img_save
=
"./inference_results/"
if
not
os
.
path
.
exists
(
draw_img_save
):
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
draw_img_save
)
os
.
makedirs
(
draw_img_save
)
...
...
tools/infer/utility.py
浏览文件 @
c4720557
...
@@ -37,7 +37,7 @@ def parse_args():
...
@@ -37,7 +37,7 @@ def parse_args():
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--use_serving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_
pd
serving"
,
type
=
str2bool
,
default
=
False
)
# params for text detector
# params for text detector
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
...
@@ -73,9 +73,7 @@ def parse_args():
...
@@ -73,9 +73,7 @@ def parse_args():
default
=
"./ppocr/utils/ppocr_keys_v1.txt"
)
default
=
"./ppocr/utils/ppocr_keys_v1.txt"
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
parser
.
add_argument
(
"--vis_font_path"
,
"--vis_font_path"
,
type
=
str
,
default
=
"./doc/simfang.ttf"
)
type
=
str
,
default
=
"./doc/simfang.ttf"
)
# params for text classifier
# params for text classifier
parser
.
add_argument
(
"--use_angle_cls"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_angle_cls"
,
type
=
str2bool
,
default
=
False
)
...
@@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
...
@@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
1
])
**
2
)
1
])
**
2
)
if
box_height
>
2
*
box_width
:
if
box_height
>
2
*
box_width
:
font_size
=
max
(
int
(
box_width
*
0.9
),
10
)
font_size
=
max
(
int
(
box_width
*
0.9
),
10
)
font
=
ImageFont
.
truetype
(
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
font_path
,
font_size
,
encoding
=
"utf-8"
)
cur_y
=
box
[
0
][
1
]
cur_y
=
box
[
0
][
1
]
for
c
in
txt
:
for
c
in
txt
:
char_size
=
font
.
getsize
(
c
)
char_size
=
font
.
getsize
(
c
)
...
@@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
...
@@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
cur_y
+=
char_size
[
1
]
cur_y
+=
char_size
[
1
]
else
:
else
:
font_size
=
max
(
int
(
box_height
*
0.8
),
10
)
font_size
=
max
(
int
(
box_height
*
0.8
),
10
)
font
=
ImageFont
.
truetype
(
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
font_path
,
font_size
,
encoding
=
"utf-8"
)
draw_right
.
text
(
draw_right
.
text
(
[
box
[
0
][
0
],
box
[
0
][
1
]],
txt
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
[
box
[
0
][
0
],
box
[
0
][
1
]],
txt
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
img_left
=
Image
.
blend
(
image
,
img_left
,
0.5
)
img_left
=
Image
.
blend
(
image
,
img_left
,
0.5
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录