Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PLSC
提交
cececbbf
P
PLSC
项目概览
PaddlePaddle
/
PLSC
通知
10
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
5
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PLSC
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
5
Issue
5
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cececbbf
编写于
2月 06, 2020
作者:
D
danleifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add reader data_format
上级
b85e4841
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
37 addition
and
18 deletion
+37
-18
plsc/entry.py
plsc/entry.py
+3
-2
plsc/utils/base64_reader.py
plsc/utils/base64_reader.py
+17
-8
plsc/utils/jpeg_reader.py
plsc/utils/jpeg_reader.py
+17
-8
未找到文件。
plsc/entry.py
浏览文件 @
cececbbf
...
...
@@ -693,7 +693,8 @@ class Entry(object):
if
self
.
predict_reader
is
None
:
predict_reader
=
paddle
.
batch
(
reader
.
arc_train
(
self
.
dataset_dir
,
self
.
num_classes
),
self
.
num_classes
,
data_format
=
self
.
data_format
),
batch_size
=
self
.
train_batch_size
)
else
:
predict_reader
=
self
.
predict_reader
...
...
@@ -925,7 +926,7 @@ class Entry(object):
if
self
.
train_reader
is
None
:
train_reader
=
paddle
.
batch
(
reader
.
arc_train
(
self
.
dataset_dir
,
self
.
num_classes
),
self
.
dataset_dir
,
self
.
num_classes
,
data_format
=
self
.
data_format
),
batch_size
=
self
.
train_batch_size
)
else
:
train_reader
=
self
.
train_reader
...
...
plsc/utils/base64_reader.py
浏览文件 @
cececbbf
...
...
@@ -172,7 +172,8 @@ def process_image(sample,
color_jitter
,
rotate
,
rand_mirror
,
normalize
):
normalize
,
data_format
=
'NCHW'
):
img_data
=
base64
.
b64decode
(
sample
[
0
])
img
=
Image
.
open
(
StringIO
(
img_data
))
...
...
@@ -199,6 +200,9 @@ def process_image(sample,
assert
sample
[
1
]
<
class_dim
,
\
"label of train dataset should be less than the class_dim."
if
data_format
==
'NHWC'
:
img
=
img
.
transpose
((
1
,
2
,
0
))
return
img
,
sample
[
1
]
...
...
@@ -208,7 +212,8 @@ def arc_iterator(data_dir,
color_jitter
=
False
,
rotate
=
False
,
rand_mirror
=
False
,
normalize
=
False
):
normalize
=
False
,
data_format
=
'NCHW'
):
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
num_trainers
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"1"
))
...
...
@@ -237,11 +242,12 @@ def arc_iterator(data_dir,
color_jitter
=
color_jitter
,
rotate
=
rotate
,
rand_mirror
=
rand_mirror
,
normalize
=
normalize
)
normalize
=
normalize
,
data_format
=
data_format
)
return
paddle
.
reader
.
xmap_readers
(
mapper
,
reader
,
THREAD
,
BUF_SIZE
)
def
load_bin
(
path
,
image_size
):
def
load_bin
(
path
,
image_size
,
data_format
=
'NCHW'
):
if
six
.
PY2
:
bins
,
issame_list
=
pickle
.
load
(
open
(
path
,
'rb'
))
else
:
...
...
@@ -267,6 +273,8 @@ def load_bin(path, image_size):
img
=
np
.
array
(
img
).
astype
(
'float32'
).
transpose
((
2
,
0
,
1
))
img
-=
img_mean
img
/=
img_std
if
data_format
==
'NHWC'
:
img
=
img
.
transpose
((
1
,
2
,
0
))
data_list
[
flip
][
i
][:]
=
img
if
i
%
1000
==
0
:
print
(
'loading bin'
,
i
)
...
...
@@ -274,7 +282,7 @@ def load_bin(path, image_size):
return
data_list
,
issame_list
def
train
(
data_dir
,
num_classes
):
def
train
(
data_dir
,
num_classes
,
data_format
=
'NCHW'
):
file_path
=
os
.
path
.
join
(
data_dir
,
'file_list.txt'
)
return
arc_iterator
(
data_dir
,
file_path
,
...
...
@@ -282,16 +290,17 @@ def train(data_dir, num_classes):
color_jitter
=
False
,
rotate
=
False
,
rand_mirror
=
True
,
normalize
=
True
)
normalize
=
True
,
data_format
=
data_format
)
def
test
(
data_dir
,
datasets
):
def
test
(
data_dir
,
datasets
,
data_format
=
'NCHW'
):
test_list
=
[]
test_name_list
=
[]
for
name
in
datasets
.
split
(
','
):
path
=
os
.
path
.
join
(
data_dir
,
name
+
".bin"
)
if
os
.
path
.
exists
(
path
):
data_set
=
load_bin
(
path
,
(
DATA_DIM
,
DATA_DIM
))
data_set
=
load_bin
(
path
,
(
DATA_DIM
,
DATA_DIM
)
,
data_format
=
data_format
)
test_list
.
append
(
data_set
)
test_name_list
.
append
(
name
)
print
(
'test'
,
name
)
...
...
plsc/utils/jpeg_reader.py
浏览文件 @
cececbbf
...
...
@@ -184,7 +184,8 @@ def process_image_imagepath(sample,
color_jitter
,
rotate
,
rand_mirror
,
normalize
):
normalize
,
data_format
=
'NCHW'
):
imgpath
=
sample
[
0
]
img
=
Image
.
open
(
imgpath
)
...
...
@@ -211,6 +212,9 @@ def process_image_imagepath(sample,
assert
sample
[
1
]
<
class_dim
,
\
"label of train dataset should be less than the class_dim."
if
data_format
==
'NHWC'
:
img
=
img
.
transpose
((
1
,
2
,
0
))
return
img
,
sample
[
1
]
...
...
@@ -221,7 +225,8 @@ def arc_iterator(data,
color_jitter
=
False
,
rotate
=
False
,
rand_mirror
=
False
,
normalize
=
False
):
normalize
=
False
,
data_format
=
'NCHW'
):
def
reader
():
if
shuffle
:
random
.
shuffle
(
data
)
...
...
@@ -235,11 +240,12 @@ def arc_iterator(data,
color_jitter
=
color_jitter
,
rotate
=
rotate
,
rand_mirror
=
rand_mirror
,
normalize
=
normalize
)
normalize
=
normalize
,
data_format
=
data_format
)
return
paddle
.
reader
.
xmap_readers
(
mapper
,
reader
,
THREAD
,
BUF_SIZE
)
def
load_bin
(
path
,
image_size
):
def
load_bin
(
path
,
image_size
,
data_format
=
'NCHW'
):
if
six
.
PY2
:
bins
,
issame_list
=
pickle
.
load
(
open
(
path
,
'rb'
))
else
:
...
...
@@ -265,6 +271,8 @@ def load_bin(path, image_size):
img
=
np
.
array
(
img
).
astype
(
'float32'
).
transpose
((
2
,
0
,
1
))
img
-=
img_mean
img
/=
img_std
if
data_format
==
'NHWC'
:
img
=
img
.
transpose
((
1
,
2
,
0
))
data_list
[
flip
][
i
][:]
=
img
if
i
%
1000
==
0
:
print
(
'loading bin'
,
i
)
...
...
@@ -272,7 +280,7 @@ def load_bin(path, image_size):
return
data_list
,
issame_list
def
arc_train
(
data_dir
,
class_dim
):
def
arc_train
(
data_dir
,
class_dim
,
data_format
=
'NCHW'
):
train_image_list
=
get_train_image_list
(
data_dir
)
return
arc_iterator
(
train_image_list
,
shuffle
=
True
,
...
...
@@ -281,16 +289,17 @@ def arc_train(data_dir, class_dim):
color_jitter
=
False
,
rotate
=
False
,
rand_mirror
=
True
,
normalize
=
True
)
normalize
=
True
,
data_format
=
data_format
)
def
test
(
data_dir
,
datasets
):
def
test
(
data_dir
,
datasets
,
data_format
=
'NCHW'
):
test_list
=
[]
test_name_list
=
[]
for
name
in
datasets
.
split
(
','
):
path
=
os
.
path
.
join
(
data_dir
,
name
+
".bin"
)
if
os
.
path
.
exists
(
path
):
data_set
=
load_bin
(
path
,
(
DATA_DIM
,
DATA_DIM
))
data_set
=
load_bin
(
path
,
(
DATA_DIM
,
DATA_DIM
)
,
data_format
=
data_format
)
test_list
.
append
(
data_set
)
test_name_list
.
append
(
name
)
print
(
'test'
,
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录