Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
6832ca02
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6832ca02
编写于
8月 15, 2020
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update config
上级
09d8cb6d
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
197 addition
and
132 deletion
+197
-132
configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml
configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml
+3
-2
ppocr/data/rec/dataset_traversal.py
ppocr/data/rec/dataset_traversal.py
+25
-24
ppocr/modeling/architectures/rec_model.py
ppocr/modeling/architectures/rec_model.py
+79
-20
ppocr/modeling/heads/self_attention/model.py
ppocr/modeling/heads/self_attention/model.py
+66
-73
tools/eval_utils/eval_rec_utils.py
tools/eval_utils/eval_rec_utils.py
+13
-8
tools/program.py
tools/program.py
+10
-5
train_data
train_data
+1
-0
未找到文件。
configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml
浏览文件 @
6832ca02
...
...
@@ -17,10 +17,11 @@ Global:
average_window
:
0.15
max_average_window
:
15625
min_average_window
:
10000
reader_yml
:
./configs/rec/rec_
srn
_reader.yml
reader_yml
:
./configs/rec/rec_
benchmark
_reader.yml
pretrain_weights
:
checkpoints
:
save_inference_dir
:
infer_img
:
Architecture
:
function
:
ppocr.modeling.architectures.rec_model,RecModel
...
...
ppocr/data/rec/dataset_traversal.py
浏览文件 @
6832ca02
...
...
@@ -118,15 +118,14 @@ class LMDBReader(object):
image_file_list
=
get_image_file_list
(
self
.
infer_img
)
for
single_img
in
image_file_list
:
img
=
cv2
.
imread
(
single_img
)
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
self
.
loss_type
==
'srn'
:
norm_img
=
process_image_srn
(
img
=
img
,
image_shape
=
self
.
image_shape
,
num_heads
=
self
.
num_heads
,
max_text_length
=
self
.
max_text_length
)
max_text_length
=
self
.
max_text_length
)
else
:
norm_img
=
process_image
(
img
=
img
,
...
...
@@ -135,20 +134,20 @@ class LMDBReader(object):
tps
=
self
.
use_tps
,
infer_mode
=
True
)
yield
norm_img
elif
self
.
mode
==
'test
'
:
image_file_list
=
get_image_file_list
(
self
.
infer_img
)
for
single_img
in
image_file_list
:
img
=
cv2
.
imread
(
single_img
)
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
norm_img
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
char_ops
=
self
.
char_ops
,
tps
=
self
.
use_tps
,
infer_mode
=
True
)
yield
norm_img
#elif self.mode == 'eval
':
#
image_file_list = get_image_file_list(self.infer_img)
#
for single_img in image_file_list:
#
img = cv2.imread(single_img)
#
if img.shape[-1]==1 or len(list(img.shape))==2:
#
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
#
norm_img = process_image(
#
img=img,
#
image_shape=self.image_shape,
#
char_ops=self.char_ops,
#
tps=self.use_tps,
#
infer_mode=True
#
)
#
yield norm_img
else
:
lmdb_sets
=
self
.
load_hierarchical_lmdb_dataset
()
if
process_id
==
0
:
...
...
@@ -169,14 +168,15 @@ class LMDBReader(object):
img
,
label
=
sample_info
outs
=
[]
if
self
.
loss_type
==
"srn"
:
outs
=
process_image_srn
(
img
,
self
.
image_shape
,
self
.
num_heads
,
self
.
max_text_length
,
label
,
self
.
char_ops
,
self
.
loss_type
)
outs
=
process_image_srn
(
img
,
self
.
image_shape
,
self
.
num_heads
,
self
.
max_text_length
,
label
,
self
.
char_ops
,
self
.
loss_type
)
else
:
outs
=
process_image
(
img
,
self
.
image_shape
,
label
,
self
.
char_ops
,
self
.
loss_type
,
self
.
max_text_length
)
outs
=
process_image
(
img
,
self
.
image_shape
,
label
,
self
.
char_ops
,
self
.
loss_type
,
self
.
max_text_length
)
if
outs
is
None
:
continue
yield
outs
...
...
@@ -184,6 +184,7 @@ class LMDBReader(object):
if
finish_read_num
==
len
(
lmdb_sets
):
break
self
.
close_lmdb_dataset
(
lmdb_sets
)
def
batch_iter_reader
():
batch_outs
=
[]
for
outs
in
sample_iter_reader
():
...
...
ppocr/modeling/architectures/rec_model.py
浏览文件 @
6832ca02
...
...
@@ -79,17 +79,45 @@ class RecModel(object):
feed_list
=
[
image
,
label_in
,
label_out
]
labels
=
{
'label_in'
:
label_in
,
'label_out'
:
label_out
}
elif
self
.
loss_type
==
"srn"
:
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
lbl_weight
=
fluid
.
layers
.
data
(
name
=
"lbl_weight"
,
shape
=
[
-
1
,
1
],
dtype
=
'int64'
)
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
lbl_weight
=
fluid
.
layers
.
data
(
name
=
"lbl_weight"
,
shape
=
[
-
1
,
1
],
dtype
=
'int64'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
-
1
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
feed_list
=
[
image
,
label
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
]
labels
=
{
'label'
:
label
,
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
,
'lbl_weight'
:
lbl_weight
}
feed_list
=
[
image
,
label
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
]
labels
=
{
'label'
:
label
,
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
,
'lbl_weight'
:
lbl_weight
}
else
:
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
...
...
@@ -114,13 +142,39 @@ class RecModel(object):
image_shape
=
deepcopy
(
self
.
image_shape
)
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
if
self
.
loss_type
==
"srn"
:
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
feed_list
=
[
image
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
labels
=
{
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
}
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
])
feed_list
=
[
image
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
labels
=
{
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
}
return
image
,
labels
,
loader
def
__call__
(
self
,
mode
):
...
...
@@ -140,8 +194,13 @@ class RecModel(object):
label
=
labels
[
'label'
]
if
self
.
loss_type
==
'srn'
:
total_loss
,
img_loss
,
word_loss
=
self
.
loss
(
predicts
,
labels
)
outputs
=
{
'total_loss'
:
total_loss
,
'img_loss'
:
img_loss
,
'word_loss'
:
word_loss
,
'decoded_out'
:
decoded_out
,
'label'
:
label
}
outputs
=
{
'total_loss'
:
total_loss
,
'img_loss'
:
img_loss
,
'word_loss'
:
word_loss
,
'decoded_out'
:
decoded_out
,
'label'
:
label
}
else
:
outputs
=
{
'total_loss'
:
loss
,
'decoded_out'
:
\
decoded_out
,
'label'
:
label
}
...
...
ppocr/modeling/heads/self_attention/model.py
浏览文件 @
6832ca02
...
...
@@ -4,8 +4,9 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
from
.desc
import
*
from
.config
import
ModelHyperParams
,
TrainTaskConfig
# Set seed for CE
dropout_seed
=
None
def
wrap_layer_with_block
(
layer
,
block_idx
):
"""
...
...
@@ -269,7 +270,8 @@ pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer
=
pre_post_process_layer
def
prepare_encoder
(
src_word
,
#[b,t,c]
def
prepare_encoder
(
src_word
,
#[b,t,c]
src_pos
,
src_vocab_size
,
src_emb_dim
,
...
...
@@ -284,8 +286,8 @@ def prepare_encoder(src_word,#[b,t,c]
This module is used at the bottom of the encoder stacks.
"""
src_word_emb
=
src_word
#layers.concat(res,axis=1)
src_word_emb
=
layers
.
cast
(
src_word_emb
,
'float32'
)
src_word_emb
=
src_word
#layers.concat(res,axis=1)
src_word_emb
=
layers
.
cast
(
src_word_emb
,
'float32'
)
# print("src_word_emb",src_word_emb)
src_word_emb
=
layers
.
scale
(
x
=
src_word_emb
,
scale
=
src_emb_dim
**
0.5
)
...
...
@@ -323,7 +325,7 @@ def prepare_decoder(src_word,
name
=
word_emb_param_name
,
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
src_emb_dim
**-
0.5
)))
# print("target_word_emb",src_word_emb)
src_word_emb
=
layers
.
scale
(
x
=
src_word_emb
,
scale
=
src_emb_dim
**
0.5
)
src_word_emb
=
layers
.
scale
(
x
=
src_word_emb
,
scale
=
src_emb_dim
**
0.5
)
src_pos_enc
=
layers
.
embedding
(
src_pos
,
size
=
[
src_max_len
,
src_emb_dim
],
...
...
@@ -335,6 +337,7 @@ def prepare_decoder(src_word,
enc_input
,
dropout_prob
=
dropout_rate
,
seed
=
dropout_seed
,
is_test
=
False
)
if
dropout_rate
else
enc_input
# prepare_encoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
# prepare_decoder = partial(
...
...
@@ -595,21 +598,9 @@ def transformer(src_vocab_size,
weights
=
all_inputs
[
-
1
]
enc_output
=
wrap_encoder
(
src_vocab_size
,
ModelHyperParams
.
src_seq_len
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
)
src_vocab_size
,
64
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
)
predict
=
wrap_decoder
(
trg_vocab_size
,
...
...
@@ -676,8 +667,8 @@ def wrap_encoder_forFeature(src_vocab_size,
conv_features
,
src_pos
,
src_slf_attn_bias
=
make_all_inputs
(
encoder_data_input_fields
)
else
:
conv_features
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
b
,
t
,
c
=
conv_features
.
shape
conv_features
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
b
,
t
,
c
=
conv_features
.
shape
#"""
# insert cnn
#"""
...
...
@@ -718,7 +709,7 @@ def wrap_encoder_forFeature(src_vocab_size,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
word_emb_param_names
[
0
]
)
word_emb_param_name
=
"src_word_emb_table"
)
enc_output
=
encoder
(
enc_input
,
...
...
@@ -736,6 +727,7 @@ def wrap_encoder_forFeature(src_vocab_size,
postprocess_cmd
,
)
return
enc_output
def
wrap_encoder
(
src_vocab_size
,
max_length
,
n_layer
,
...
...
@@ -762,7 +754,7 @@ def wrap_encoder(src_vocab_size,
src_word
,
src_pos
,
src_slf_attn_bias
=
make_all_inputs
(
encoder_data_input_fields
)
else
:
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
#"""
# insert cnn
#"""
...
...
@@ -802,7 +794,7 @@ def wrap_encoder(src_vocab_size,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
word_emb_param_names
[
0
]
)
word_emb_param_name
=
"src_word_emb_table"
)
enc_output
=
encoder
(
enc_input
,
...
...
@@ -858,8 +850,8 @@ def wrap_decoder(trg_vocab_size,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
word_emb_param_names
[
0
]
if
weight_sharing
else
word_emb_param_names
[
1
]
)
word_emb_param_name
=
"src_word_emb_table"
if
weight_sharing
else
"trg_word_emb_table"
)
dec_output
=
decoder
(
dec_input
,
enc_output
,
...
...
@@ -886,7 +878,7 @@ def wrap_decoder(trg_vocab_size,
predict
=
layers
.
matmul
(
x
=
dec_output
,
y
=
fluid
.
default_main_program
().
global_block
().
var
(
word_emb_param_names
[
0
]
),
"trg_word_emb_table"
),
transpose_y
=
True
)
else
:
predict
=
layers
.
fc
(
input
=
dec_output
,
...
...
@@ -931,12 +923,13 @@ def fast_decode(src_vocab_size,
enc_inputs_len
=
len
(
encoder_data_input_fields
)
dec_inputs_len
=
len
(
fast_decoder_data_input_fields
)
enc_inputs
=
all_inputs
[
0
:
enc_inputs_len
]
#enc_inputs tensor
dec_inputs
=
all_inputs
[
enc_inputs_len
:
enc_inputs_len
+
dec_inputs_len
]
#dec_inputs tensor
enc_inputs
=
all_inputs
[
0
:
enc_inputs_len
]
#enc_inputs tensor
dec_inputs
=
all_inputs
[
enc_inputs_len
:
enc_inputs_len
+
dec_inputs_len
]
#dec_inputs tensor
enc_output
=
wrap_encoder
(
src_vocab_size
,
ModelHyperParams
.
src_seq_len
,
##to do !!!!!????
64
,
##to do !!!!!????
n_layer
,
n_head
,
d_key
,
...
...
tools/eval_utils/eval_rec_utils.py
浏览文件 @
6832ca02
...
...
@@ -75,7 +75,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
preds_lod
=
outs
[
0
].
lod
()[
0
]
labels
,
labels_lod
=
convert_rec_label_to_lod
(
label_list
)
acc
,
acc_num
,
sample_num
=
cal_predicts_accuracy
(
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
else
:
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
...
...
@@ -89,10 +90,14 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
label_list
=
np
.
concatenate
(
label_list
,
axis
=
0
)
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
,
axis
=
0
).
astype
(
np
.
float32
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
,
axis
=
0
).
astype
(
np
.
float32
)
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
,
axis
=
0
).
astype
(
np
.
float32
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
,
axis
=
0
).
astype
(
np
.
float32
)
labels
=
label_list
...
...
@@ -108,7 +113,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_acc_num
+=
acc_num
total_sample_num
+=
sample_num
logger
.
info
(
"eval batch id: {}, acc: {}"
.
format
(
total_batch_num
,
acc
))
#
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num
+=
1
avg_acc
=
total_acc_num
*
1.0
/
total_sample_num
metrics
=
{
'avg_acc'
:
avg_acc
,
"total_acc_num"
:
total_acc_num
,
\
...
...
tools/program.py
浏览文件 @
6832ca02
...
...
@@ -34,6 +34,7 @@ from ppocr.utils.save_load import save_model
import
numpy
as
np
from
ppocr.utils.character
import
cal_predicts_accuracy
,
cal_predicts_accuracy_srn
,
CharacterOps
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
super
(
ArgsParser
,
self
).
__init__
(
...
...
@@ -196,10 +197,13 @@ def build(config, main_prog, startup_prog, mode):
if
config
[
'Global'
][
"loss_type"
]
==
'srn'
:
model_average
=
fluid
.
optimizer
.
ModelAverage
(
config
[
'Global'
][
'average_window'
],
min_average_window
=
config
[
'Global'
][
'min_average_window'
],
max_average_window
=
config
[
'Global'
][
'max_average_window'
])
min_average_window
=
config
[
'Global'
][
'min_average_window'
],
max_average_window
=
config
[
'Global'
][
'max_average_window'
])
return
(
dataloader
,
fetch_name_list
,
fetch_varname_list
,
opt_loss_name
,
model_average
)
return
(
dataloader
,
fetch_name_list
,
fetch_varname_list
,
opt_loss_name
,
model_average
)
def
build_export
(
config
,
main_prog
,
startup_prog
):
...
...
@@ -398,6 +402,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
return
def
preprocess
():
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
...
...
@@ -409,8 +414,8 @@ def preprocess():
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
...
...
train_data
0 → 120000
浏览文件 @
6832ca02
/workspace/PaddleOCR/train_data/
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录