Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
6f153ddc
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f153ddc
编写于
10月 30, 2020
作者:
H
haoyuying
提交者:
GitHub
10月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adapt rc for colorization and style transfer
上级
332f3a0c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
33 addition
and
39 deletion
+33
-39
demo/colorization/predict.py
demo/colorization/predict.py
+2
-4
demo/style_transfer/predict.py
demo/style_transfer/predict.py
+1
-3
modules/image/colorization/user_guided_colorization/module.py
...les/image/colorization/user_guided_colorization/module.py
+0
-4
modules/image/style_transfer/msgnet/module.py
modules/image/style_transfer/msgnet/module.py
+0
-3
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+30
-25
未找到文件。
demo/colorization/predict.py
浏览文件 @
6f153ddc
...
...
@@ -2,7 +2,5 @@ import paddle
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
)
state_dict
=
paddle
.
load
(
'img_colorization_ckpt'
)
model
.
set_dict
(
state_dict
)
result
=
model
.
predict
(
'house.png'
)
model
=
hub
.
Module
(
name
=
'user_guided_colorization'
,
load_checkpoint
=
'/PATH/TO/CHECKPOINT'
)
result
=
model
.
predict
(
images
=
'house.png'
)
demo/style_transfer/predict.py
浏览文件 @
6f153ddc
...
...
@@ -2,7 +2,5 @@ import paddle
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
model
=
hub
.
Module
(
name
=
'msgnet'
)
state_dict
=
paddle
.
load
(
'img_style_transfer_ckpt'
)
model
.
set_dict
(
state_dict
)
model
=
hub
.
Module
(
name
=
'msgnet'
,
load_checkpoint
=
'/PATH/TO/CHECKPOINT'
)
result
=
model
.
predict
(
"venice-boat.jpg"
,
"candy.jpg"
)
modules/image/colorization/user_guided_colorization/module.py
浏览文件 @
6f153ddc
...
...
@@ -179,11 +179,7 @@ class UserGuidedColorization(nn.Layer):
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'user_guided.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O '
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
...
...
modules/image/style_transfer/msgnet/module.py
浏览文件 @
6f153ddc
...
...
@@ -314,9 +314,6 @@ class MSGNet(nn.Layer):
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'style_paddle.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O '
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)
model_dict_clone
=
model_dict
.
copy
()
for
key
,
value
in
model_dict_clone
.
items
():
...
...
paddlehub/module/cv_module.py
浏览文件 @
6f153ddc
...
...
@@ -164,7 +164,7 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns:
results(list[dict]) : The prediction result of each input image
'''
self
.
eval
()
lab2rgb
=
T
.
LAB2RGB
()
process
=
T
.
ColorPostprocess
()
resize
=
T
.
Resize
((
256
,
256
))
...
...
@@ -239,16 +239,17 @@ class Yolov3Module(RunModule, ImageServing):
for
i
,
out
in
enumerate
(
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
loss
=
F
.
yolov3_loss
(
x
=
out
,
gt_box
=
gtbox
,
gt_label
=
gtlabel
,
gt_score
=
gtscore
,
anchors
=
self
.
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
self
.
class_num
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
32
,
use_label_smooth
=
False
)
loss
=
F
.
yolov3_loss
(
x
=
out
,
gt_box
=
gtbox
,
gt_label
=
gtlabel
,
gt_score
=
gtscore
,
anchors
=
self
.
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
self
.
class_num
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
32
,
use_label_smooth
=
False
)
losses
.
append
(
paddle
.
mean
(
loss
))
self
.
downsample
//=
2
...
...
@@ -269,6 +270,7 @@ class Yolov3Module(RunModule, ImageServing):
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
self
.
eval
()
boxes
=
[]
scores
=
[]
self
.
downsample
=
32
...
...
@@ -287,13 +289,14 @@ class Yolov3Module(RunModule, ImageServing):
mask_anchors
.
append
((
self
.
anchors
[
2
*
m
]))
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
box
,
score
=
F
.
yolo_box
(
x
=
out
,
img_size
=
im_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
class_num
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
box
,
score
=
F
.
yolo_box
(
x
=
out
,
img_size
=
im_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
class_num
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
boxes
.
append
(
box
)
scores
.
append
(
paddle
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
...
...
@@ -302,13 +305,14 @@ class Yolov3Module(RunModule, ImageServing):
yolo_boxes
=
paddle
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
paddle
.
concat
(
scores
,
axis
=
2
)
pred
=
F
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
pred
=
F
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
bboxes
=
pred
.
numpy
()
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
...
...
@@ -388,6 +392,7 @@ class StyleTransferModule(RunModule, ImageServing):
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
self
.
eval
()
content
=
paddle
.
to_tensor
(
self
.
transform
(
origin_path
))
style
=
paddle
.
to_tensor
(
self
.
transform
(
style_path
))
content
=
content
.
unsqueeze
(
0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录