test_stylepro_artistic.py 4.0 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import unittest

import cv2
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub

content_dir = '../image_dataset/style_tranfer/content/'
style_dir = '../image_dataset/style_tranfer/style/'


class TestStyleProjection(unittest.TestCase):
    @classmethod
    def setUpClass(self):
        """Prepare the environment once before execution of all tests.\n"""
        self.style_projection = hub.Module(name="stylepro_artistic")

    @classmethod
    def tearDownClass(self):
        """clean up the environment after the execution of all tests.\n"""
        self.style_projection = None

    def setUp(self):
        "Call setUp() to prepare environment\n"
        self.test_prog = fluid.Program()

    def tearDown(self):
        "Call tearDown to restore environment.\n"
        self.test_prog = None

    def test_single_style(self):
        with fluid.program_guard(self.test_prog):
            content_paths = [
                os.path.join(content_dir, f) for f in os.listdir(content_dir)
            ]
            style_paths = [
                os.path.join(style_dir, f) for f in os.listdir(style_dir)
            ]
            for style_path in style_paths:
                t1 = time.time()
                self.style_projection.style_transfer(
                    paths=[{
                        'content': content_paths[0],
                        'styles': [style_path]
                    }],
                    alpha=0.8,
                    use_gpu=True)
                t2 = time.time()
                print('\nCost time: {}'.format(t2 - t1))

    def test_multiple_styles(self):
        with fluid.program_guard(self.test_prog):
            content_path = os.path.join(content_dir, 'chicago.jpg')
            style_paths = [
                os.path.join(style_dir, f) for f in os.listdir(style_dir)
            ]
            for j in range(len(style_paths) - 1):
                res = self.style_projection.style_transfer(
                    paths=[{
                        'content': content_path,
                        'styles': [style_paths[j], style_paths[j + 1]],
                        'weights': [1, 2]
                    }],
                    alpha=0.8,
                    use_gpu=True,
                    visualization=True)
                print('#' * 100)
                print(res)
                print('#' * 100)

    def test_input_ndarray(self):
        with fluid.program_guard(self.test_prog):
            content_arr = cv2.imread(os.path.join(content_dir, 'chicago.jpg'))
            content_arr = cv2.cvtColor(content_arr, cv2.COLOR_BGR2RGB)
            style_arrs_BGR = [
                cv2.imread(os.path.join(style_dir, f))
                for f in os.listdir(style_dir)
            ]
            style_arrs_list = [
                cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) for arr in style_arrs_BGR
            ]
            for j in range(len(style_arrs_list) - 1):
                self.style_projection.style_transfer(
                    images=[{
                        'content':
                        content_arr,
                        'styles': [style_arrs_list[j], style_arrs_list[j + 1]]
                    }],
                    alpha=0.8,
                    use_gpu=True,
                    output_dir='transfer_out',
                    visualization=True)

    def test_save_inference_model(self):
        with fluid.program_guard(self.test_prog):
            self.style_projection.save_inference_model(
                dirname='stylepro_artistic',
                model_filename='model',
                combined=True)


if __name__ == "__main__":
    suite = unittest.TestSuite()
    suite.addTest(TestStyleProjection('test_single_style'))
    suite.addTest(TestStyleProjection('test_multiple_styles'))
    suite.addTest(TestStyleProjection('test_input_ndarray'))
    suite.addTest(TestStyleProjection('test_save_inference_model'))
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)