Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
6335b0d8
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6335b0d8
编写于
2月 24, 2021
作者:
D
Double_V
提交者:
GitHub
2月 24, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'dygraph' into fix2013
上级
a898305c
3ce97f18
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
118 addition
and
49 deletion
+118
-49
StyleText/engine/predictors.py
StyleText/engine/predictors.py
+24
-1
StyleText/engine/synthesisers.py
StyleText/engine/synthesisers.py
+13
-7
StyleText/engine/text_drawers.py
StyleText/engine/text_drawers.py
+44
-17
deploy/hubserving/ocr_det/params.py
deploy/hubserving/ocr_det/params.py
+2
-1
deploy/hubserving/ocr_system/params.py
deploy/hubserving/ocr_system/params.py
+2
-1
paddleocr.py
paddleocr.py
+4
-2
ppocr/losses/det_basic_loss.py
ppocr/losses/det_basic_loss.py
+1
-1
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+3
-3
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+25
-16
未找到文件。
StyleText/engine/predictors.py
浏览文件 @
6335b0d8
...
...
@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
self
.
std
=
config
[
"Predictor"
][
"std"
]
self
.
expand_result
=
config
[
"Predictor"
][
"expand_result"
]
def
predict
(
self
,
style_input
,
text_input
):
def
reshape_to_same_height
(
self
,
img_list
):
h
=
img_list
[
0
].
shape
[
0
]
for
idx
in
range
(
1
,
len
(
img_list
)):
new_w
=
round
(
1.0
*
img_list
[
idx
].
shape
[
1
]
/
img_list
[
idx
].
shape
[
0
]
*
h
)
img_list
[
idx
]
=
cv2
.
resize
(
img_list
[
idx
],
(
new_w
,
h
))
return
img_list
def
predict_single_image
(
self
,
style_input
,
text_input
):
style_input
=
self
.
rep_style_input
(
style_input
,
text_input
)
tensor_style_input
=
self
.
preprocess
(
style_input
)
tensor_text_input
=
self
.
preprocess
(
text_input
)
...
...
@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
"fake_bg"
:
fake_bg
,
}
def
predict
(
self
,
style_input
,
text_input_list
):
if
not
isinstance
(
text_input_list
,
(
tuple
,
list
)):
return
self
.
predict_single_image
(
style_input
,
text_input_list
)
synth_result_list
=
[]
for
text_input
in
text_input_list
:
synth_result
=
self
.
predict_single_image
(
style_input
,
text_input
)
synth_result_list
.
append
(
synth_result
)
for
key
in
synth_result
:
res
=
[
r
[
key
]
for
r
in
synth_result_list
]
res
=
self
.
reshape_to_same_height
(
res
)
synth_result
[
key
]
=
np
.
concatenate
(
res
,
axis
=
1
)
return
synth_result
def
preprocess
(
self
,
img
):
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img_height
,
img_width
,
channel
=
img
.
shape
...
...
StyleText/engine/synthesisers.py
浏览文件 @
6335b0d8
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
cv2
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.logging
import
get_logger
...
...
@@ -36,8 +38,9 @@ class ImageSynthesiser(object):
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
corpus
,
text_input
=
self
.
text_drawer
.
draw_text
(
corpus
,
language
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
corpus_list
,
text_input_list
=
self
.
text_drawer
.
draw_text
(
corpus
,
language
,
style_input_width
=
style_input
.
shape
[
1
])
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input_list
)
return
synth_result
...
...
@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
for
i
in
range
(
self
.
output_num
):
style_data
=
self
.
style_sampler
.
sample
()
style_input
=
style_data
[
"image"
]
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
(
)
text_input_label
,
text_input
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
corpus_language
)
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
()
text_input_label_list
,
text_input_list
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
corpus_language
,
style_input_width
=
style_input
.
shape
[
1
])
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
text_input_label
=
""
.
join
(
text_input_label_list
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input_list
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_label
()
...
...
StyleText/engine/text_drawers.py
浏览文件 @
6335b0d8
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
numpy
as
np
import
cv2
from
utils.logging
import
get_logger
...
...
@@ -28,7 +29,11 @@ class StdTextDrawer(object):
else
:
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
):
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
,
style_input_width
=
None
):
if
language
not
in
self
.
support_languages
:
self
.
logger
.
warning
(
"language {} not supported, use en instead."
.
format
(
language
))
...
...
@@ -37,21 +42,43 @@ class StdTextDrawer(object):
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
else
:
width
=
len
(
corpus
)
*
self
.
height
+
4
if
style_input_width
is
not
None
:
width
=
min
(
width
,
style_input_width
)
corpus_list
=
[]
text_input_list
=
[]
while
len
(
corpus
)
!=
0
:
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
draw
=
ImageDraw
.
Draw
(
bg
)
char_x
=
2
font
=
self
.
font_dict
[
language
]
for
i
,
char_i
in
enumerate
(
corpus
):
i
=
0
while
i
<
len
(
corpus
):
char_i
=
corpus
[
i
]
char_size
=
font
.
getsize
(
char_i
)[
0
]
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
if
char_x
+
char_size
>=
width
and
i
!=
0
:
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
corpus_list
.
append
(
corpus
[
0
:
i
])
text_input_list
.
append
(
text_input
)
corpus
=
corpus
[
i
:]
break
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
char_x
+=
char_size
if
char_x
>=
width
:
corpus
=
corpus
[
0
:
i
+
1
]
self
.
logger
.
warning
(
"corpus length exceed limit: {}"
.
format
(
corpus
))
break
i
+=
1
# the whole text is shorter than style input
if
i
==
len
(
corpus
):
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
return
corpus
,
text_input
corpus_list
.
append
(
corpus
[
0
:
i
])
text_input_list
.
append
(
text_input
)
corpus
=
corpus
[
i
:]
break
return
corpus_list
,
text_input_list
deploy/hubserving/ocr_det/params.py
浏览文件 @
6335b0d8
...
...
@@ -20,7 +20,8 @@ def read_params():
#DB parmas
cfg
.
det_db_thresh
=
0.3
cfg
.
det_db_box_thresh
=
0.5
cfg
.
det_db_unclip_ratio
=
2.0
cfg
.
det_db_unclip_ratio
=
1.6
cfg
.
use_dilation
=
False
# #EAST parmas
# cfg.det_east_score_thresh = 0.8
...
...
deploy/hubserving/ocr_system/params.py
浏览文件 @
6335b0d8
...
...
@@ -20,7 +20,8 @@ def read_params():
#DB parmas
cfg
.
det_db_thresh
=
0.3
cfg
.
det_db_box_thresh
=
0.5
cfg
.
det_db_unclip_ratio
=
2.0
cfg
.
det_db_unclip_ratio
=
1.6
cfg
.
use_dilation
=
False
#EAST parmas
cfg
.
det_east_score_thresh
=
0.8
...
...
paddleocr.py
浏览文件 @
6335b0d8
...
...
@@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True):
# DB parmas
parser
.
add_argument
(
"--det_db_thresh"
,
type
=
float
,
default
=
0.3
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
2.0
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.6
)
parser
.
add_argument
(
"--use_dilation"
,
type
=
bool
,
default
=
False
)
# EAST parmas
parser
.
add_argument
(
"--det_east_score_thresh"
,
type
=
float
,
default
=
0.8
)
...
...
@@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True):
det_limit_type
=
'max'
,
det_db_thresh
=
0.3
,
det_db_box_thresh
=
0.5
,
det_db_unclip_ratio
=
2.0
,
det_db_unclip_ratio
=
1.6
,
use_dilation
=
False
,
det_east_score_thresh
=
0.8
,
det_east_cover_thresh
=
0.1
,
det_east_nms_thresh
=
0.2
,
...
...
ppocr/losses/det_basic_loss.py
浏览文件 @
6335b0d8
...
...
@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
i
,
:,
:],
ohem_ratio
))
selected_masks
=
np
.
concatenate
(
selected_masks
,
0
)
selected_masks
=
paddle
.
to_
variable
(
selected_masks
)
selected_masks
=
paddle
.
to_
tensor
(
selected_masks
)
return
selected_masks
ppocr/modeling/heads/rec_att_head.py
浏览文件 @
6335b0d8
...
...
@@ -57,6 +57,9 @@ class AttentionHead(nn.Layer):
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
char_onehots
=
None
outputs
=
None
alpha
=
None
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
...
...
@@ -146,9 +149,6 @@ class AttentionLSTM(nn.Layer):
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
char_onehots
=
None
outputs
=
None
alpha
=
None
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
...
...
tools/infer/predict_rec.py
浏览文件 @
6335b0d8
...
...
@@ -248,9 +248,11 @@ class TextRecognizer(object):
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
text_recognizer
=
TextRecognizer
(
args
)
total_run_time
=
0.0
total_images_num
=
0
valid_image_file_list
=
[]
img_list
=
[]
for
i
mage_file
in
image_file_list
:
for
i
dx
,
image_file
in
enumerate
(
image_file_list
)
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
...
...
@@ -259,8 +261,11 @@ def main(args):
continue
valid_image_file_list
.
append
(
image_file
)
img_list
.
append
(
img
)
if
len
(
img_list
)
>=
args
.
rec_batch_num
or
idx
==
len
(
image_file_list
)
-
1
:
try
:
rec_res
,
predict_time
=
text_recognizer
(
img_list
)
total_run_time
+=
predict_time
except
:
logger
.
info
(
traceback
.
format_exc
())
logger
.
info
(
...
...
@@ -268,13 +273,17 @@ def main(args):
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq
\n
"
"If your model has tps module: "
"TPS does not support variable shape.
\n
"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit
()
for
ino
in
range
(
len
(
img_list
)):
logger
.
info
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
logger
.
info
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
total_images_num
+=
len
(
valid_image_file_list
)
valid_image_file_list
=
[]
img_list
=
[]
logger
.
info
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
len
(
img_list
),
predict
_time
))
total_images_num
,
total_run
_time
))
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录