Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
d81baf51
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d81baf51
编写于
4月 12, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
image-classification reader add standardization operations
上级
c8404395
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
36 addition
and
8 deletion
+36
-8
demo/image-classification/retrain.py
demo/image-classification/retrain.py
+6
-1
paddlehub/reader/cv_reader.py
paddlehub/reader/cv_reader.py
+30
-7
未找到文件。
demo/image-classification/retrain.py
浏览文件 @
d81baf51
...
@@ -10,7 +10,11 @@ def train():
...
@@ -10,7 +10,11 @@ def train():
sign_name
=
"feature_map"
,
trainable
=
True
)
sign_name
=
"feature_map"
,
trainable
=
True
)
dataset
=
hub
.
dataset
.
Flowers
()
dataset
=
hub
.
dataset
.
Flowers
()
data_reader
=
hub
.
reader
.
ImageClassificationReader
(
data_reader
=
hub
.
reader
.
ImageClassificationReader
(
image_width
=
224
,
image_height
=
224
,
dataset
=
dataset
)
image_width
=
resnet_module
.
get_excepted_image_width
(),
image_height
=
resnet_module
.
get_excepted_image_height
(),
images_mean
=
resnet_module
.
get_pretrained_images_mean
(),
images_std
=
resnet_module
.
get_pretrained_images_std
(),
dataset
=
dataset
)
with
fluid
.
program_guard
(
program
):
with
fluid
.
program_guard
(
program
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
dtype
=
"int64"
,
shape
=
[
1
])
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
dtype
=
"int64"
,
shape
=
[
1
])
img
=
input_dict
[
0
]
img
=
input_dict
[
0
]
...
@@ -20,6 +24,7 @@ def train():
...
@@ -20,6 +24,7 @@ def train():
use_cuda
=
True
,
use_cuda
=
True
,
num_epoch
=
10
,
num_epoch
=
10
,
batch_size
=
32
,
batch_size
=
32
,
enable_memory_optim
=
False
,
strategy
=
hub
.
finetune
.
strategy
.
DefaultFinetuneStrategy
())
strategy
=
hub
.
finetune
.
strategy
.
DefaultFinetuneStrategy
())
feed_list
=
[
img
.
name
,
label
.
name
]
feed_list
=
[
img
.
name
,
label
.
name
]
...
...
paddlehub/reader/cv_reader.py
浏览文件 @
d81baf51
...
@@ -22,7 +22,7 @@ from PIL import Image
...
@@ -22,7 +22,7 @@ from PIL import Image
import
paddlehub.io.augmentation
as
image_augmentation
import
paddlehub.io.augmentation
as
image_augmentation
c
olor_mode
_dict
=
{
c
hannel_order
_dict
=
{
"RGB"
:
[
0
,
1
,
2
],
"RGB"
:
[
0
,
1
,
2
],
"RBG"
:
[
0
,
2
,
1
],
"RBG"
:
[
0
,
2
,
1
],
"GBR"
:
[
1
,
2
,
0
],
"GBR"
:
[
1
,
2
,
0
],
...
@@ -37,16 +37,35 @@ class ImageClassificationReader(object):
...
@@ -37,16 +37,35 @@ class ImageClassificationReader(object):
image_width
,
image_width
,
image_height
,
image_height
,
dataset
,
dataset
,
color_mode
=
"RGB"
,
channel_order
=
"RGB"
,
images_mean
=
None
,
images_std
=
None
,
data_augmentation
=
False
):
data_augmentation
=
False
):
self
.
image_width
=
image_width
self
.
image_width
=
image_width
self
.
image_height
=
image_height
self
.
image_height
=
image_height
self
.
c
olor_mode
=
color_mode
self
.
c
hannel_order
=
channel_order
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
data_augmentation
=
data_augmentation
self
.
data_augmentation
=
data_augmentation
if
self
.
color_mode
not
in
color_mode_dict
:
self
.
images_std
=
images_std
self
.
images_mean
=
images_mean
if
self
.
images_mean
is
None
:
try
:
self
.
images_mean
=
self
.
dataset
.
images_mean
except
:
self
.
images_mean
=
[
0
,
0
,
0
]
self
.
images_mean
=
np
.
array
(
self
.
images_mean
).
reshape
(
3
,
1
,
1
)
if
self
.
images_std
is
None
:
try
:
self
.
images_std
=
self
.
dataset
.
images_std
except
:
self
.
images_std
=
[
1
,
1
,
1
]
self
.
images_std
=
np
.
array
(
self
.
images_std
).
reshape
(
3
,
1
,
1
)
if
self
.
channel_order
not
in
channel_order_dict
:
raise
ValueError
(
raise
ValueError
(
"
Color_mode should in %s."
%
color_mode
_dict
.
keys
())
"
The channel_order should in %s."
%
channel_order
_dict
.
keys
())
if
self
.
image_width
<=
0
or
self
.
image_height
<=
0
:
if
self
.
image_width
<=
0
or
self
.
image_height
<=
0
:
raise
ValueError
(
"Image width and height should not be negative."
)
raise
ValueError
(
"Image width and height should not be negative."
)
...
@@ -74,12 +93,16 @@ class ImageClassificationReader(object):
...
@@ -74,12 +93,16 @@ class ImageClassificationReader(object):
image
=
image
.
convert
(
'RGB'
)
image
=
image
.
convert
(
'RGB'
)
# HWC to CHW
# HWC to CHW
image
=
np
.
array
(
image
)
image
=
np
.
array
(
image
)
.
astype
(
'float32'
)
if
len
(
image
.
shape
)
==
3
:
if
len
(
image
.
shape
)
==
3
:
image
=
np
.
swapaxes
(
image
,
1
,
2
)
image
=
np
.
swapaxes
(
image
,
1
,
2
)
image
=
np
.
swapaxes
(
image
,
1
,
0
)
image
=
np
.
swapaxes
(
image
,
1
,
0
)
image
=
image
[
color_mode_dict
[
self
.
color_mode
],
:,
:]
# standardization
image
/=
255
image
-=
self
.
images_mean
image
/=
self
.
images_std
image
=
image
[
channel_order_dict
[
self
.
channel_order
],
:,
:]
yield
((
image
,
label
))
yield
((
image
,
label
))
return
paddle
.
batch
(
_data_reader
,
batch_size
=
batch_size
)
return
paddle
.
batch
(
_data_reader
,
batch_size
=
batch_size
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录