Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
dc6e724e
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
dc6e724e
编写于
11月 09, 2020
作者:
D
dyning
提交者:
GitHub
11月 09, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1139 from WenmuZhou/dygraph_rc
新增功能
上级
b863528d
b2004fe5
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
84 addition
and
74 deletion
+84
-74
ppocr/modeling/backbones/det_resnet_vd.py
ppocr/modeling/backbones/det_resnet_vd.py
+8
-5
ppocr/modeling/backbones/rec_resnet_vd.py
ppocr/modeling/backbones/rec_resnet_vd.py
+9
-6
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+5
-2
ppocr/postprocess/db_postprocess_torch.py
ppocr/postprocess/db_postprocess_torch.py
+5
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+6
-6
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+4
-4
tools/export_model.py
tools/export_model.py
+15
-9
tools/infer/predict_det.py
tools/infer/predict_det.py
+19
-9
tools/infer_det.py
tools/infer_det.py
+7
-16
tools/infer_rec.py
tools/infer_rec.py
+6
-15
未找到文件。
ppocr/modeling/backbones/det_resnet_vd.py
浏览文件 @
dc6e724e
...
...
@@ -19,6 +19,7 @@ from __future__ import print_function
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
__all__
=
[
"ResNet"
]
...
...
@@ -37,9 +38,9 @@ class ConvBNLayer(nn.Layer):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
nn
.
AvgPool2
d
(
self
.
_pool2d_avg
=
nn
.
AvgPool2
D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
nn
.
Conv2
d
(
self
.
_conv
=
nn
.
Conv2
D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
...
...
@@ -118,7 +119,8 @@ class BottleneckBlock(nn.Layer):
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
...
...
@@ -165,7 +167,8 @@ class BasicBlock(nn.Layer):
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv1
)
y
=
F
.
relu
(
y
)
return
y
...
...
@@ -214,7 +217,7 @@ class ResNet(nn.Layer):
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_3"
)
self
.
pool2d_max
=
nn
.
MaxPool2
d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
pool2d_max
=
nn
.
MaxPool2
D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
stages
=
[]
self
.
out_channels
=
[]
...
...
ppocr/modeling/backbones/rec_resnet_vd.py
浏览文件 @
dc6e724e
...
...
@@ -19,6 +19,7 @@ from __future__ import print_function
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
__all__
=
[
"ResNet"
]
...
...
@@ -37,9 +38,9 @@ class ConvBNLayer(nn.Layer):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
nn
.
AvgPool2
d
(
self
.
_pool2d_avg
=
nn
.
AvgPool2
D
(
kernel_size
=
stride
,
stride
=
stride
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
nn
.
Conv2
d
(
self
.
_conv
=
nn
.
Conv2
D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
...
...
@@ -119,7 +120,8 @@ class BottleneckBlock(nn.Layer):
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
...
...
@@ -166,7 +168,8 @@ class BasicBlock(nn.Layer):
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv1
)
y
=
F
.
relu
(
y
)
return
y
...
...
@@ -215,7 +218,7 @@ class ResNet(nn.Layer):
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_3"
)
self
.
pool2d_max
=
nn
.
MaxPool2
d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
pool2d_max
=
nn
.
MaxPool2
D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
block_list
=
[]
if
layers
>=
50
:
...
...
@@ -270,7 +273,7 @@ class ResNet(nn.Layer):
shortcut
=
True
self
.
block_list
.
append
(
basic_block
)
self
.
out_channels
=
num_filters
[
block
]
self
.
out_pool
=
nn
.
MaxPool2
d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_pool
=
nn
.
MaxPool2
D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
inputs
):
y
=
self
.
conv1_1
(
inputs
)
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
dc6e724e
...
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
numpy
as
np
import
cv2
import
paddle
from
shapely.geometry
import
Polygon
import
pyclipper
...
...
@@ -130,7 +131,9 @@ class DBPostProcess(object):
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
def
__call__
(
self
,
pred
,
shape_list
):
pred
=
pred
.
numpy
()[:,
0
,
:,
:]
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
pred
[:,
0
,
:,
:]
segmentation
=
pred
>
self
.
thresh
boxes_batch
=
[]
...
...
@@ -140,4 +143,4 @@ class DBPostProcess(object):
pred
[
batch_index
],
segmentation
[
batch_index
],
width
,
height
)
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
return
boxes_batch
\ No newline at end of file
ppocr/postprocess/db_postprocess_torch.py
浏览文件 @
dc6e724e
import
cv2
import
paddle
import
numpy
as
np
import
pyclipper
from
shapely.geometry
import
Polygon
...
...
@@ -23,7 +24,9 @@ class DBPostProcess():
pred:
binary: text region segmentation map, with shape (N, 1,H, W)
'''
pred
=
pred
.
numpy
()[:,
0
,
:,
:]
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
pred
[:,
0
,
:,
:]
segmentation
=
self
.
binarize
(
pred
)
batch_out
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
...
...
@@ -130,4 +133,4 @@ class DBPostProcess():
box
[:,
0
]
=
box
[:,
0
]
-
xmin
box
[:,
1
]
=
box
[:,
1
]
-
ymin
cv2
.
fillPoly
(
mask
,
box
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
\ No newline at end of file
ppocr/postprocess/rec_postprocess.py
浏览文件 @
dc6e724e
...
...
@@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode):
character_type
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
# out = self.decode_preds(preds)
preds
=
F
.
softmax
(
preds
,
axis
=
2
).
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
...
...
@@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode):
return
dict_character
def
decode_preds
(
self
,
preds
):
probs
=
F
.
softmax
(
preds
,
axis
=
2
).
numpy
()
probs_ind
=
np
.
argmax
(
probs
,
axis
=
2
)
probs_ind
=
np
.
argmax
(
preds
,
axis
=
2
)
B
,
N
,
_
=
preds
.
shape
l
=
np
.
ones
(
B
).
astype
(
np
.
int64
)
*
N
length
=
paddle
.
to_
variable
(
l
)
length
=
paddle
.
to_
tensor
(
l
)
out
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
preds
,
0
,
length
)
batch_res
=
[
x
[:
idx
[
0
]]
for
x
,
idx
in
zip
(
out
[
0
].
numpy
(),
out
[
1
].
numpy
())
]
result_list
=
[]
for
sample_idx
,
ind
,
prob
in
zip
(
batch_res
,
probs_ind
,
pr
ob
s
):
for
sample_idx
,
ind
,
prob
in
zip
(
batch_res
,
probs_ind
,
pr
ed
s
):
char_list
=
[
self
.
character
[
idx
]
for
idx
in
sample_idx
]
valid_ind
=
np
.
where
(
ind
!=
0
)[
0
]
if
len
(
valid_ind
)
==
0
:
...
...
@@ -172,4 +172,4 @@ class AttnLabelDecode(BaseRecLabelDecode):
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
return
idx
\ No newline at end of file
ppocr/utils/save_load.py
浏览文件 @
dc6e724e
...
...
@@ -68,11 +68,11 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
param_state_dict
[
key
]
=
pre_state_dict
[
weight_name
]
else
:
param_state_dict
[
key
]
=
model_dict
[
key
]
model
.
set_dict
(
param_state_dict
)
model
.
set_
state_
dict
(
param_state_dict
)
return
param_state_dict
,
optim_state_dict
=
paddle
.
load
(
path
)
model
.
set_dict
(
param_state_dict
)
param_state_dict
=
paddle
.
load
(
path
+
'.pdparams'
)
model
.
set_
state_
dict
(
param_state_dict
)
return
...
...
@@ -91,7 +91,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
"Given dir {}.pdopt not exist."
.
format
(
checkpoints
)
para_dict
=
paddle
.
load
(
checkpoints
+
'.pdparams'
)
opti_dict
=
paddle
.
load
(
checkpoints
+
'.pdopt'
)
model
.
set_dict
(
para_dict
)
model
.
set_
state_
dict
(
para_dict
)
if
optimizer
is
not
None
:
optimizer
.
set_state_dict
(
opti_dict
)
...
...
tools/export_model.py
浏览文件 @
dc6e724e
...
...
@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
argparse
import
paddle
...
...
@@ -20,14 +27,11 @@ from paddle.jit import to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
from
tools.program
import
merge_config
def
parse_args
():
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
...
...
@@ -43,7 +47,7 @@ class Model(paddle.nn.Layer):
# Please modify the 'shape' according to actual needs
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
32
,
None
],
dtype
=
'float32'
)
shape
=
[
None
,
3
,
640
,
640
],
dtype
=
'float32'
)
])
def
forward
(
self
,
inputs
):
x
=
self
.
pre_model
(
inputs
)
...
...
@@ -53,14 +57,13 @@ class Model(paddle.nn.Layer):
def
main
():
FLAGS
=
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
logger
=
get_logger
()
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
# build model
#for rec algorithm
#
for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
...
@@ -69,7 +72,10 @@ def main():
model
.
eval
()
model
=
Model
(
model
)
paddle
.
jit
.
save
(
model
,
FLAGS
.
output_path
)
save_path
=
'{}/{}'
.
format
(
FLAGS
.
output_path
,
config
[
'Architecture'
][
'model_type'
])
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
if
__name__
==
"__main__"
:
...
...
tools/infer/predict_det.py
浏览文件 @
dc6e724e
...
...
@@ -22,7 +22,6 @@ import cv2
import
numpy
as
np
import
time
import
sys
import
paddle
import
tools.infer.utility
as
utility
...
...
@@ -39,7 +38,7 @@ class TextDetector(object):
postprocess_params
=
{}
if
self
.
det_algorithm
==
"DB"
:
pre_process_list
=
[{
'ResizeForTest'
:
{
'
Det
ResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
}
...
...
@@ -53,7 +52,7 @@ class TextDetector(object):
},
{
'ToCHWImage'
:
None
},
{
'
k
eepKeys'
:
{
'
K
eepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]
}
}]
...
...
@@ -68,8 +67,9 @@ class TextDetector(object):
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
=
paddle
.
jit
.
load
(
args
.
det_model_dir
)
self
.
predictor
.
eval
()
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
utility
.
create_predictor
(
args
,
'det'
,
logger
)
# paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def
order_points_clockwise
(
self
,
pts
):
"""
...
...
@@ -133,11 +133,23 @@ class TextDetector(object):
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
shape_list
=
np
.
expand_dims
(
shape_list
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
preds
=
self
.
predictor
(
img
)
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
zero_copy_run
()
else
:
im
=
paddle
.
fluid
.
core
.
PaddleTensor
(
img
)
self
.
predictor
.
run
([
im
])
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
# preds = self.predictor(img)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
dt_boxes
=
post_result
[
0
][
'points'
]
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
...
...
@@ -146,8 +158,6 @@ class TextDetector(object):
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
place
=
paddle
.
CPUPlace
()
paddle
.
disable_static
(
place
)
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
logger
=
get_logger
()
...
...
tools/infer_det.py
浏览文件 @
dc6e724e
...
...
@@ -29,12 +29,11 @@ import cv2
import
json
import
paddle
from
ppocr.utils.logging
import
get_logger
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling
import
build_model
from
ppocr.modeling
.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
print_dict
,
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
...
...
@@ -67,11 +66,11 @@ def main():
# create data ops
transforms
=
[]
for
op
in
config
[
'E
VAL
'
][
'dataset'
][
'transforms'
]:
for
op
in
config
[
'E
val
'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
elif
op_name
==
'
k
eepKeys'
:
elif
op_name
==
'
K
eepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
,
'shape'
]
transforms
.
append
(
op
)
...
...
@@ -92,8 +91,7 @@ def main():
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
shape_list
=
np
.
expand_dims
(
batch
[
1
],
axis
=
0
)
images
=
paddle
.
to_variable
(
images
)
print
(
images
.
shape
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
boxes
=
post_result
[
0
][
'points'
]
...
...
@@ -109,14 +107,7 @@ def main():
draw_det_res
(
boxes
,
config
,
src_img
,
file
)
logger
.
info
(
"success!"
)
# save inference model
# paddle.jit.save(model, 'output/model')
if
__name__
==
'__main__'
:
place
,
config
=
program
.
preprocess
()
paddle
.
disable_static
(
place
)
logger
=
get_logger
()
print_dict
(
config
,
logger
)
main
()
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
\ No newline at end of file
tools/infer_rec.py
浏览文件 @
dc6e724e
...
...
@@ -27,12 +27,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import
paddle
from
ppocr.utils.logging
import
get_logger
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling
import
build_model
from
ppocr.modeling
.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
print_dict
,
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
...
...
@@ -54,13 +53,13 @@ def main():
# create data ops
transforms
=
[]
for
op
in
config
[
'E
VAL
'
][
'dataset'
][
'transforms'
]:
for
op
in
config
[
'E
val
'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
elif
op_name
in
[
'RecResizeImg'
]:
op
[
op_name
][
'infer_mode'
]
=
True
elif
op_name
==
'
k
eepKeys'
:
elif
op_name
==
'
K
eepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
...
...
@@ -75,22 +74,14 @@ def main():
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_
variable
(
images
)
images
=
paddle
.
to_
tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
for
rec_reuslt
in
post_result
:
logger
.
info
(
'
\t
result: {}'
.
format
(
rec_reuslt
))
logger
.
info
(
"success!"
)
# save inference model
# currently, paddle.jit.to_static not support rnn
# paddle.jit.save(model, 'output/rec/model')
if
__name__
==
'__main__'
:
place
,
config
=
program
.
preprocess
()
paddle
.
disable_static
(
place
)
logger
=
get_logger
()
print_dict
(
config
,
logger
)
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录