Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
9f581156
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9f581156
编写于
1月 24, 2021
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix data replication for multi-cards sampling
上级
8216aa9e
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
9 addition
and
6 deletion
+9
-6
ppocr/data/__init__.py
ppocr/data/__init__.py
+2
-2
ppocr/data/lmdb_dataset.py
ppocr/data/lmdb_dataset.py
+1
-1
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+4
-1
tools/program.py
tools/program.py
+2
-2
未找到文件。
ppocr/data/__init__.py
浏览文件 @
9f581156
...
@@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp)
...
@@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp)
signal
.
signal
(
signal
.
SIGTERM
,
term_mp
)
signal
.
signal
(
signal
.
SIGTERM
,
term_mp
)
def
build_dataloader
(
config
,
mode
,
device
,
logger
):
def
build_dataloader
(
config
,
mode
,
device
,
logger
,
seed
=
None
):
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
support_dict
=
[
'SimpleDataSet'
,
'LMDBDateSet'
]
support_dict
=
[
'SimpleDataSet'
,
'LMDBDateSet'
]
...
@@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger):
...
@@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
],
"Mode should be Train, Eval or Test."
],
"Mode should be Train, Eval or Test."
dataset
=
eval
(
module_name
)(
config
,
mode
,
logger
)
dataset
=
eval
(
module_name
)(
config
,
mode
,
logger
,
seed
)
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
batch_size
=
loader_config
[
'batch_size_per_card'
]
batch_size
=
loader_config
[
'batch_size_per_card'
]
drop_last
=
loader_config
[
'drop_last'
]
drop_last
=
loader_config
[
'drop_last'
]
...
...
ppocr/data/lmdb_dataset.py
浏览文件 @
9f581156
...
@@ -21,7 +21,7 @@ from .imaug import transform, create_operators
...
@@ -21,7 +21,7 @@ from .imaug import transform, create_operators
class
LMDBDateSet
(
Dataset
):
class
LMDBDateSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
super
(
LMDBDateSet
,
self
).
__init__
()
super
(
LMDBDateSet
,
self
).
__init__
()
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
...
...
ppocr/data/simple_dataset.py
浏览文件 @
9f581156
...
@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
...
@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class
SimpleDataSet
(
Dataset
):
class
SimpleDataSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
super
(
SimpleDataSet
,
self
).
__init__
()
super
(
SimpleDataSet
,
self
).
__init__
()
self
.
logger
=
logger
self
.
logger
=
logger
...
@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
...
@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
seed
=
seed
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
...
@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
...
@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for
idx
,
file
in
enumerate
(
file_list
):
for
idx
,
file
in
enumerate
(
file_list
):
with
open
(
file
,
"rb"
)
as
f
:
with
open
(
file
,
"rb"
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
random
.
seed
(
self
.
seed
)
lines
=
random
.
sample
(
lines
,
lines
=
random
.
sample
(
lines
,
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
lines
)
data_lines
.
extend
(
lines
)
...
@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
...
@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def
shuffle_data_random
(
self
):
def
shuffle_data_random
(
self
):
if
self
.
do_shuffle
:
if
self
.
do_shuffle
:
random
.
seed
(
self
.
seed
)
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
...
...
tools/program.py
浏览文件 @
9f581156
...
@@ -182,8 +182,8 @@ def train(config,
...
@@ -182,8 +182,8 @@ def train(config,
start_epoch
=
1
start_epoch
=
1
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
if
epoch
>
0
:
train_dataloader
=
build_dataloader
(
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
)
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
train_batch_cost
=
0.0
train_batch_cost
=
0.0
train_reader_cost
=
0.0
train_reader_cost
=
0.0
batch_sum
=
0
batch_sum
=
0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录