Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
02d7e551
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
02d7e551
编写于
10月 14, 2022
作者:
jm_12138
提交者:
GitHub
10月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update stgan_bald (#2022)
上级
2ce0e07b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
163 addition
and
50 deletion
+163
-50
modules/image/Image_gan/gan/stgan_bald/README.md
modules/image/Image_gan/gan/stgan_bald/README.md
+6
-1
modules/image/Image_gan/gan/stgan_bald/README_en.md
modules/image/Image_gan/gan/stgan_bald/README_en.md
+6
-1
modules/image/Image_gan/gan/stgan_bald/data_feed.py
modules/image/Image_gan/gan/stgan_bald/data_feed.py
+13
-9
modules/image/Image_gan/gan/stgan_bald/module.py
modules/image/Image_gan/gan/stgan_bald/module.py
+43
-33
modules/image/Image_gan/gan/stgan_bald/module/__model__
modules/image/Image_gan/gan/stgan_bald/module/__model__
+0
-0
modules/image/Image_gan/gan/stgan_bald/processor.py
modules/image/Image_gan/gan/stgan_bald/processor.py
+10
-5
modules/image/Image_gan/gan/stgan_bald/requirements.txt
modules/image/Image_gan/gan/stgan_bald/requirements.txt
+0
-1
modules/image/Image_gan/gan/stgan_bald/test.py
modules/image/Image_gan/gan/stgan_bald/test.py
+85
-0
未找到文件。
modules/image/Image_gan/gan/stgan_bald/README.md
浏览文件 @
02d7e551
...
@@ -129,6 +129,11 @@
...
@@ -129,6 +129,11 @@
*
1.0.0
*
1.0.0
初始发布
初始发布
*
1.1.0
移除 Fluid API
-
```shell
-
```shell
$ hub install stgan_bald==1.
0
.0
$ hub install stgan_bald==1.
1
.0
```
```
modules/image/Image_gan/gan/stgan_bald/README_en.md
浏览文件 @
02d7e551
...
@@ -128,6 +128,11 @@
...
@@ -128,6 +128,11 @@
*
1.0.0
*
1.0.0
First release
First release
*
1.1.0
Remove Fluid API
-
```shell
-
```shell
$ hub install stgan_bald==1.
0
.0
$ hub install stgan_bald==1.
1
.0
```
```
modules/image/Image_gan/gan/stgan_bald/data_feed.py
浏览文件 @
02d7e551
...
@@ -3,10 +3,8 @@ import os
...
@@ -3,10 +3,8 @@ import os
import
time
import
time
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
PIL
import
Image
,
ImageOps
import
numpy
as
np
from
PIL
import
Image
import
cv2
import
cv2
import
numpy
as
np
__all__
=
[
'reader'
]
__all__
=
[
'reader'
]
...
@@ -26,27 +24,33 @@ def reader(images=None, paths=None, org_labels=None, target_labels=None):
...
@@ -26,27 +24,33 @@ def reader(images=None, paths=None, org_labels=None, target_labels=None):
if
paths
:
if
paths
:
for
i
,
im_path
in
enumerate
(
paths
):
for
i
,
im_path
in
enumerate
(
paths
):
each
=
OrderedDict
()
each
=
OrderedDict
()
assert
os
.
path
.
isfile
(
im_path
),
"The {} isn't a valid file path."
.
format
(
im_path
)
assert
os
.
path
.
isfile
(
im_path
),
"The {} isn't a valid file path."
.
format
(
im_path
)
im
=
cv2
.
imread
(
im_path
)
im
=
cv2
.
imread
(
im_path
)
each
[
'org_im'
]
=
im
each
[
'org_im'
]
=
im
each
[
'org_im_path'
]
=
im_path
each
[
'org_im_path'
]
=
im_path
each
[
'org_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
each
[
'org_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
if
not
target_labels
:
if
not
target_labels
:
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
else
:
else
:
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
component
.
append
(
each
)
component
.
append
(
each
)
if
images
is
not
None
:
if
images
is
not
None
:
assert
type
(
images
)
is
list
,
"images should be a list."
assert
type
(
images
)
is
list
,
"images should be a list."
for
i
,
im
in
enumerate
(
images
):
for
i
,
im
in
enumerate
(
images
):
each
=
OrderedDict
()
each
=
OrderedDict
()
each
[
'org_im'
]
=
im
each
[
'org_im'
]
=
im
each
[
'org_im_path'
]
=
'ndarray_time={}'
.
format
(
round
(
time
.
time
(),
6
)
*
1e6
)
each
[
'org_im_path'
]
=
'ndarray_time={}'
.
format
(
round
(
time
.
time
(),
6
)
*
1e6
)
each
[
'org_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
each
[
'org_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
if
not
target_labels
:
if
not
target_labels
:
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
org_labels
[
i
]).
astype
(
'float32'
)
else
:
else
:
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
each
[
'target_label'
]
=
np
.
array
(
target_labels
[
i
]).
astype
(
'float32'
)
component
.
append
(
each
)
component
.
append
(
each
)
for
element
in
component
:
for
element
in
component
:
...
...
modules/image/Image_gan/gan/stgan_bald/module.py
浏览文件 @
02d7e551
...
@@ -13,17 +13,14 @@
...
@@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
ast
import
os
import
os
import
argparse
import
copy
import
copy
import
paddle
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.inference
import
Config
,
create_predictor
import
paddlehub
as
hub
from
paddlehub.module.module
import
moduleinfo
,
serving
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
.data_feed
import
reader
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
.processor
import
postprocess
,
base64_to_cv2
,
cv2_to_base64
from
stgan_bald.data_feed
import
reader
from
stgan_bald.processor
import
postprocess
,
base64_to_cv2
,
cv2_to_base64
,
check_dir
def
check_attribute_conflict
(
label_batch
):
def
check_attribute_conflict
(
label_batch
):
...
@@ -45,40 +42,43 @@ def check_attribute_conflict(label_batch):
...
@@ -45,40 +42,43 @@ def check_attribute_conflict(label_batch):
@
moduleinfo
(
@
moduleinfo
(
name
=
"stgan_bald"
,
name
=
"stgan_bald"
,
version
=
"1.
0
.0"
,
version
=
"1.
1
.0"
,
summary
=
"Baldness generator"
,
summary
=
"Baldness generator"
,
author
=
"Arrow, 七年期限,Mr.郑先生_"
,
author
=
"Arrow, 七年期限,Mr.郑先生_"
,
author_email
=
"1084667371@qq.com,2733821739@qq.com"
,
author_email
=
"1084667371@qq.com,2733821739@qq.com"
,
type
=
"image/gan"
)
type
=
"image/gan"
)
class
StganBald
(
hub
.
Module
):
class
StganBald
:
def
_initialize
(
self
):
def
__init__
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"module"
)
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"module"
,
"model"
)
self
.
_set_config
()
self
.
_set_config
()
def
_set_config
(
self
):
def
_set_config
(
self
):
"""
"""
predictor config setting
predictor config setting
"""
"""
self
.
model_file_path
=
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
'__model__'
)
model
=
self
.
default_pretrained_model_path
+
'.pdmodel'
self
.
params_file_path
=
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
'__params__'
)
params
=
self
.
default_pretrained_model_path
+
'.pdiparams'
cpu_config
=
AnalysisConfig
(
self
.
model_file_path
,
self
.
params_file_path
)
cpu_config
=
Config
(
model
,
params
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
cpu_config
.
disable_gpu
()
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
int
(
_places
[
0
])
use_gpu
=
True
use_gpu
=
True
self
.
place
=
fluid
.
CUDAPlace
(
0
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
except
:
except
:
use_gpu
=
False
use_gpu
=
False
self
.
place
=
fluid
.
CPUPlace
()
self
.
place
=
paddle
.
CPUPlace
()
if
use_gpu
:
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
model_file_path
,
self
.
params_file_path
)
gpu_config
=
Config
(
model
,
params
)
gpu_config
.
disable_glog_info
()
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
1000
,
device_id
=
0
)
gpu_config
.
enable_use_gpu
(
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
memory_pool_init_size_mb
=
1000
,
device_id
=
0
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
def
bald
(
self
,
def
bald
(
self
,
images
=
None
,
images
=
None
,
...
@@ -135,19 +135,29 @@ class StganBald(hub.Module):
...
@@ -135,19 +135,29 @@ class StganBald(hub.Module):
label_trg_tmp
=
copy
.
deepcopy
(
target_label_np
)
label_trg_tmp
=
copy
.
deepcopy
(
target_label_np
)
new_i
=
0
new_i
=
0
label_trg_tmp
[
0
][
new_i
]
=
1.0
-
label_trg_tmp
[
0
][
new_i
]
label_trg_tmp
[
0
][
new_i
]
=
1.0
-
label_trg_tmp
[
0
][
new_i
]
label_trg_tmp
=
check_attribute_conflict
(
label_trg_tmp
)
label_trg_tmp
=
check_attribute_conflict
(
label_trg_tmp
)
change_num
=
j
*
0.02
+
0.3
change_num
=
j
*
0.02
+
0.3
label_org_tmp
=
list
(
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
org_label_np
))
label_org_tmp
=
list
(
label_trg_tmp
=
list
(
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
label_trg_tmp
))
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
org_label_np
))
label_trg_tmp
=
list
(
image
=
PaddleTensor
(
image_np
.
copy
())
map
(
lambda
x
:
((
x
*
2
)
-
1
)
*
change_num
,
label_trg_tmp
))
org_label
=
PaddleTensor
(
np
.
array
(
label_org_tmp
).
astype
(
'float32'
))
target_label
=
PaddleTensor
(
np
.
array
(
label_trg_tmp
).
astype
(
'float32'
))
predictor
=
self
.
gpu_predictor
if
use_gpu
else
self
.
cpu_predictor
input_names
=
predictor
.
get_input_names
()
output
=
self
.
gpu_predictor
.
run
([
input_handle
=
predictor
.
get_input_handle
(
input_names
[
0
])
image
,
target_label
,
org_label
input_handle
.
copy_from_cpu
(
image_np
.
copy
())
])
if
use_gpu
else
self
.
cpu_predictor
.
run
([
image
,
org_label
,
target_label
])
input_handle
=
predictor
.
get_input_handle
(
input_names
[
1
])
outputs
.
append
(
output
)
input_handle
.
copy_from_cpu
(
np
.
array
(
label_org_tmp
).
astype
(
'float32'
))
input_handle
=
predictor
.
get_input_handle
(
input_names
[
2
])
input_handle
.
copy_from_cpu
(
np
.
array
(
label_trg_tmp
).
astype
(
'float32'
))
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
outputs
.
append
(
output_handle
)
out
=
postprocess
(
out
=
postprocess
(
data_out
=
outputs
,
data_out
=
outputs
,
...
...
modules/image/Image_gan/gan/stgan_bald/module/__model__
已删除
100644 → 0
浏览文件 @
2ce0e07b
文件已删除
modules/image/Image_gan/gan/stgan_bald/processor.py
浏览文件 @
02d7e551
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
import
os
import
os
import
time
import
base64
import
base64
import
cv2
import
cv2
from
PIL
import
Image
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
__all__
=
[
'cv2_to_base64'
,
'base64_to_cv2'
,
'postprocess'
]
__all__
=
[
'cv2_to_base64'
,
'base64_to_cv2'
,
'postprocess'
]
...
@@ -22,7 +21,12 @@ def base64_to_cv2(b64str):
...
@@ -22,7 +21,12 @@ def base64_to_cv2(b64str):
return
data
return
data
def
postprocess
(
data_out
,
org_im
,
org_im_path
,
output_dir
,
visualization
,
thresh
=
120
):
def
postprocess
(
data_out
,
org_im
,
org_im_path
,
output_dir
,
visualization
,
thresh
=
120
):
"""
"""
Postprocess output of network. one image at a time.
Postprocess output of network. one image at a time.
...
@@ -41,7 +45,7 @@ def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh
...
@@ -41,7 +45,7 @@ def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh
result
=
dict
()
result
=
dict
()
for
i
,
img
in
enumerate
(
data_out
):
for
i
,
img
in
enumerate
(
data_out
):
img
=
np
.
squeeze
(
img
[
0
].
as_ndarray
(),
0
).
transpose
((
1
,
2
,
0
))
img
=
np
.
squeeze
(
img
.
copy_to_cpu
(),
0
).
transpose
((
1
,
2
,
0
))
img
=
((
img
+
1
)
*
127.5
).
astype
(
np
.
uint8
)
img
=
((
img
+
1
)
*
127.5
).
astype
(
np
.
uint8
)
img
=
cv2
.
resize
(
img
,
(
256
,
341
),
cv2
.
INTER_CUBIC
)
img
=
cv2
.
resize
(
img
,
(
256
,
341
),
cv2
.
INTER_CUBIC
)
fake_image
=
Image
.
fromarray
(
img
)
fake_image
=
Image
.
fromarray
(
img
)
...
@@ -76,6 +80,7 @@ def get_save_image_name(org_im_path, output_dir, num):
...
@@ -76,6 +80,7 @@ def get_save_image_name(org_im_path, output_dir, num):
# save image path
# save image path
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
ext
)
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
ext
)
if
os
.
path
.
exists
(
save_im_path
):
if
os
.
path
.
exists
(
save_im_path
):
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
str
(
num
)
+
ext
)
save_im_path
=
os
.
path
.
join
(
output_dir
,
im_prefix
+
str
(
num
)
+
ext
)
return
save_im_path
return
save_im_path
modules/image/Image_gan/gan/stgan_bald/requirements.txt
已删除
100644 → 0
浏览文件 @
2ce0e07b
paddlehub>=1.8.0
modules/image/Image_gan/gan/stgan_bald/test.py
0 → 100644
浏览文件 @
02d7e551
import
os
import
shutil
import
unittest
import
cv2
import
requests
import
numpy
as
np
import
paddlehub
as
hub
class
TestHubModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
img_url
=
'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if
not
os
.
path
.
exists
(
'tests'
):
os
.
makedirs
(
'tests'
)
response
=
requests
.
get
(
img_url
)
assert
response
.
status_code
==
200
,
'Network Error.'
with
open
(
'tests/test.jpg'
,
'wb'
)
as
f
:
f
.
write
(
response
.
content
)
cls
.
module
=
hub
.
Module
(
name
=
"stgan_bald"
)
@
classmethod
def
tearDownClass
(
cls
)
->
None
:
shutil
.
rmtree
(
'tests'
)
shutil
.
rmtree
(
'inference'
)
shutil
.
rmtree
(
'bald_output'
)
def
test_bald1
(
self
):
results
=
self
.
module
.
bald
(
paths
=
[
'tests/test.jpg'
]
)
data_0
=
results
[
0
][
'data_0'
]
data_1
=
results
[
0
][
'data_1'
]
data_2
=
results
[
0
][
'data_2'
]
self
.
assertIsInstance
(
data_0
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_1
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_2
,
np
.
ndarray
)
def
test_bald2
(
self
):
results
=
self
.
module
.
bald
(
images
=
[
cv2
.
imread
(
'tests/test.jpg'
)]
)
data_0
=
results
[
0
][
'data_0'
]
data_1
=
results
[
0
][
'data_1'
]
data_2
=
results
[
0
][
'data_2'
]
self
.
assertIsInstance
(
data_0
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_1
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_2
,
np
.
ndarray
)
def
test_bald3
(
self
):
results
=
self
.
module
.
bald
(
images
=
[
cv2
.
imread
(
'tests/test.jpg'
)],
visualization
=
False
)
data_0
=
results
[
0
][
'data_0'
]
data_1
=
results
[
0
][
'data_1'
]
data_2
=
results
[
0
][
'data_2'
]
self
.
assertIsInstance
(
data_0
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_1
,
np
.
ndarray
)
self
.
assertIsInstance
(
data_2
,
np
.
ndarray
)
def
test_bald4
(
self
):
self
.
assertRaises
(
AssertionError
,
self
.
module
.
bald
,
paths
=
[
'no.jpg'
]
)
def
test_bald5
(
self
):
self
.
assertRaises
(
cv2
.
error
,
self
.
module
.
bald
,
images
=
[
'tests/test.jpg'
]
)
def
test_save_inference_model
(
self
):
self
.
module
.
save_inference_model
(
'./inference/model'
)
self
.
assertTrue
(
os
.
path
.
exists
(
'./inference/model.pdmodel'
))
self
.
assertTrue
(
os
.
path
.
exists
(
'./inference/model.pdiparams'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录