Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
52b40f36
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
52b40f36
编写于
10月 16, 2020
作者:
Z
zhoujun
提交者:
GitHub
10月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #952 from WenmuZhou/dygraph
使用PaddleClass的resnet_vd
上级
7d3ba1f1
bdad0cef
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
528 addition
and
505 deletion
+528
-505
configs/det/det_mv3_db.yml
configs/det/det_mv3_db.yml
+7
-7
configs/det/det_r50_vd_db.yml
configs/det/det_r50_vd_db.yml
+4
-4
configs/rec/rec_mv3_none_bilstm_ctc.yml
configs/rec/rec_mv3_none_bilstm_ctc.yml
+2
-2
configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml
configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml
+2
-2
configs/rec/rec_mv3_none_none_ctc_lmdb.yml
configs/rec/rec_mv3_none_none_ctc_lmdb.yml
+105
-0
configs/rec/rec_r34_vd_none_bilstm_ctc.yml
configs/rec/rec_r34_vd_none_bilstm_ctc.yml
+3
-3
ppocr/modeling/architectures/model.py
ppocr/modeling/architectures/model.py
+8
-12
ppocr/modeling/backbones/det_resnet_vd.py
ppocr/modeling/backbones/det_resnet_vd.py
+188
-240
ppocr/modeling/backbones/rec_resnet_vd.py
ppocr/modeling/backbones/rec_resnet_vd.py
+199
-228
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+1
-1
tools/train.py
tools/train.py
+9
-6
未找到文件。
configs/det/det_mv3_db.yml
浏览文件 @
52b40f36
...
...
@@ -3,7 +3,7 @@ Global:
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
2
save_model_dir
:
./output/
20201010
/
save_model_dir
:
./output/
db_mv3
/
save_epoch_step
:
1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
8
...
...
@@ -66,9 +66,9 @@ Metric:
TRAIN
:
dataset
:
name
:
SimpleDataSet
data_dir
:
/home/zhoujun20
/detection/
data_dir
:
.
/detection/
file_list
:
-
/home/zhoujun20
/detection/train_icdar2015_label.txt
# dataset1
-
.
/detection/train_icdar2015_label.txt
# dataset1
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
...
...
@@ -103,14 +103,14 @@ TRAIN:
shuffle
:
True
drop_last
:
False
batch_size
:
16
num_workers
:
6
num_workers
:
8
EVAL
:
dataset
:
name
:
SimpleDataSet
data_dir
:
/home/zhoujun20
/detection/
data_dir
:
.
/detection/
file_list
:
-
/home/zhoujun20
/detection/test_icdar2015_label.txt
-
.
/detection/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
...
...
@@ -130,4 +130,4 @@ EVAL:
shuffle
:
False
drop_last
:
False
batch_size
:
1
# must be 1
num_workers
:
6
\ No newline at end of file
num_workers
:
8
\ No newline at end of file
configs/det/det_r50_vd_db.yml
浏览文件 @
52b40f36
...
...
@@ -3,14 +3,14 @@ Global:
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
2
save_model_dir
:
./output/20201010/
save_model_dir
:
./output/2020101
5_r5
0/
save_epoch_step
:
1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
8
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights
:
True
cal_metric_during_train
:
False
pretrained_model
:
/home/zhoujun20/pretrain_models/
MobileNetV3_large_x0_5_pretrained
pretrained_model
:
/home/zhoujun20/pretrain_models/
ResNet50_vd_ssld_pretrained/
checkpoints
:
#./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
save_inference_dir
:
use_visualdl
:
True
...
...
@@ -102,7 +102,7 @@ TRAIN:
shuffle
:
True
drop_last
:
False
batch_size
:
16
num_workers
:
6
num_workers
:
8
EVAL
:
dataset
:
...
...
@@ -129,4 +129,4 @@ EVAL:
shuffle
:
False
drop_last
:
False
batch_size
:
1
# must be 1
num_workers
:
6
\ No newline at end of file
num_workers
:
8
\ No newline at end of file
configs/rec/rec_mv3_none_bilstm_ctc.yml
浏览文件 @
52b40f36
...
...
@@ -84,7 +84,7 @@ TRAIN:
batch_size
:
256
shuffle
:
True
drop_last
:
True
num_workers
:
6
num_workers
:
8
EVAL
:
dataset
:
...
...
@@ -105,4 +105,4 @@ EVAL:
shuffle
:
False
drop_last
:
False
batch_size
:
256
num_workers
:
6
num_workers
:
8
configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml
浏览文件 @
52b40f36
...
...
@@ -83,7 +83,7 @@ TRAIN:
batch_size
:
256
shuffle
:
True
drop_last
:
True
num_workers
:
6
num_workers
:
8
EVAL
:
dataset
:
...
...
@@ -103,4 +103,4 @@ EVAL:
shuffle
:
False
drop_last
:
False
batch_size
:
256
num_workers
:
6
num_workers
:
8
configs/rec/rec_mv3_none_none_ctc_lmdb.yml
0 → 100644
浏览文件 @
52b40f36
Global
:
use_gpu
:
false
epoch_num
:
500
log_smooth_window
:
20
print_batch_step
:
1
save_model_dir
:
./output/rec/test/
save_epoch_step
:
500
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
1016
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights
:
True
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
#output/rec/rec_crnn/best_accuracy
save_inference_dir
:
use_visualdl
:
True
infer_img
:
doc/imgs_words/ch/word_1.jpg
# for data or label process
max_text_length
:
80
character_dict_path
:
/home/zhoujun20/rec/lmdb/dict.txt
character_type
:
'
en'
use_space_char
:
True
infer_mode
:
False
use_tps
:
False
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
learning_rate
:
name
:
Cosine
lr
:
0.0005
warmup_epoch
:
1
regularizer
:
name
:
'
L2'
factor
:
0.00001
Architecture
:
type
:
rec
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
small
small_stride
:
[
1
,
2
,
2
,
2
]
Neck
:
name
:
SequenceEncoder
encoder_type
:
reshape
Head
:
name
:
CTC
fc_decay
:
0.00001
Loss
:
name
:
CTCLoss
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
TRAIN
:
dataset
:
name
:
LMDBDateSet
file_list
:
-
/Users/zhoujun20/Downloads/evaluation_new
# dataset1
ratio_list
:
[
0.4
,
0.6
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecAug
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
keepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader将按照此顺序返回list
loader
:
batch_size
:
256
shuffle
:
True
drop_last
:
True
num_workers
:
8
EVAL
:
dataset
:
name
:
LMDBDateSet
file_list
:
-
/home/zhoujun20/rec/lmdb/val
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
keepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader将按照此顺序返回list
loader
:
shuffle
:
False
drop_last
:
False
batch_size
:
256
num_workers
:
8
configs/rec/rec_r34_vd_none_bilstm_ctc.yml
浏览文件 @
52b40f36
...
...
@@ -42,7 +42,7 @@ Architecture:
Transform
:
Backbone
:
name
:
ResNet
layers
:
200
layers
:
34
Neck
:
name
:
SequenceEncoder
encoder_type
:
fc
...
...
@@ -82,7 +82,7 @@ TRAIN:
batch_size
:
256
shuffle
:
True
drop_last
:
True
num_workers
:
6
num_workers
:
8
EVAL
:
dataset
:
...
...
@@ -103,4 +103,4 @@ EVAL:
shuffle
:
False
drop_last
:
False
batch_size
:
256
num_workers
:
6
num_workers
:
8
ppocr/modeling/architectures/model.py
浏览文件 @
52b40f36
...
...
@@ -94,13 +94,11 @@ def check_static():
from
ppocr.utils.logging
import
get_logger
from
tools
import
program
config
=
program
.
load_config
(
'configs/
det/det_r50_vd_db
.yml'
)
config
=
program
.
load_config
(
'configs/
rec/rec_r34_vd_none_bilstm_ctc
.yml'
)
# import cv2
# data = cv2.imread('doc/imgs/1.jpg')
# data = normalize(data)
logger
=
get_logger
()
data
=
np
.
zeros
((
1
,
3
,
640
,
640
),
dtype
=
np
.
float32
)
np
.
random
.
seed
(
0
)
data
=
np
.
random
.
rand
(
1
,
3
,
32
,
320
).
astype
(
np
.
float32
)
paddle
.
disable_static
()
config
[
'Architecture'
][
'in_channels'
]
=
3
...
...
@@ -110,17 +108,15 @@ def check_static():
load_dygraph_pretrain
(
model
,
logger
,
'/Users/zhoujun20/Desktop/code/PaddleOCR/
db/db
'
,
'/Users/zhoujun20/Desktop/code/PaddleOCR/
cnn_ctc/cnn_ctc
'
,
load_static_weights
=
True
)
x
=
paddle
.
to_
variable
(
data
)
x
=
paddle
.
to_
tensor
(
data
)
y
=
model
(
x
)
for
y1
in
y
:
print
(
y1
.
shape
)
#
# # from matplotlib import pyplot as plt
# # plt.imshow(y.numpy())
# # plt.show()
static_out
=
np
.
load
(
'/Users/zhoujun20/Desktop/code/PaddleOCR/db/db.npy'
)
static_out
=
np
.
load
(
'/Users/zhoujun20/Desktop/code/PaddleOCR/output/conv.npy'
)
diff
=
y
.
numpy
()
-
static_out
print
(
y
.
shape
,
static_out
.
shape
,
diff
.
mean
())
...
...
ppocr/modeling/backbones/det_resnet_vd.py
浏览文件 @
52b40f36
...
...
@@ -16,143 +16,30 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
__all__
=
[
"ResNet"
]
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
**
kwargs
):
"""
the Resnet backbone network for detection module.
Args:
params(dict): the super parameters for network build
"""
super
(
ResNet
,
self
).
__init__
()
supported_layers
=
{
18
:
{
'depth'
:
[
2
,
2
,
2
,
2
],
'block_class'
:
BasicBlock
},
34
:
{
'depth'
:
[
3
,
4
,
6
,
3
],
'block_class'
:
BasicBlock
},
50
:
{
'depth'
:
[
3
,
4
,
6
,
3
],
'block_class'
:
BottleneckBlock
},
101
:
{
'depth'
:
[
3
,
4
,
23
,
3
],
'block_class'
:
BottleneckBlock
},
152
:
{
'depth'
:
[
3
,
8
,
36
,
3
],
'block_class'
:
BottleneckBlock
},
200
:
{
'depth'
:
[
3
,
12
,
48
,
3
],
'block_class'
:
BottleneckBlock
}
}
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
.
keys
(),
layers
)
is_3x3
=
True
depth
=
supported_layers
[
layers
][
'depth'
]
block_class
=
supported_layers
[
layers
][
'block_class'
]
num_filters
=
[
64
,
128
,
256
,
512
]
conv
=
[]
if
is_3x3
==
False
:
conv
.
append
(
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
act
=
'relu'
))
else
:
conv
.
append
(
ConvBNLayer
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
2
,
act
=
'relu'
,
name
=
'conv1_1'
))
conv
.
append
(
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'conv1_2'
))
conv
.
append
(
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'conv1_3'
))
self
.
conv1
=
nn
.
Sequential
(
*
conv
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
stages
=
[]
self
.
out_channels
=
[]
in_ch
=
64
for
block_index
in
range
(
len
(
depth
)):
block_list
=
[]
for
i
in
range
(
depth
[
block_index
]):
if
layers
>=
50
:
if
layers
in
[
101
,
152
,
200
]
and
block_index
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
chr
(
97
+
i
)
else
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
chr
(
97
+
i
)
block_list
.
append
(
block_class
(
in_channels
=
in_ch
,
out_channels
=
num_filters
[
block_index
],
stride
=
2
if
i
==
0
and
block_index
!=
0
else
1
,
if_first
=
block_index
==
i
==
0
,
name
=
conv_name
))
in_ch
=
block_list
[
-
1
].
out_channels
self
.
out_channels
.
append
(
in_ch
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
for
i
,
stage
in
enumerate
(
self
.
stages
):
self
.
add_sublayer
(
sublayer
=
stage
,
name
=
"stage{}"
.
format
(
i
))
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
pool
(
x
)
out_list
=
[]
for
stage
in
self
.
stages
:
x
=
stage
(
x
)
out_list
.
append
(
x
)
return
out_list
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
is_vd_mode
=
False
,
act
=
None
,
name
=
None
,
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
...
...
@@ -165,87 +52,32 @@ class ConvBNLayer(nn.Layer):
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
self
.
_batch_norm
=
nn
.
BatchNorm
(
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
bn_name
+
"_variance"
)
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
__call__
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
def
forward
(
self
,
inputs
):
if
self
.
is_vd_mode
:
inputs
=
self
.
_pool2d_avg
(
inputs
)
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
ConvBNLayerNew
(
nn
.
Layer
):
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
stride
,
shortcut
=
True
,
if_first
=
False
,
name
=
None
):
super
(
ConvBNLayerNew
,
self
).
__init__
()
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
bn_name
+
"_variance"
)
def
__call__
(
self
,
x
):
x
=
self
.
pool
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
ShortCut
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
if_first
=
False
):
super
(
ShortCut
,
self
).
__init__
()
self
.
use_conv
=
True
if
in_channels
!=
out_channels
or
stride
!=
1
:
if
if_first
:
self
.
conv
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
else
:
self
.
conv
=
ConvBNLayerNew
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
elif
if_first
:
self
.
conv
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
else
:
self
.
use_conv
=
False
def
forward
(
self
,
x
):
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
if_first
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
...
...
@@ -266,32 +98,46 @@ class BottleneckBlock(nn.Layer):
act
=
None
,
name
=
name
+
"_branch2c"
)
self
.
short
=
ShortCut
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
stride
=
stride
,
if_first
=
if_first
,
name
=
name
+
"_branch1"
)
self
.
out_channels
=
out_channels
*
4
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
def
forward
(
self
,
x
):
y
=
self
.
conv0
(
x
)
y
=
self
.
conv1
(
y
)
y
=
self
.
conv2
(
y
)
y
=
y
+
self
.
short
(
x
)
y
=
F
.
relu
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
if_first
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
shortcut
=
True
,
if_first
=
False
,
name
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
...
...
@@ -299,31 +145,133 @@ class BasicBlock(nn.Layer):
kernel_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
self
.
short
=
ShortCut
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
if_first
=
if_first
,
name
=
name
+
"_branch1"
)
self
.
out_channels
=
out_channels
def
forward
(
self
,
x
):
y
=
self
.
conv0
(
x
)
y
=
self
.
conv1
(
y
)
y
=
y
+
self
.
short
(
x
)
return
F
.
relu
(
y
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
return
y
if
__name__
==
'__main__'
:
import
paddle
paddle
.
disable_static
()
x
=
paddle
.
zeros
([
1
,
3
,
640
,
640
])
x
=
paddle
.
to_variable
(
x
)
print
(
x
.
shape
)
net
=
ResNet
(
layers
=
18
)
y
=
net
(
x
)
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
**
kwargs
):
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
supported_layers
=
[
18
,
34
,
50
,
101
,
152
,
200
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
elif
layers
==
200
:
depth
=
[
3
,
12
,
48
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
if
layers
>=
50
else
[
64
,
64
,
128
,
256
]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
conv1_1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
2
,
act
=
'relu'
,
name
=
"conv1_1"
)
self
.
conv1_2
=
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_2"
)
self
.
conv1_3
=
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_3"
)
self
.
pool2d_max
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
stages
=
[]
self
.
out_channels
=
[]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
block_list
=
[]
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BottleneckBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
out_channels
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
))
shortcut
=
True
block_list
.
append
(
bottleneck_block
)
self
.
out_channels
.
append
(
num_filters
[
block
]
*
4
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
else
:
for
block
in
range
(
len
(
depth
)):
block_list
=
[]
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
basic_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BasicBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
],
out_channels
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
))
shortcut
=
True
block_list
.
append
(
basic_block
)
self
.
out_channels
.
append
(
num_filters
[
block
])
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
for
stage
in
y
:
print
(
stage
.
shape
)
# paddle.save(net.state_dict(),'1.pth')
def
forward
(
self
,
inputs
):
y
=
self
.
conv1_1
(
inputs
)
y
=
self
.
conv1_2
(
y
)
y
=
self
.
conv1_3
(
y
)
y
=
self
.
pool2d_max
(
y
)
out
=
[]
for
block
in
self
.
stages
:
y
=
block
(
y
)
out
.
append
(
y
)
return
out
ppocr/modeling/backbones/rec_resnet_vd.py
浏览文件 @
52b40f36
...
...
@@ -16,144 +16,34 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
__all__
=
[
"ResNet"
]
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
34
):
super
(
ResNet
,
self
).
__init__
()
supported_layers
=
{
18
:
{
'depth'
:
[
2
,
2
,
2
,
2
],
'block_class'
:
BasicBlock
},
34
:
{
'depth'
:
[
3
,
4
,
6
,
3
],
'block_class'
:
BasicBlock
},
50
:
{
'depth'
:
[
3
,
4
,
6
,
3
],
'block_class'
:
BottleneckBlock
},
101
:
{
'depth'
:
[
3
,
4
,
23
,
3
],
'block_class'
:
BottleneckBlock
},
152
:
{
'depth'
:
[
3
,
8
,
36
,
3
],
'block_class'
:
BottleneckBlock
},
200
:
{
'depth'
:
[
3
,
12
,
48
,
3
],
'block_class'
:
BottleneckBlock
}
}
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
.
keys
(),
layers
)
is_3x3
=
True
num_filters
=
[
64
,
128
,
256
,
512
]
depth
=
supported_layers
[
layers
][
'depth'
]
block_class
=
supported_layers
[
layers
][
'block_class'
]
conv
=
[]
if
is_3x3
==
False
:
conv
.
append
(
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
1
,
act
=
'relu'
))
else
:
conv
.
append
(
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'conv1_1'
))
conv
.
append
(
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'conv1_2'
))
conv
.
append
(
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'conv1_3'
))
self
.
conv1
=
nn
.
Sequential
(
*
conv
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
block_list
=
[]
in_ch
=
64
for
block_index
in
range
(
len
(
depth
)):
for
i
in
range
(
depth
[
block_index
]):
if
layers
>=
50
:
if
layers
in
[
101
,
152
,
200
]
and
block_index
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
chr
(
97
+
i
)
else
:
conv_name
=
"res"
+
str
(
block_index
+
2
)
+
chr
(
97
+
i
)
if
i
==
0
and
block_index
!=
0
:
stride
=
(
2
,
1
)
else
:
stride
=
(
1
,
1
)
block_list
.
append
(
block_class
(
in_channels
=
in_ch
,
out_channels
=
num_filters
[
block_index
],
stride
=
stride
,
if_first
=
block_index
==
i
==
0
,
name
=
conv_name
))
in_ch
=
block_list
[
-
1
].
out_channels
self
.
block_list
=
nn
.
Sequential
(
*
block_list
)
self
.
add_sublayer
(
sublayer
=
self
.
block_list
,
name
=
"block_list"
)
self
.
pool_out
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_channels
=
in_ch
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
pool
(
x
)
x
=
self
.
block_list
(
x
)
x
=
self
.
pool_out
(
x
)
return
x
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
is_vd_mode
=
False
,
act
=
None
,
name
=
None
,
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
stride
=
1
if
is_vd_mode
else
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
...
...
@@ -162,88 +52,32 @@ class ConvBNLayer(nn.Layer):
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
self
.
_batch_norm
=
nn
.
BatchNorm
(
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
bn_name
+
"_variance"
)
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
__call__
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
def
forward
(
self
,
inputs
):
if
self
.
is_vd_mode
:
inputs
=
self
.
_pool2d_avg
(
inputs
)
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
ConvBNLayerNew
(
nn
.
Layer
):
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
stride
,
shortcut
=
True
,
if_first
=
False
,
name
=
None
):
super
(
ConvBNLayerNew
,
self
).
__init__
()
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
,
padding
=
0
,
ceil_mode
=
True
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
bn_name
+
"_variance"
)
def
__call__
(
self
,
x
):
x
=
self
.
pool
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
ShortCut
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
if_first
=
False
):
super
(
ShortCut
,
self
).
__init__
()
self
.
use_conv
=
True
if
in_channels
!=
out_channels
or
stride
[
0
]
!=
1
:
if
if_first
:
self
.
conv
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
else
:
self
.
conv
=
ConvBNLayerNew
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
elif
if_first
:
self
.
conv
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
else
:
self
.
use_conv
=
False
def
forward
(
self
,
x
):
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
if_first
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
...
...
@@ -264,32 +98,47 @@ class BottleneckBlock(nn.Layer):
act
=
None
,
name
=
name
+
"_branch2c"
)
self
.
short
=
ShortCut
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
stride
=
stride
,
if_first
=
if_first
,
name
=
name
+
"_branch1"
)
self
.
out_channels
=
out_channels
*
4
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
stride
=
stride
,
is_vd_mode
=
not
if_first
and
stride
[
0
]
!=
1
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
def
forward
(
self
,
x
):
y
=
self
.
conv0
(
x
)
y
=
self
.
conv1
(
y
)
y
=
self
.
conv2
(
y
)
y
=
y
+
self
.
short
(
x
)
y
=
F
.
relu
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
if_first
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
shortcut
=
True
,
if_first
=
False
,
name
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
...
...
@@ -297,16 +146,138 @@ class BasicBlock(nn.Layer):
kernel_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
self
.
short
=
ShortCut
(
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
is_vd_mode
=
not
if_first
and
stride
[
0
]
!=
1
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
return
y
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
**
kwargs
):
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
supported_layers
=
[
18
,
34
,
50
,
101
,
152
,
200
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
elif
layers
==
200
:
depth
=
[
3
,
12
,
48
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
if
layers
>=
50
else
[
64
,
64
,
128
,
256
]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
conv1_1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
if_first
=
if_first
,
name
=
name
+
"_branch1"
)
self
.
out_channels
=
out_channels
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_1"
)
self
.
conv1_2
=
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_2"
)
self
.
conv1_3
=
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_3"
)
self
.
pool2d_max
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
block_list
=
[]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
,
200
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
def
forward
(
self
,
x
):
y
=
self
.
conv0
(
x
)
y
=
self
.
conv1
(
y
)
y
=
y
+
self
.
short
(
x
)
return
F
.
relu
(
y
)
if
i
==
0
and
block
!=
0
:
stride
=
(
2
,
1
)
else
:
stride
=
(
1
,
1
)
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BottleneckBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
out_channels
=
num_filters
[
block
],
stride
=
stride
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
))
shortcut
=
True
self
.
block_list
.
append
(
bottleneck_block
)
self
.
out_channels
=
num_filters
[
block
]
else
:
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
if
i
==
0
and
block
!=
0
:
stride
=
(
2
,
1
)
else
:
stride
=
(
1
,
1
)
basic_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BasicBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
],
out_channels
=
num_filters
[
block
],
stride
=
stride
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
))
shortcut
=
True
self
.
block_list
.
append
(
basic_block
)
self
.
out_channels
=
num_filters
[
block
]
self
.
out_pool
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
inputs
):
y
=
self
.
conv1_1
(
inputs
)
y
=
self
.
conv1_2
(
y
)
y
=
self
.
conv1_3
(
y
)
y
=
self
.
pool2d_max
(
y
)
for
block
in
self
.
block_list
:
y
=
block
(
y
)
y
=
self
.
out_pool
(
y
)
return
y
ppocr/modeling/necks/rnn.py
浏览文件 @
52b40f36
...
...
@@ -116,7 +116,7 @@ class EncoderWithFC(nn.Layer):
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
EncoderWithReshape
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
...
...
tools/train.py
浏览文件 @
52b40f36
...
...
@@ -88,20 +88,23 @@ def main(config, device, logger, vdl_writer):
best_model_dict
,
logger
,
vdl_writer
)
def
test_reader
(
config
,
place
,
logger
):
train_loader
=
build_dataloader
(
config
[
'TRAIN'
],
place
)
def
test_reader
(
config
,
place
,
logger
,
global_config
):
train_loader
,
_
=
build_dataloader
(
config
[
'TRAIN'
],
place
,
global_config
=
global_config
)
import
time
starttime
=
time
.
time
()
count
=
0
try
:
for
data
in
train_loader
()
:
for
data
in
train_loader
:
count
+=
1
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
count
,
len
(
data
),
batch_time
))
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
count
,
len
(
data
[
0
]
),
batch_time
))
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
...
...
@@ -130,7 +133,7 @@ def dis_main():
device
))
main
(
config
,
device
,
logger
,
vdl_writer
)
# test_reader(config,
place, logger
)
# test_reader(config,
device, logger, config['Global']
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录