未验证 提交 02d7e551 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update stgan_bald (#2022)

上级 2ce0e07b
......@@ -129,6 +129,11 @@
* 1.0.0
初始发布
* 1.1.0
移除 Fluid API
- ```shell
$ hub install stgan_bald==1.0.0
$ hub install stgan_bald==1.1.0
```
......@@ -128,6 +128,11 @@
* 1.0.0
First release
* 1.1.0
Remove Fluid API
- ```shell
$ hub install stgan_bald==1.0.0
$ hub install stgan_bald==1.1.0
```
......@@ -3,10 +3,8 @@ import os
import time
from collections import OrderedDict
from PIL import Image, ImageOps
import numpy as np
from PIL import Image
import cv2
import numpy as np
__all__ = ['reader']
......@@ -26,27 +24,33 @@ def reader(images=None, paths=None, org_labels=None, target_labels=None):
if paths:
for i, im_path in enumerate(paths):
each = OrderedDict()
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path)
assert os.path.isfile(
im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path)
each['org_im'] = im
each['org_im_path'] = im_path
each['org_label'] = np.array(org_labels[i]).astype('float32')
if not target_labels:
each['target_label'] = np.array(org_labels[i]).astype('float32')
each['target_label'] = np.array(
org_labels[i]).astype('float32')
else:
each['target_label'] = np.array(target_labels[i]).astype('float32')
each['target_label'] = np.array(
target_labels[i]).astype('float32')
component.append(each)
if images is not None:
assert type(images) is list, "images should be a list."
for i, im in enumerate(images):
each = OrderedDict()
each['org_im'] = im
each['org_im_path'] = 'ndarray_time={}'.format(round(time.time(), 6) * 1e6)
each['org_im_path'] = 'ndarray_time={}'.format(
round(time.time(), 6) * 1e6)
each['org_label'] = np.array(org_labels[i]).astype('float32')
if not target_labels:
each['target_label'] = np.array(org_labels[i]).astype('float32')
each['target_label'] = np.array(
org_labels[i]).astype('float32')
else:
each['target_label'] = np.array(target_labels[i]).astype('float32')
each['target_label'] = np.array(
target_labels[i]).astype('float32')
component.append(each)
for element in component:
......
......@@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import os
import argparse
import copy
import paddle
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from stgan_bald.data_feed import reader
from stgan_bald.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, serving
from .data_feed import reader
from .processor import postprocess, base64_to_cv2, cv2_to_base64
def check_attribute_conflict(label_batch):
......@@ -45,40 +42,43 @@ def check_attribute_conflict(label_batch):
@moduleinfo(
name="stgan_bald",
version="1.0.0",
version="1.1.0",
summary="Baldness generator",
author="Arrow, 七年期限,Mr.郑先生_",
author_email="1084667371@qq.com,2733821739@qq.com",
type="image/gan")
class StganBald(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "module")
class StganBald:
def __init__(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "module", "model")
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
self.model_file_path = os.path.join(self.default_pretrained_model_path, '__model__')
self.params_file_path = os.path.join(self.default_pretrained_model_path, '__params__')
cpu_config = AnalysisConfig(self.model_file_path, self.params_file_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
self.place = fluid.CUDAPlace(0)
self.place = paddle.CUDAPlace(0)
except:
use_gpu = False
self.place = fluid.CPUPlace()
self.place = paddle.CPUPlace()
if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path, self.params_file_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
gpu_config.enable_use_gpu(
memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_predictor(gpu_config)
def bald(self,
images=None,
......@@ -135,19 +135,29 @@ class StganBald(hub.Module):
label_trg_tmp = copy.deepcopy(target_label_np)
new_i = 0
label_trg_tmp[0][new_i] = 1.0 - label_trg_tmp[0][new_i]
label_trg_tmp = check_attribute_conflict(label_trg_tmp)
label_trg_tmp = check_attribute_conflict(
label_trg_tmp)
change_num = j * 0.02 + 0.3
label_org_tmp = list(map(lambda x: ((x * 2) - 1) * change_num, org_label_np))
label_trg_tmp = list(map(lambda x: ((x * 2) - 1) * change_num, label_trg_tmp))
image = PaddleTensor(image_np.copy())
org_label = PaddleTensor(np.array(label_org_tmp).astype('float32'))
target_label = PaddleTensor(np.array(label_trg_tmp).astype('float32'))
output = self.gpu_predictor.run([
image, target_label, org_label
]) if use_gpu else self.cpu_predictor.run([image, org_label, target_label])
outputs.append(output)
label_org_tmp = list(
map(lambda x: ((x * 2) - 1) * change_num, org_label_np))
label_trg_tmp = list(
map(lambda x: ((x * 2) - 1) * change_num, label_trg_tmp))
predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(image_np.copy())
input_handle = predictor.get_input_handle(input_names[1])
input_handle.copy_from_cpu(
np.array(label_org_tmp).astype('float32'))
input_handle = predictor.get_input_handle(input_names[2])
input_handle.copy_from_cpu(
np.array(label_trg_tmp).astype('float32'))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(
output_names[0])
outputs.append(output_handle)
out = postprocess(
data_out=outputs,
......
# -*- coding:utf-8 -*-
import os
import time
import base64
import cv2
from PIL import Image
import numpy as np
from PIL import Image
__all__ = ['cv2_to_base64', 'base64_to_cv2', 'postprocess']
......@@ -22,7 +21,12 @@ def base64_to_cv2(b64str):
return data
def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh=120):
def postprocess(data_out,
org_im,
org_im_path,
output_dir,
visualization,
thresh=120):
"""
Postprocess output of network. one image at a time.
......@@ -41,7 +45,7 @@ def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh
result = dict()
for i, img in enumerate(data_out):
img = np.squeeze(img[0].as_ndarray(), 0).transpose((1, 2, 0))
img = np.squeeze(img.copy_to_cpu(), 0).transpose((1, 2, 0))
img = ((img + 1) * 127.5).astype(np.uint8)
img = cv2.resize(img, (256, 341), cv2.INTER_CUBIC)
fake_image = Image.fromarray(img)
......@@ -76,6 +80,7 @@ def get_save_image_name(org_im_path, output_dir, num):
# save image path
save_im_path = os.path.join(output_dir, im_prefix + ext)
if os.path.exists(save_im_path):
save_im_path = os.path.join(output_dir, im_prefix + str(num) + ext)
save_im_path = os.path.join(
output_dir, im_prefix + str(num) + ext)
return save_im_path
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="stgan_bald")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('bald_output')
def test_bald1(self):
results = self.module.bald(
paths=['tests/test.jpg']
)
data_0 = results[0]['data_0']
data_1 = results[0]['data_1']
data_2 = results[0]['data_2']
self.assertIsInstance(data_0, np.ndarray)
self.assertIsInstance(data_1, np.ndarray)
self.assertIsInstance(data_2, np.ndarray)
def test_bald2(self):
results = self.module.bald(
images=[cv2.imread('tests/test.jpg')]
)
data_0 = results[0]['data_0']
data_1 = results[0]['data_1']
data_2 = results[0]['data_2']
self.assertIsInstance(data_0, np.ndarray)
self.assertIsInstance(data_1, np.ndarray)
self.assertIsInstance(data_2, np.ndarray)
def test_bald3(self):
results = self.module.bald(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
data_0 = results[0]['data_0']
data_1 = results[0]['data_1']
data_2 = results[0]['data_2']
self.assertIsInstance(data_0, np.ndarray)
self.assertIsInstance(data_1, np.ndarray)
self.assertIsInstance(data_2, np.ndarray)
def test_bald4(self):
self.assertRaises(
AssertionError,
self.module.bald,
paths=['no.jpg']
)
def test_bald5(self):
self.assertRaises(
cv2.error,
self.module.bald,
images=['tests/test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册