未验证 提交 02d7e551 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update stgan_bald (#2022)

上级 2ce0e07b
...@@ -129,6 +129,11 @@ ...@@ -129,6 +129,11 @@
* 1.0.0 * 1.0.0
初始发布 初始发布
* 1.1.0
移除 Fluid API
- ```shell - ```shell
$ hub install stgan_bald==1.0.0 $ hub install stgan_bald==1.1.0
``` ```
...@@ -128,6 +128,11 @@ ...@@ -128,6 +128,11 @@
* 1.0.0 * 1.0.0
First release First release
* 1.1.0
Remove Fluid API
- ```shell - ```shell
$ hub install stgan_bald==1.0.0 $ hub install stgan_bald==1.1.0
``` ```
...@@ -3,10 +3,8 @@ import os ...@@ -3,10 +3,8 @@ import os
import time import time
from collections import OrderedDict from collections import OrderedDict
from PIL import Image, ImageOps
import numpy as np
from PIL import Image
import cv2 import cv2
import numpy as np
__all__ = ['reader'] __all__ = ['reader']
...@@ -26,27 +24,33 @@ def reader(images=None, paths=None, org_labels=None, target_labels=None): ...@@ -26,27 +24,33 @@ def reader(images=None, paths=None, org_labels=None, target_labels=None):
if paths: if paths:
for i, im_path in enumerate(paths): for i, im_path in enumerate(paths):
each = OrderedDict() each = OrderedDict()
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path) assert os.path.isfile(
im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path) im = cv2.imread(im_path)
each['org_im'] = im each['org_im'] = im
each['org_im_path'] = im_path each['org_im_path'] = im_path
each['org_label'] = np.array(org_labels[i]).astype('float32') each['org_label'] = np.array(org_labels[i]).astype('float32')
if not target_labels: if not target_labels:
each['target_label'] = np.array(org_labels[i]).astype('float32') each['target_label'] = np.array(
org_labels[i]).astype('float32')
else: else:
each['target_label'] = np.array(target_labels[i]).astype('float32') each['target_label'] = np.array(
target_labels[i]).astype('float32')
component.append(each) component.append(each)
if images is not None: if images is not None:
assert type(images) is list, "images should be a list." assert type(images) is list, "images should be a list."
for i, im in enumerate(images): for i, im in enumerate(images):
each = OrderedDict() each = OrderedDict()
each['org_im'] = im each['org_im'] = im
each['org_im_path'] = 'ndarray_time={}'.format(round(time.time(), 6) * 1e6) each['org_im_path'] = 'ndarray_time={}'.format(
round(time.time(), 6) * 1e6)
each['org_label'] = np.array(org_labels[i]).astype('float32') each['org_label'] = np.array(org_labels[i]).astype('float32')
if not target_labels: if not target_labels:
each['target_label'] = np.array(org_labels[i]).astype('float32') each['target_label'] = np.array(
org_labels[i]).astype('float32')
else: else:
each['target_label'] = np.array(target_labels[i]).astype('float32') each['target_label'] = np.array(
target_labels[i]).astype('float32')
component.append(each) component.append(each)
for element in component: for element in component:
......
...@@ -13,17 +13,14 @@ ...@@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast
import os import os
import argparse
import copy import copy
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid from paddle.inference import Config, create_predictor
import paddlehub as hub from paddlehub.module.module import moduleinfo, serving
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor from .data_feed import reader
from paddlehub.module.module import moduleinfo, runnable, serving from .processor import postprocess, base64_to_cv2, cv2_to_base64
from stgan_bald.data_feed import reader
from stgan_bald.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
def check_attribute_conflict(label_batch): def check_attribute_conflict(label_batch):
...@@ -45,40 +42,43 @@ def check_attribute_conflict(label_batch): ...@@ -45,40 +42,43 @@ def check_attribute_conflict(label_batch):
@moduleinfo( @moduleinfo(
name="stgan_bald", name="stgan_bald",
version="1.0.0", version="1.1.0",
summary="Baldness generator", summary="Baldness generator",
author="Arrow, 七年期限,Mr.郑先生_", author="Arrow, 七年期限,Mr.郑先生_",
author_email="1084667371@qq.com,2733821739@qq.com", author_email="1084667371@qq.com,2733821739@qq.com",
type="image/gan") type="image/gan")
class StganBald(hub.Module): class StganBald:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "module") self.default_pretrained_model_path = os.path.join(
self.directory, "module", "model")
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
self.model_file_path = os.path.join(self.default_pretrained_model_path, '__model__') model = self.default_pretrained_model_path+'.pdmodel'
self.params_file_path = os.path.join(self.default_pretrained_model_path, '__params__') params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = AnalysisConfig(self.model_file_path, self.params_file_path) cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) int(_places[0])
use_gpu = True use_gpu = True
self.place = fluid.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
except: except:
use_gpu = False use_gpu = False
self.place = fluid.CPUPlace() self.place = paddle.CPUPlace()
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path, self.params_file_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(
self.gpu_predictor = create_paddle_predictor(gpu_config) memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_predictor(gpu_config)
def bald(self, def bald(self,
images=None, images=None,
...@@ -135,19 +135,29 @@ class StganBald(hub.Module): ...@@ -135,19 +135,29 @@ class StganBald(hub.Module):
label_trg_tmp = copy.deepcopy(target_label_np) label_trg_tmp = copy.deepcopy(target_label_np)
new_i = 0 new_i = 0
label_trg_tmp[0][new_i] = 1.0 - label_trg_tmp[0][new_i] label_trg_tmp[0][new_i] = 1.0 - label_trg_tmp[0][new_i]
label_trg_tmp = check_attribute_conflict(label_trg_tmp) label_trg_tmp = check_attribute_conflict(
label_trg_tmp)
change_num = j * 0.02 + 0.3 change_num = j * 0.02 + 0.3
label_org_tmp = list(map(lambda x: ((x * 2) - 1) * change_num, org_label_np)) label_org_tmp = list(
label_trg_tmp = list(map(lambda x: ((x * 2) - 1) * change_num, label_trg_tmp)) map(lambda x: ((x * 2) - 1) * change_num, org_label_np))
label_trg_tmp = list(
image = PaddleTensor(image_np.copy()) map(lambda x: ((x * 2) - 1) * change_num, label_trg_tmp))
org_label = PaddleTensor(np.array(label_org_tmp).astype('float32'))
target_label = PaddleTensor(np.array(label_trg_tmp).astype('float32')) predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
output = self.gpu_predictor.run([ input_handle = predictor.get_input_handle(input_names[0])
image, target_label, org_label input_handle.copy_from_cpu(image_np.copy())
]) if use_gpu else self.cpu_predictor.run([image, org_label, target_label]) input_handle = predictor.get_input_handle(input_names[1])
outputs.append(output) input_handle.copy_from_cpu(
np.array(label_org_tmp).astype('float32'))
input_handle = predictor.get_input_handle(input_names[2])
input_handle.copy_from_cpu(
np.array(label_trg_tmp).astype('float32'))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(
output_names[0])
outputs.append(output_handle)
out = postprocess( out = postprocess(
data_out=outputs, data_out=outputs,
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import os import os
import time
import base64 import base64
import cv2 import cv2
from PIL import Image
import numpy as np import numpy as np
from PIL import Image
__all__ = ['cv2_to_base64', 'base64_to_cv2', 'postprocess'] __all__ = ['cv2_to_base64', 'base64_to_cv2', 'postprocess']
...@@ -22,7 +21,12 @@ def base64_to_cv2(b64str): ...@@ -22,7 +21,12 @@ def base64_to_cv2(b64str):
return data return data
def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh=120): def postprocess(data_out,
org_im,
org_im_path,
output_dir,
visualization,
thresh=120):
""" """
Postprocess output of network. one image at a time. Postprocess output of network. one image at a time.
...@@ -41,7 +45,7 @@ def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh ...@@ -41,7 +45,7 @@ def postprocess(data_out, org_im, org_im_path, output_dir, visualization, thresh
result = dict() result = dict()
for i, img in enumerate(data_out): for i, img in enumerate(data_out):
img = np.squeeze(img[0].as_ndarray(), 0).transpose((1, 2, 0)) img = np.squeeze(img.copy_to_cpu(), 0).transpose((1, 2, 0))
img = ((img + 1) * 127.5).astype(np.uint8) img = ((img + 1) * 127.5).astype(np.uint8)
img = cv2.resize(img, (256, 341), cv2.INTER_CUBIC) img = cv2.resize(img, (256, 341), cv2.INTER_CUBIC)
fake_image = Image.fromarray(img) fake_image = Image.fromarray(img)
...@@ -76,6 +80,7 @@ def get_save_image_name(org_im_path, output_dir, num): ...@@ -76,6 +80,7 @@ def get_save_image_name(org_im_path, output_dir, num):
# save image path # save image path
save_im_path = os.path.join(output_dir, im_prefix + ext) save_im_path = os.path.join(output_dir, im_prefix + ext)
if os.path.exists(save_im_path): if os.path.exists(save_im_path):
save_im_path = os.path.join(output_dir, im_prefix + str(num) + ext) save_im_path = os.path.join(
output_dir, im_prefix + str(num) + ext)
return save_im_path return save_im_path
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="stgan_bald")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('bald_output')
def test_bald1(self):
results = self.module.bald(
paths=['tests/test.jpg']
)
data_0 = results[0]['data_0']
data_1 = results[0]['data_1']
data_2 = results[0]['data_2']
self.assertIsInstance(data_0, np.ndarray)
self.assertIsInstance(data_1, np.ndarray)
self.assertIsInstance(data_2, np.ndarray)
def test_bald2(self):
results = self.module.bald(
images=[cv2.imread('tests/test.jpg')]
)
data_0 = results[0]['data_0']
data_1 = results[0]['data_1']
data_2 = results[0]['data_2']
self.assertIsInstance(data_0, np.ndarray)
self.assertIsInstance(data_1, np.ndarray)
self.assertIsInstance(data_2, np.ndarray)
def test_bald3(self):
results = self.module.bald(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
data_0 = results[0]['data_0']
data_1 = results[0]['data_1']
data_2 = results[0]['data_2']
self.assertIsInstance(data_0, np.ndarray)
self.assertIsInstance(data_1, np.ndarray)
self.assertIsInstance(data_2, np.ndarray)
def test_bald4(self):
self.assertRaises(
AssertionError,
self.module.bald,
paths=['no.jpg']
)
def test_bald5(self):
self.assertRaises(
cv2.error,
self.module.bald,
images=['tests/test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册