Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
02d7e551
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
02d7e551
编写于
10月 14, 2022
作者:
jm_12138
提交者:
GitHub
10月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update stgan_bald (#2022)
上级
2ce0e07b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
163 addition
and
50 deletion
+163
-50
modules/image/Image_gan/gan/stgan_bald/README.md
modules/image/Image_gan/gan/stgan_bald/README.md
+6
-1
modules/image/Image_gan/gan/stgan_bald/README_en.md
modules/image/Image_gan/gan/stgan_bald/README_en.md
+6
-1
modules/image/Image_gan/gan/stgan_bald/data_feed.py
modules/image/Image_gan/gan/stgan_bald/data_feed.py
+13
-9
modules/image/Image_gan/gan/stgan_bald/module.py
modules/image/Image_gan/gan/stgan_bald/module.py
+43
-33
modules/image/Image_gan/gan/stgan_bald/module/__model__
modules/image/Image_gan/gan/stgan_bald/module/__model__
+0
-0
modules/image/Image_gan/gan/stgan_bald/processor.py
modules/image/Image_gan/gan/stgan_bald/processor.py
+10
-5
modules/image/Image_gan/gan/stgan_bald/requirements.txt
modules/image/Image_gan/gan/stgan_bald/requirements.txt
+0
-1
modules/image/Image_gan/gan/stgan_bald/test.py
modules/image/Image_gan/gan/stgan_bald/test.py
+85
-0
未找到文件。
modules/image/Image_gan/gan/stgan_bald/README.md
浏览文件 @
02d7e551
...
...
@@ -129,6 +129,11 @@
*
1.0.0
初始发布
*
1.1.0
移除 Fluid API
-
```shell
$ hub install stgan_bald==1.
0
.0
$ hub install stgan_bald==1.
1
.0
```
modules/image/Image_gan/gan/stgan_bald/README_en.md
浏览文件 @
02d7e551
...
...
@@ -128,6 +128,11 @@
*
1.0.0
First release
*
1.1.0
Remove Fluid API
-
```shell
$ hub install stgan_bald==1.
0
.0
$ hub install stgan_bald==1.
1
.0
```
modules/image/Image_gan/gan/stgan_bald/data_feed.py
浏览文件 @
02d7e551
...
...
@@ -3,10 +3,8 @@ import os
import
time
from
collections
import
OrderedDict
from
PIL
import
Image
,
ImageOps
import
numpy
as
np
from
PIL
import
Image
import
cv2
import
numpy
as
np
__all__
=
[
'reader'
]
...
...
@@ -26,27 +24,33 @@ def reader(images=None, paths=None, org_labels=None, target_labels=None):
if
paths
:
for
i
,
im_path
in
enumerate
(
paths
):
each
=
OrderedDict
()
assert
os
.
path
.
isfile
(
im_path
),
"The {} isn't a valid file path."
.
format
(
im_path
)
assert
os
.
path
.
isfile
(
im_path
),
"The {} isn't a valid file path."
.
format
(
im_path
)
im
=
cv2
.
imread
(
im_path
)
each
[
'org_im'
]
=
im
each
[
'org_im_path'
]
=
im_path
each
[
'org_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
if
not
target_labels
:
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
else
:
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
component
.
append
(
each
)
if
images
is
not
None
:
assert
type
(
images
)
is
list
,
"images should be a list."
for
i
,
im
in
enumerate
(
images
):
each
=
OrderedDict
()
each
[
'org_im'
]
=
im
each
[
'org_im_path'
]
=
'ndarray_time={}'
.
format
(
round
(
time
.
time
(),
6
)
*
1e6
)
each
[
'org_im_path'
]
=
'ndarray_time={}'
.
format
(
round
(
time
.
time
(),
6
)
*
1e6
)
each
[
'org_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
if
not
target_labels
:
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
else
:
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
component
.
append
(
each
)
for
element
in
component
:
...
...
modules/image/Image_gan/gan/stgan_bald/module.py
浏览文件 @
02d7e551
...
...
@@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
ast
import
os
import
argparse
import
copy
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
stgan_bald.data_feed
import
reader
from
stgan_bald.processor
import
postprocess
,
base64_to_cv2
,
cv2_to_base64
,
check_dir
from
paddle.inference
import
Config
,
create_predictor
from
paddlehub.module.module
import
moduleinfo
,
serving
from
.data_feed
import
reader
from
.processor
import
postprocess
,
base64_to_cv2
,
cv2_to_base64
def
check_attribute_conflict
(
label_batch
):
...
...
@@ -45,40 +42,43 @@ def check_attribute_conflict(label_batch):
@
moduleinfo
(
name
=
"stgan_bald"
,
version
=
"1.
0
.0"
,
version
=
"1.
1
.0"
,
summary
=
"Baldness generator"
,
author
=
"Arrow, 七年期限,Mr.郑先生_"
,
author_email
=
"1084667371@qq.com,2733821739@qq.com"
,
type
=
"image/gan"
)
class
StganBald
(
hub
.
Module
):
def
_initialize
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"module"
)
class
StganBald
:
def
__init__
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"module"
,
"model"
)
self
.
_set_config
()
def
_set_config
(
self
):
"""
predictor config setting
"""
self
.
model_file_path
=
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
'__model__'
)
self
.
params_file_path
=
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
'__params__'
)
cpu_config
=
AnalysisConfig
(
self
.
model_file_path
,
self
.
params_file_path
)
model
=
self
.
default_pretrained_model_path
+
'.pdmodel'
params
=
self
.
default_pretrained_model_path
+
'.pdiparams'
cpu_config
=
Config
(
model
,
params
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
self
.
place
=
fluid
.
CUDAPlace
(
0
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
except
:
use_gpu
=
False
self
.
place
=
fluid
.
CPUPlace
()
self
.
place
=
paddle
.
CPUPlace
()
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
model_file_path
,
self
.
params_file_path
)
gpu_config
=
Config
(
model
,
params
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
1000
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
1000
,
device_id
=
0
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
def
bald
(
self
,
images
=
None
,
...
...
@@ -135,19 +135,29 @@ class StganBald(hub.Module):
label_trg_tmp
=
copy
.
deepcopy
(
target_label_np
)
new_i
=
0
label_trg_tmp
[
0
][
new_i
]
=
1.0
-
label_trg_tmp
[
0
][
new_i
]
label_trg_tmp
=
check_attribute_conflict
(
label_trg_tmp
)
label_trg_tmp
=
check_attribute_conflict
(
label_trg_tmp
)
change_num
=
j
*
0.02
+
0.3
label_org_tmp
=
list
(
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
org_label_np
))
label_trg_tmp
=
list
(
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
label_trg_tmp
))
image
=
PaddleTensor
(
image_np
.
copy
())
org_label
=
PaddleTensor
(
np
.
array
(
label_org_tmp
).
astype
(
'float32'
))
target_label
=
PaddleTensor
(
np
.
array
(
label_trg_tmp
).
astype
(
'float32'
))
output
=
self
.
gpu_predictor
.
run
([
image
,
target_label
,
org_label
])
if
use_gpu
else
self
.
cpu_predictor
.
run
([
image
,
org_label
,
target_label
])
outputs
.
append
(
output
)
label_org_tmp
=
list
(
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
org_label_np
))
label_trg_tmp
=
list
(
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
label_trg_tmp
))
predictor
=
self
.
gpu_predictor
if
use_gpu
else
self
.
cpu_predictor
input_names
=
predictor
.
get_input_names
()
input_handle
=
predictor
.
get_input_handle
(
input_names
[
0
])
input_handle
.
copy_from_cpu
(
image_np
.
copy
())
input_handle
=
predictor
.
get_input_handle
(
input_names
[
1
])
input_handle
.
copy_from_cpu
(
np
.
array
(
label_org_tmp
).
astype
(
'float32'
))
input_handle
=
predictor
.
get_input_handle
(
input_names
[
2
])
input_handle
.
copy_from_cpu
(
np
.
array
(
label_trg_tmp
).
astype
(
'float32'
))
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
outputs
.
append
(
output_handle
)
out
=
postprocess
(
data_out
=
outputs
,
...
...
modules/image/Image_gan/gan/stgan_bald/module/__model__
已删除
100644 → 0
浏览文件 @
2ce0e07b
文件已删除
modules/image/Image_gan/gan/stgan_bald/processor.py
浏览文件 @
02d7e551
# -*- coding:utf-8 -*-
import
os
import
time
import
base64
import
cv2
from
PIL
import
Image
import
numpy
as
np
from
PIL
import
Image
__all__
=
[
'cv2_to_base64'
,
'base64_to_cv2'
,
'postprocess'
]
...
...
@@ -22,7 +21,12 @@ def base64_to_cv2(b64str):
return
data
def
postprocess
(
data_out
,
org_im
,
org_im_path
,
output_dir
,
visualization
,
thresh
=
120
):
def
postprocess
(
data_out
,
org_im
,
org_im_path
,
output_dir
,
visualization
,
thresh
=
120
):
"""
Postprocess output of network. one image at a time.
...
...
@@ -41,7 +45,7 @@ def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh
result
=
dict
()
for
i
,
img
in
enumerate
(
data_out
):
img
=
np
.
squeeze
(
img
[
0
].
as_ndarray
(),
0
).
transpose
((
1
,
2
,
0
))
img
=
np
.
squeeze
(
img
.
copy_to_cpu
(),
0
).
transpose
((
1
,
2
,
0
))
img
=
((
img
+
1
)
*
127.5
).
astype
(
np
.
uint8
)
img
=
cv2
.
resize
(
img
,
(
256
,
341
),
cv2
.
INTER_CUBIC
)
fake_image
=
Image
.
fromarray
(
img
)
...
...
@@ -76,6 +80,7 @@ def get_save_image_name(org_im_path, output_dir, num):
# save image path
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
ext
)
if
os
.
path
.
exists
(
save_im_path
):
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
str
(
num
)
+
ext
)
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
str
(
num
)
+
ext
)
return
save_im_path
modules/image/Image_gan/gan/stgan_bald/requirements.txt
已删除
100644 → 0
浏览文件 @
2ce0e07b
paddlehub>=1.8.0
modules/image/Image_gan/gan/stgan_bald/test.py
0 → 100644
浏览文件 @
02d7e551
import
os
import
shutil
import
unittest
import
cv2
import
requests
import
numpy
as
np
import
paddlehub
as
hub
class
TestHubModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
img_url
=
'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if
not
os
.
path
.
exists
(
'tests'
):
os
.
makedirs
(
'tests'
)
response
=
requests
.
get
(
img_url
)
assert
response
.
status_code
==
200
,
'Network Error.'
with
open
(
'tests/test.jpg'
,
'wb'
)
as
f
:
f
.
write
(
response
.
content
)
cls
.
module
=
hub
.
Module
(
name
=
"stgan_bald"
)
@
classmethod
def
tearDownClass
(
cls
)
->
None
:
shutil
.
rmtree
(
'tests'
)
shutil
.
rmtree
(
'inference'
)
shutil
.
rmtree
(
'bald_output'
)
def
test_bald1
(
self
):
results
=
self
.
module
.
bald
(
paths
=
[
'tests/test.jpg'
]
)
data_0
=
results
[
0
][
'data_0'
]
data_1
=
results
[
0
][
'data_1'
]
data_2
=
results
[
0
][
'data_2'
]
self
.
assertIsInstance
(
data_0
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_1
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_2
,
np
.
ndarray
)
def
test_bald2
(
self
):
results
=
self
.
module
.
bald
(
images
=
[
cv2
.
imread
(
'tests/test.jpg'
)]
)
data_0
=
results
[
0
][
'data_0'
]
data_1
=
results
[
0
][
'data_1'
]
data_2
=
results
[
0
][
'data_2'
]
self
.
assertIsInstance
(
data_0
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_1
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_2
,
np
.
ndarray
)
def
test_bald3
(
self
):
results
=
self
.
module
.
bald
(
images
=
[
cv2
.
imread
(
'tests/test.jpg'
)],
visualization
=
False
)
data_0
=
results
[
0
][
'data_0'
]
data_1
=
results
[
0
][
'data_1'
]
data_2
=
results
[
0
][
'data_2'
]
self
.
assertIsInstance
(
data_0
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_1
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_2
,
np
.
ndarray
)
def
test_bald4
(
self
):
self
.
assertRaises
(
AssertionError
,
self
.
module
.
bald
,
paths
=
[
'no.jpg'
]
)
def
test_bald5
(
self
):
self
.
assertRaises
(
cv2
.
error
,
self
.
module
.
bald
,
images
=
[
'tests/test.jpg'
]
)
def
test_save_inference_model
(
self
):
self
.
module
.
save_inference_model
(
'./inference/model'
)
self
.
assertTrue
(
os
.
path
.
exists
(
'./inference/model.pdmodel'
))
self
.
assertTrue
(
os
.
path
.
exists
(
'./inference/model.pdiparams'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录