Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a0d3efa2
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
接近 2 年 前同步成功
通知
284
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a0d3efa2
编写于
11月 05, 2020
作者:
H
haoyuying
提交者:
GitHub
11月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revise hub2.0 module
上级
0bdbbd73
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
128 addition
and
157 deletion
+128
-157
demo/colorization/predict.py
demo/colorization/predict.py
+1
-1
demo/colorization/train.py
demo/colorization/train.py
+4
-7
demo/style_transfer/train.py
demo/style_transfer/train.py
+8
-8
modules/image/colorization/user_guided_colorization/data_feed.py
.../image/colorization/user_guided_colorization/data_feed.py
+10
-10
modules/image/colorization/user_guided_colorization/module.py
...les/image/colorization/user_guided_colorization/module.py
+18
-39
modules/image/keypoint_detection/openpose_body_estimation/module.py
...age/keypoint_detection/openpose_body_estimation/module.py
+2
-2
modules/image/keypoint_detection/openpose_body_estimation/processor.py
.../keypoint_detection/openpose_body_estimation/processor.py
+19
-0
modules/image/keypoint_detection/openpose_hands_estimation/module.py
...ge/keypoint_detection/openpose_hands_estimation/module.py
+5
-6
modules/image/keypoint_detection/openpose_hands_estimation/processor.py
...keypoint_detection/openpose_hands_estimation/processor.py
+28
-0
modules/image/style_transfer/msgnet/module.py
modules/image/style_transfer/msgnet/module.py
+3
-4
paddlehub/datasets/minicoco.py
paddlehub/datasets/minicoco.py
+2
-0
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+21
-32
paddlehub/vision/transforms.py
paddlehub/vision/transforms.py
+7
-48
未找到文件。
demo/colorization/predict.py
浏览文件 @
a0d3efa2
...
...
@@ -2,5 +2,5 @@ import paddle
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
)
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
,
load_checkpoint
=
'/PATH/TO/CHECKPOINT'
)
result
=
model
.
predict
(
images
=
'house.png'
)
demo/colorization/train.py
浏览文件 @
a0d3efa2
...
...
@@ -6,13 +6,10 @@ from paddlehub.datasets import Canvas
if
__name__
==
'__main__'
:
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
)
transform
=
T
.
Compose
(
[
T
.
Resize
((
256
,
256
),
interpolation
=
'NEAREST'
),
T
.
RandomPaddingCrop
(
crop_size
=
176
),
T
.
RGB2LAB
()],
stay_rgb
=
True
,
is_permute
=
False
)
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
,
classification
=
True
,
prob
=
0.125
)
transform
=
T
.
Compose
([
T
.
Resize
((
256
,
256
),
interpolation
=
'NEAREST'
),
T
.
RandomPaddingCrop
(
crop_size
=
176
),
T
.
RGB2LAB
()])
color_set
=
Canvas
(
transform
=
transform
,
mode
=
'train'
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
...
...
demo/style_transfer/train.py
浏览文件 @
a0d3efa2
import
paddle
import
paddlehub
as
hub
import
paddlehub.vision.transforms
as
T
from
paddlehub.finetune.trainer
import
Trainer
from
paddlehub.datasets
import
MiniCOCO
from
paddlehub.datasets.minicoco
import
MiniCOCO
import
paddlehub.vision.transforms
as
T
if
__name__
==
"__main__"
:
model
=
hub
.
Module
(
name
=
'msgnet'
)
transform
=
T
.
Compose
([
T
.
Resize
(
(
256
,
256
),
interpolation
=
'LINEAR'
),
T
.
CenterCrop
(
crop_size
=
256
)],
T
.
SetType
(
datatype
=
'float32'
))
transform
=
T
.
Compose
([
T
.
Resize
((
256
,
256
),
interpolation
=
'LINEAR'
)])
styledata
=
MiniCOCO
(
transform
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'img_style_transfer_ckpt'
)
trainer
.
train
(
styledata
,
epochs
=
5
,
batch_size
=
16
,
eval_dataset
=
styledata
,
log_interval
=
1
,
save_interval
=
1
)
scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
0.001
,
power
=
0.9
,
decay_steps
=
100
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'test_style_ckpt'
)
trainer
.
train
(
styledata
,
epochs
=
101
,
batch_size
=
4
,
eval_dataset
=
styledata
,
log_interval
=
10
,
save_interval
=
10
)
modules/image/colorization/user_guided_colorization/data_feed.py
浏览文件 @
a0d3efa2
...
...
@@ -18,7 +18,6 @@ class ColorizeHint:
hint(np.ndarray): hint images
mask(np.ndarray): mask images
"""
def
__init__
(
self
,
percent
:
float
,
num_points
:
int
=
None
,
samp
:
str
=
'normal'
,
use_avg
:
bool
=
True
):
self
.
percent
=
percent
self
.
num_points
=
num_points
...
...
@@ -37,7 +36,7 @@ class ColorizeHint:
while
cont_cond
:
if
self
.
num_points
is
None
:
# draw from geometric
# embed()
cont_cond
=
np
.
random
.
rand
()
>
(
1
-
self
.
percent
)
cont_cond
=
np
.
random
.
rand
()
<
(
1
-
self
.
percent
)
else
:
# add certain number of points
cont_cond
=
pp
<
self
.
num_points
if
not
cont_cond
:
# skip out of loop if condition not met
...
...
@@ -53,9 +52,11 @@ class ColorizeHint:
# add color point
if
self
.
use_avg
:
# embed()
hint
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
np
.
mean
(
np
.
mean
(
data
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
],
axis
=
2
,
keepdims
=
True
),
axis
=
1
,
keepdims
=
True
).
reshape
(
1
,
C
,
1
,
1
)
hint
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
np
.
mean
(
np
.
mean
(
data
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
],
axis
=
2
,
keepdims
=
True
),
axis
=
1
,
keepdims
=
True
).
reshape
(
1
,
C
,
1
,
1
)
else
:
hint
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
data
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
mask
[
nn
,
:,
h
:
h
+
P
,
w
:
w
+
P
]
=
1
...
...
@@ -81,16 +82,15 @@ class ColorizePreprocess:
data(dict):The preprocessed data for colorization.
"""
def
__init__
(
self
,
ab_thresh
:
float
=
0.
,
p
:
float
=
0.
,
num_
points
:
int
=
None
,
points
:
int
=
None
,
samp
:
str
=
'normal'
,
use_avg
:
bool
=
True
):
self
.
ab_thresh
=
ab_thresh
self
.
p
=
p
self
.
num_points
=
num_
points
self
.
num_points
=
points
self
.
samp
=
samp
self
.
use_avg
=
use_avg
self
.
gethint
=
ColorizeHint
(
percent
=
self
.
p
,
num_points
=
self
.
num_points
,
samp
=
self
.
samp
,
use_avg
=
self
.
use_avg
)
...
...
@@ -113,8 +113,8 @@ class ColorizePreprocess:
data
[
'B'
]
=
data_lab
[:,
1
:,
:,
:]
if
self
.
ab_thresh
>
0
:
# mask out grayscale images
thresh
=
1.
*
self
.
ab_thresh
/
110
mask
=
np
.
sum
(
np
.
abs
(
np
.
max
(
np
.
max
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)
-
np
.
min
(
np
.
min
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)),
axis
=
1
)
mask
=
np
.
sum
(
np
.
abs
(
np
.
max
(
np
.
max
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)
-
np
.
min
(
np
.
min
(
data
[
'B'
],
axis
=
3
),
axis
=
2
)),
axis
=
1
)
mask
=
(
mask
>=
thresh
)
data
[
'A'
]
=
data
[
'A'
][
mask
,
:,
:,
:]
data
[
'B'
]
=
data
[
'B'
][
mask
,
:,
:,
:]
...
...
modules/image/colorization/user_guided_colorization/module.py
浏览文件 @
a0d3efa2
...
...
@@ -42,11 +42,13 @@ class UserGuidedColorization(nn.Layer):
"""
def
__init__
(
self
,
use_tanh
:
bool
=
True
,
classification
:
bool
=
True
,
load_checkpoint
:
str
=
None
):
def
__init__
(
self
,
use_tanh
:
bool
=
True
,
classification
:
bool
=
True
,
load_checkpoint
:
str
=
None
,
ab_thresh
:
float
=
0.
,
prob
:
float
=
1.
,
num_point
:
int
=
None
):
super
(
UserGuidedColorization
,
self
).
__init__
()
self
.
input_nc
=
4
self
.
output_nc
=
2
self
.
classification
=
classification
self
.
pre_func
=
ColorizePreprocess
(
ab_thresh
=
ab_thresh
,
p
=
prob
,
points
=
num_point
)
# Conv1
model1
=
(
Conv2D
(
self
.
input_nc
,
64
,
3
,
1
,
1
),
...
...
@@ -121,8 +123,8 @@ class UserGuidedColorization(nn.Layer):
)
# Conv8
model8up
=
(
Conv2DTranspose
(
512
,
256
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),
)
model3short8
=
(
Conv2D
(
256
,
256
,
3
,
1
,
1
),
)
model8up
=
(
Conv2DTranspose
(
512
,
256
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),)
model3short8
=
(
Conv2D
(
256
,
256
,
3
,
1
,
1
),)
model8
=
(
nn
.
ReLU
(),
Conv2D
(
256
,
256
,
3
,
1
,
1
),
...
...
@@ -133,26 +135,20 @@ class UserGuidedColorization(nn.Layer):
)
# Conv9
model9up
=
(
Conv2DTranspose
(
256
,
128
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),
)
model2short9
=
(
Conv2D
(
128
,
128
,
3
,
1
,
1
,
),
)
model9up
=
(
Conv2DTranspose
(
256
,
128
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),)
model2short9
=
(
Conv2D
(
128
,
128
,
3
,
1
,
1
,),)
model9
=
(
nn
.
ReLU
(),
Conv2D
(
128
,
128
,
3
,
1
,
1
),
nn
.
ReLU
(),
nn
.
BatchNorm
(
128
))
# Conv10
model10up
=
(
Conv2DTranspose
(
128
,
128
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),
)
model1short10
=
(
Conv2D
(
64
,
128
,
3
,
1
,
1
),
)
model10up
=
(
Conv2DTranspose
(
128
,
128
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),)
model1short10
=
(
Conv2D
(
64
,
128
,
3
,
1
,
1
),)
model10
=
(
nn
.
ReLU
(),
Conv2D
(
128
,
128
,
3
,
1
,
1
),
nn
.
LeakyReLU
(
negative_slope
=
0.2
))
model_class
=
(
Conv2D
(
256
,
529
,
1
),
)
model_class
=
(
Conv2D
(
256
,
529
,
1
),)
if
use_tanh
:
model_out
=
(
Conv2D
(
128
,
2
,
1
,
1
,
0
,
1
),
nn
.
Tanh
())
else
:
model_out
=
(
Conv2D
(
128
,
2
,
1
,
1
,
0
,
1
),
)
model_out
=
(
Conv2D
(
128
,
2
,
1
,
1
,
0
,
1
),)
self
.
model1
=
nn
.
Sequential
(
*
model1
)
self
.
model2
=
nn
.
Sequential
(
*
model2
)
...
...
@@ -183,24 +179,14 @@ class UserGuidedColorization(nn.Layer):
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transforms
(
self
,
images
:
str
,
is_train
:
bool
=
True
)
->
callable
:
if
is_train
:
transform
=
T
.
Compose
(
[
T
.
Resize
((
256
,
256
),
interpolation
=
'NEAREST'
),
T
.
RandomPaddingCrop
(
crop_size
=
176
),
T
.
RGB2LAB
()],
stay_rgb
=
True
,
is_permute
=
False
)
else
:
transform
=
T
.
Compose
([
T
.
Resize
(
(
256
,
256
),
interpolation
=
'NEAREST'
),
T
.
RGB2LAB
()],
stay_rgb
=
True
,
is_permute
=
False
)
def
transforms
(
self
,
images
:
str
)
->
callable
:
transform
=
T
.
Compose
([
T
.
Resize
((
256
,
256
),
interpolation
=
'NEAREST'
),
T
.
RGB2LAB
()],
to_rgb
=
True
)
return
transform
(
images
)
def
preprocess
(
self
,
inputs
:
paddle
.
Tensor
,
ab_thresh
:
float
=
0.
,
prob
:
float
=
0.
):
self
.
preprocess
=
ColorizePreprocess
(
ab_thresh
=
ab_thresh
,
p
=
prob
)
return
self
.
preprocess
(
inputs
)
def
preprocess
(
self
,
inputs
:
paddle
.
Tensor
):
output
=
self
.
pre_func
(
inputs
)
return
output
def
forward
(
self
,
input_A
:
paddle
.
Tensor
,
...
...
@@ -233,11 +219,4 @@ class UserGuidedColorization(nn.Layer):
conv10_2
=
self
.
model10
(
conv10_up
)
out_reg
=
self
.
model_out
(
conv10_2
)
return
out_class
,
out_reg
if
__name__
==
"__main__"
:
place
=
paddle
.
CUDAPlace
(
0
)
paddle
.
disable_static
()
model
=
UserGuidedColorization
()
model
.
eval
()
return
out_class
,
out_reg
\ No newline at end of file
modules/image/keypoint_detection/openpose_body_estimation/module.py
浏览文件 @
a0d3efa2
...
...
@@ -23,7 +23,7 @@ import numpy as np
from
paddlehub.transforms.module
import
moduleinfo
import
paddlehub.transforms.transforms
as
T
import
openpose_body_estimation.processor
as
P
@
moduleinfo
(
name
=
"openpose_body_estimation"
,
...
...
@@ -45,7 +45,7 @@ class BodyPoseModel(nn.Layer):
def
__init__
(
self
,
load_checkpoint
:
str
=
None
,
visualization
:
bool
=
True
):
super
(
BodyPoseModel
,
self
).
__init__
()
self
.
resize_func
=
T
.
ResizeScaling
()
self
.
resize_func
=
P
.
ResizeScaling
()
self
.
norm_func
=
T
.
Normalize
(
std
=
[
1
,
1
,
1
])
self
.
pad_func
=
P
.
PadDownRight
()
self
.
remove_pad
=
P
.
RemovePadding
()
...
...
modules/image/keypoint_detection/openpose_body_estimation/processor.py
浏览文件 @
a0d3efa2
import
math
from
typing
import
Callable
import
cv2
import
numpy
as
np
...
...
@@ -309,3 +310,21 @@ class Candidate:
subset
=
np
.
delete
(
subset
,
deleteIdx
,
axis
=
0
)
return
candidate
,
subset
class
ResizeScaling
:
"""Resize images by scaling method.
Args:
target(int): Target image size.
interpolation(Callable): Interpolation method.
"""
def
__init__
(
self
,
target
:
int
=
368
,
interpolation
:
Callable
=
cv2
.
INTER_CUBIC
):
self
.
target
=
target
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
,
scale_search
):
scale
=
scale_search
*
self
.
target
/
img
.
shape
[
0
]
resize_img
=
cv2
.
resize
(
img
,
(
0
,
0
),
fx
=
scale
,
fy
=
scale
,
interpolation
=
self
.
interpolation
)
return
resize_img
modules/image/keypoint_detection/openpose_hands_estimation/module.py
浏览文件 @
a0d3efa2
...
...
@@ -24,8 +24,7 @@ import paddlehub as hub
from
skimage.measure
import
label
from
scipy.ndimage.filters
import
gaussian_filter
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.process.functional
import
npmax
import
paddlehub.transforms.transforms
as
T
import
paddlehub.vision.transforms
as
T
import
openpose_hands_estimation.processor
as
P
...
...
@@ -50,9 +49,9 @@ class HandPoseModel(nn.Layer):
def
__init__
(
self
,
load_checkpoint
:
str
=
None
,
visualization
:
bool
=
True
):
super
(
HandPoseModel
,
self
).
__init__
()
self
.
visualization
=
visualization
self
.
resize_func
=
T
.
ResizeScaling
()
self
.
norm_func
=
T
.
Normalize
(
std
=
[
1
,
1
,
1
])
self
.
resize_func
=
P
.
ResizeScaling
()
self
.
hand_detect
=
P
.
HandDetect
()
self
.
pad_func
=
P
.
PadDownRight
()
self
.
remove_pad
=
P
.
RemovePadding
()
...
...
@@ -164,7 +163,7 @@ class HandPoseModel(nn.Layer):
label_img
[
label_img
!=
max_index
]
=
0
map_ori
[
label_img
==
0
]
=
0
y
,
x
=
npmax
(
map_ori
)
y
,
x
=
P
.
npmax
(
map_ori
)
all_peaks
.
append
([
x
,
y
])
return
np
.
array
(
all_peaks
)
...
...
@@ -194,4 +193,4 @@ class HandPoseModel(nn.Layer):
os
.
mkdir
(
save_path
)
save_path
=
os
.
path
.
join
(
save_path
,
img_path
.
rsplit
(
"/"
,
1
)[
-
1
])
cv2
.
imwrite
(
save_path
,
canvas
)
return
all_hand_peaks
return
all_hand_peaks
\ No newline at end of file
modules/image/keypoint_detection/openpose_hands_estimation/processor.py
浏览文件 @
a0d3efa2
import
math
from
typing
import
Callable
import
cv2
import
numpy
as
np
...
...
@@ -210,3 +211,30 @@ class DrawHandPose:
bg
.
draw
()
canvas
=
np
.
frombuffer
(
bg
.
tostring_rgb
(),
dtype
=
'uint8'
).
reshape
(
int
(
height
),
int
(
width
),
3
)
return
canvas
class
ResizeScaling
:
"""Resize images by scaling method.
Args:
target(int): Target image size.
interpolation(Callable): Interpolation method.
"""
def
__init__
(
self
,
target
:
int
=
368
,
interpolation
:
Callable
=
cv2
.
INTER_CUBIC
):
self
.
target
=
target
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
,
scale_search
):
scale
=
scale_search
*
self
.
target
/
img
.
shape
[
0
]
resize_img
=
cv2
.
resize
(
img
,
(
0
,
0
),
fx
=
scale
,
fy
=
scale
,
interpolation
=
self
.
interpolation
)
return
resize_img
def
npmax
(
array
:
np
.
ndarray
):
"""Get max value and index."""
arrayindex
=
array
.
argmax
(
1
)
arrayvalue
=
array
.
max
(
1
)
i
=
arrayvalue
.
argmax
()
j
=
arrayindex
[
i
]
return
i
,
j
modules/image/style_transfer/msgnet/module.py
浏览文件 @
a0d3efa2
...
...
@@ -7,7 +7,7 @@ import paddle.nn.functional as F
from
paddlehub.env
import
MODULE_HOME
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.vision.transforms
import
Compose
,
Resize
,
CenterCrop
,
SetType
from
paddlehub.vision.transforms
import
Compose
,
Resize
,
CenterCrop
from
paddlehub.module.cv_module
import
StyleTransferModule
...
...
@@ -324,8 +324,7 @@ class MSGNet(nn.Layer):
self
.
_vgg
=
None
def
transform
(
self
,
path
:
str
):
transform
=
Compose
([
Resize
(
(
256
,
256
),
interpolation
=
'LINEAR'
),
CenterCrop
(
crop_size
=
256
)],
SetType
(
datatype
=
'float32'
))
transform
=
Compose
([
Resize
((
256
,
256
),
interpolation
=
'LINEAR'
)])
return
transform
(
path
)
def
setTarget
(
self
,
Xs
:
paddle
.
Tensor
):
...
...
@@ -340,4 +339,4 @@ class MSGNet(nn.Layer):
return
self
.
_vgg
(
input
)
def
forward
(
self
,
input
:
paddle
.
Tensor
):
return
self
.
model
(
input
)
return
self
.
model
(
input
)
\ No newline at end of file
paddlehub/datasets/minicoco.py
浏览文件 @
a0d3efa2
...
...
@@ -55,9 +55,11 @@ class MiniCOCO(paddle.io.Dataset):
img_path
=
self
.
data
[
idx
]
im
=
self
.
transform
(
img_path
)
im
=
im
.
astype
(
'float32'
)
style_idx
=
idx
%
len
(
self
.
style
)
style_path
=
self
.
style
[
style_idx
]
style
=
self
.
transform
(
style_path
)
style
=
style
.
astype
(
'float32'
)
return
im
,
style
def
__len__
(
self
):
...
...
paddlehub/module/cv_module.py
浏览文件 @
a0d3efa2
...
...
@@ -137,28 +137,14 @@ class ImageColorizeModule(RunModule, ImageServing):
loss_G_L1_reg
=
paddle
.
mean
(
loss_G_L1_reg
)
loss
=
loss_ce
+
loss_G_L1_reg
#calculate psnr
visual_ret
=
OrderedDict
()
psnrs
=
[]
lab2rgb
=
T
.
LAB2RGB
()
process
=
T
.
ColorPostprocess
()
for
i
in
range
(
img
[
'A'
].
numpy
().
shape
[
0
]):
# real = lab2rgb(np.concatenate((img['A'].numpy(), img['B'].numpy()), axis=1))[i]
# visual_ret['real'] = process(real)
# fake = lab2rgb(np.concatenate((img['A'].numpy(), out_reg.numpy()), axis=1))[i]
# visual_ret['fake_reg'] = process(fake)
# mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
# psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs
.
append
(
0
)
#psnr_value)
psnr
=
paddle
.
to_tensor
(
np
.
array
(
psnrs
))
return
{
'loss'
:
loss
,
'metrics'
:
{
'psnr'
:
psnr
}}
return
{
'loss'
:
loss
}
def
predict
(
self
,
images
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
Colorize images
Args:
images(str
) : Images path
to be colorized.
images(str
|np.ndarray) : Images path or BGR image
to be colorized.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
...
...
@@ -167,10 +153,11 @@ class ImageColorizeModule(RunModule, ImageServing):
'''
self
.
eval
()
lab2rgb
=
T
.
LAB2RGB
()
process
=
T
.
ColorPostprocess
()
resize
=
T
.
Resize
((
256
,
256
))
if
isinstance
(
images
,
str
):
images
=
cv2
.
imread
(
images
).
astype
(
'float32'
)
im
=
self
.
transforms
(
images
,
is_train
=
False
)
im
=
self
.
transforms
(
images
)
im
=
im
[
np
.
newaxis
,
:,
:,
:]
im
=
self
.
preprocess
(
im
)
out_class
,
out_reg
=
self
(
im
[
'A'
],
im
[
'hint_B'
],
im
[
'mask_B'
])
...
...
@@ -179,17 +166,20 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret
=
OrderedDict
()
for
i
in
range
(
im
[
'A'
].
shape
[
0
]):
gray
=
lab2rgb
(
np
.
concatenate
((
im
[
'A'
].
numpy
(),
np
.
zeros
(
im
[
'B'
].
shape
)),
axis
=
1
))[
i
]
visual_ret
[
'gray'
]
=
resize
(
process
(
gray
))
gray
=
np
.
clip
(
np
.
transpose
(
gray
,
(
1
,
2
,
0
)),
0
,
1
)
*
255
visual_ret
[
'gray'
]
=
gray
.
astype
(
np
.
uint8
)
hint
=
lab2rgb
(
np
.
concatenate
((
im
[
'A'
].
numpy
(),
im
[
'hint_B'
].
numpy
()),
axis
=
1
))[
i
]
visual_ret
[
'hint'
]
=
resize
(
process
(
hint
))
hint
=
np
.
clip
(
np
.
transpose
(
hint
,
(
1
,
2
,
0
)),
0
,
1
)
*
255
visual_ret
[
'hint'
]
=
hint
.
astype
(
np
.
uint8
)
real
=
lab2rgb
(
np
.
concatenate
((
im
[
'A'
].
numpy
(),
im
[
'B'
].
numpy
()),
axis
=
1
))[
i
]
visual_ret
[
'real'
]
=
resize
(
process
(
real
))
real
=
np
.
clip
(
np
.
transpose
(
real
,
(
1
,
2
,
0
)),
0
,
1
)
*
255
visual_ret
[
'real'
]
=
real
.
astype
(
np
.
uint8
)
fake
=
lab2rgb
(
np
.
concatenate
((
im
[
'A'
].
numpy
(),
out_reg
.
numpy
()),
axis
=
1
))[
i
]
visual_ret
[
'fake_reg'
]
=
resize
(
process
(
fake
))
fake
=
np
.
clip
(
np
.
transpose
(
fake
,
(
1
,
2
,
0
)),
0
,
1
)
*
255
visual_ret
[
'fake_reg'
]
=
fake
.
astype
(
np
.
uint8
)
if
visualization
:
img
=
Image
.
open
(
images
)
w
,
h
=
img
.
size
[
0
],
img
.
size
[
1
]
h
,
w
,
c
=
images
.
shape
fake_name
=
"fake_"
+
str
(
time
.
time
())
+
".png"
if
not
os
.
path
.
exists
(
save_path
):
os
.
mkdir
(
save_path
)
...
...
@@ -198,8 +188,6 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_gray
=
visual_gray
.
resize
((
w
,
h
),
Image
.
BILINEAR
)
visual_gray
.
save
(
fake_path
)
mse
=
np
.
mean
((
visual_ret
[
'real'
]
*
1.0
-
visual_ret
[
'fake_reg'
]
*
1.0
)
**
2
)
psnr_value
=
20
*
np
.
log10
(
255.
/
np
.
sqrt
(
mse
))
result
.
append
(
visual_ret
)
return
result
...
...
@@ -380,13 +368,13 @@ class StyleTransferModule(RunModule, ImageServing):
return
{
'loss'
:
loss
,
'metrics'
:
{
'content gap'
:
content_loss
,
'style gap'
:
style_loss
}}
def
predict
(
self
,
origin
_path
:
str
,
style_path
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
def
predict
(
self
,
origin
:
str
,
style
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
Colorize images
Args:
origin
_path(str): Content image path
.
style
_path(str): Style image path
.
origin
(str|np.array): Content image path or BGR image
.
style
(str|np.array): Style image path or BGR image
.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
...
...
@@ -394,8 +382,9 @@ class StyleTransferModule(RunModule, ImageServing):
output(np.ndarray) : The style transformed images with bgr mode.
'''
self
.
eval
()
content
=
paddle
.
to_tensor
(
self
.
transform
(
origin_path
))
style
=
paddle
.
to_tensor
(
self
.
transform
(
style_path
))
content
=
paddle
.
to_tensor
(
self
.
transform
(
origin
).
astype
(
'float32'
))
style
=
paddle
.
to_tensor
(
self
.
transform
(
style
).
astype
(
'float32'
))
content
=
content
.
unsqueeze
(
0
)
style
=
style
.
unsqueeze
(
0
)
...
...
paddlehub/vision/transforms.py
浏览文件 @
a0d3efa2
...
...
@@ -23,7 +23,7 @@ import paddlehub.vision.functional as F
class
Compose
:
def
__init__
(
self
,
transforms
,
to_rgb
=
True
,
stay_rgb
=
False
,
is_permute
=
Tru
e
):
def
__init__
(
self
,
transforms
,
to_rgb
=
Fals
e
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
if
len
(
transforms
)
<
1
:
...
...
@@ -31,8 +31,6 @@ class Compose:
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
to_rgb
=
to_rgb
self
.
stay_rgb
=
stay_rgb
self
.
is_permute
=
is_permute
def
__call__
(
self
,
im
):
if
isinstance
(
im
,
str
):
...
...
@@ -46,11 +44,7 @@ class Compose:
for
op
in
self
.
transforms
:
im
=
op
(
im
)
if
not
self
.
stay_rgb
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
is_permute
:
im
=
F
.
permute
(
im
)
im
=
F
.
permute
(
im
)
return
im
...
...
@@ -66,7 +60,7 @@ class RandomHorizontalFlip:
class
RandomVerticalFlip
:
def
__init__
(
self
,
prob
=
0.
1
):
def
__init__
(
self
,
prob
=
0.
5
):
self
.
prob
=
prob
def
__call__
(
self
,
im
):
...
...
@@ -85,7 +79,7 @@ class Resize:
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
def
__init__
(
self
,
target_size
=
512
,
interpolation
=
'LINEAR'
):
def
__init__
(
self
,
target_size
,
interpolation
=
'LINEAR'
):
self
.
interpolation
=
interpolation
if
not
(
interpolation
==
"RANDOM"
or
interpolation
in
self
.
interpolation_dict
):
raise
ValueError
(
"interpolation should be one of {}"
.
format
(
self
.
interpolation_dict
.
keys
()))
...
...
@@ -212,7 +206,7 @@ class Padding:
class
RandomPaddingCrop
:
def
__init__
(
self
,
crop_size
=
512
,
im_padding_value
=
[
127.5
,
127.5
,
127.5
]):
def
__init__
(
self
,
crop_size
,
im_padding_value
=
[
127.5
,
127.5
,
127.5
]):
if
isinstance
(
crop_size
,
list
)
or
isinstance
(
crop_size
,
tuple
):
if
len
(
crop_size
)
!=
2
:
raise
ValueError
(
...
...
@@ -469,7 +463,8 @@ class RGB2LAB:
def
__call__
(
self
,
img
:
np
.
ndarray
)
->
np
.
ndarray
:
img
=
img
/
255
img
=
np
.
array
(
img
).
transpose
(
2
,
0
,
1
)
return
self
.
rgb2lab
(
img
)
img
=
self
.
rgb2lab
(
img
)
return
np
.
array
(
img
).
transpose
(
1
,
2
,
0
)
class
LAB2RGB
:
...
...
@@ -585,39 +580,3 @@ class CenterCrop:
crop_left
=
int
((
img_width
-
self
.
crop_size
)
/
2.
)
return
img
[
crop_left
:
crop_left
+
self
.
crop_size
,
crop_top
:
crop_top
+
self
.
crop_size
,
:]
class
SetType
:
"""
Set image type.
Args:
type(type): Type of Image value.
Return:
img(np.ndarray): Transformed image.
"""
def
__init__
(
self
,
datatype
:
type
=
'float32'
):
self
.
type
=
datatype
def
__call__
(
self
,
img
:
np
.
ndarray
):
img
=
img
.
astype
(
self
.
type
)
return
img
class
ResizeScaling
:
"""Resize images by scaling method.
Args:
target(int): Target image size.
interpolation(Callable): Interpolation method.
"""
def
__init__
(
self
,
target
:
int
=
368
,
interpolation
:
Callable
=
cv2
.
INTER_CUBIC
):
self
.
target
=
target
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
,
scale_search
):
scale
=
scale_search
*
self
.
target
/
img
.
shape
[
0
]
resize_img
=
cv2
.
resize
(
img
,
(
0
,
0
),
fx
=
scale
,
fy
=
scale
,
interpolation
=
self
.
interpolation
)
return
resize_img
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录