diff --git a/modules/image/Image_gan/style_transfer/painttransformer/README.md b/modules/image/Image_gan/style_transfer/painttransformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1cfd283ac52d06c02611b5f77754a7222333d5bc
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/README.md
@@ -0,0 +1,134 @@
+# painttransformer
+
+|模型名称|painttransformer|
+| :--- | :---: |
+|类别|图像 - 风格转换|
+|网络|Paint Transformer|
+|数据集|-|
+|是否支持Fine-tuning|否|
+|模型大小|77MB|
+|最新更新日期|2021-12-07|
+|数据指标|-|
+
+
+## 一、模型基本信息
+
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
+ 输入图像
+
+
+
+ 输出图像
+
+
+
+- ### 模型介绍
+
+ - 该模型可以实现图像油画风格的转换。
+ - 更多详情参考:[Paint Transformer: Feed Forward Neural Painting with Stroke Prediction](https://github.com/wzmsltw/PaintTransformer)
+
+
+
+## 二、安装
+
+- ### 1、环境依赖
+ - ppgan
+
+- ### 2、安装
+
+ - ```shell
+ $ hub install painttransformer
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+## 三、模型API预测
+
+- ### 1、命令行预测
+
+ - ```shell
+ # Read from a file
+ $ hub run painttransformer --input_path "/PATH/TO/IMAGE"
+ ```
+ - 通过命令行方式实现风格转换模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
+
+- ### 2、预测代码示例
+
+ - ```python
+ import paddlehub as hub
+
+ module = hub.Module(name="painttransformer")
+ input_path = ["/PATH/TO/IMAGE"]
+ # Read from a file
+ module.style_transfer(paths=input_path, output_dir='./transfer_result/', use_gpu=True)
+ ```
+
+- ### 3、API
+
+ - ```python
+ style_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, need_animation=False, visualization=True):
+ ```
+ - 油画风格转换API。
+
+ - **参数**
+
+ - images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\];
+ - paths (list\[str\]): 图片的路径;
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - need_animation(bool): 是否保存中间结果形成动画
+ - visualization(bool): 是否保存结果到本地文件夹
+
+
+## 四、服务部署
+
+- PaddleHub Serving可以部署一个在线油画风格转换服务。
+
+- ### 第一步:启动PaddleHub Serving
+
+ - 运行启动命令:
+ - ```shell
+ $ hub serving start -m painttransformer
+ ```
+
+ - 这样就完成了一个油画风格转换的在线服务API的部署,默认端口号为8866。
+
+ - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
+
+- ### 第二步:发送预测请求
+
+ - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+ - ```python
+ import requests
+ import json
+ import cv2
+ import base64
+
+
+ def cv2_to_base64(image):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+
+ # 发送HTTP请求
+ data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
+ headers = {"Content-type": "application/json"}
+ url = "http://127.0.0.1:8866/predict/painttransformer"
+ r = requests.post(url=url, headers=headers, data=json.dumps(data))
+
+ # 打印预测结果
+ print(r.json()["results"])
+
+## 五、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ - ```shell
+ $ hub install painttransformer==1.0.0
+ ```
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/brush/brush_large_horizontal.png b/modules/image/Image_gan/style_transfer/painttransformer/brush/brush_large_horizontal.png
new file mode 100644
index 0000000000000000000000000000000000000000..e2f746dd71f932c81a61c33512b32a929efaf2f4
Binary files /dev/null and b/modules/image/Image_gan/style_transfer/painttransformer/brush/brush_large_horizontal.png differ
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/brush/brush_large_vertical.png b/modules/image/Image_gan/style_transfer/painttransformer/brush/brush_large_vertical.png
new file mode 100644
index 0000000000000000000000000000000000000000..5238813f2666df103d1cf5ff3aa2e7a2badc4f2a
Binary files /dev/null and b/modules/image/Image_gan/style_transfer/painttransformer/brush/brush_large_vertical.png differ
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/inference.py b/modules/image/Image_gan/style_transfer/painttransformer/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd2c1113549ceb7c74ab1445c0d39a92a475842
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/inference.py
@@ -0,0 +1,72 @@
+import numpy as np
+from PIL import Image
+import network
+import os
+import math
+import render_utils
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import cv2
+import render_parallel
+import render_serial
+
+
+def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
+ if not os.path.exists(output_dir):
+ os.mkdir(output_dir)
+ input_name = os.path.basename(input_path)
+ output_path = os.path.join(output_dir, input_name)
+ frame_dir = None
+ if need_animation:
+ if not serial:
+ print('It must be under serial mode if animation results are required, so serial flag is set to True!')
+ serial = True
+ frame_dir = os.path.join(output_dir, input_name[:input_name.find('.')])
+ if not os.path.exists(frame_dir):
+ os.mkdir(frame_dir)
+ stroke_num = 8
+
+ #* ----- load model ----- *#
+ paddle.set_device('gpu')
+ net_g = network.Painter(5, stroke_num, 256, 8, 3, 3)
+ net_g.set_state_dict(paddle.load(model_path))
+ net_g.eval()
+ for param in net_g.parameters():
+ param.stop_gradient = True
+
+ #* ----- load brush ----- *#
+ brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L')
+ brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L')
+ meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0)
+
+ import time
+ t0 = time.time()
+
+ original_img = render_utils.read_img(input_path, 'RGB', resize_h, resize_w)
+ if serial:
+ final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes)
+ if need_animation:
+
+ print("total frame:", len(final_result_list))
+ for idx, frame in enumerate(final_result_list):
+ cv2.imwrite(os.path.join(frame_dir, '%03d.png' % idx), frame)
+ else:
+ cv2.imwrite(output_path, final_result_list[-1])
+ else:
+ final_result = render_parallel.render_parallel(original_img, net_g, meta_brushes)
+ cv2.imwrite(output_path, final_result)
+
+ print("total infer time:", time.time() - t0)
+
+
+if __name__ == '__main__':
+
+ main(
+ input_path='input/chicago.jpg',
+ model_path='paint_best.pdparams',
+ output_dir='output/',
+ need_animation=True, # whether need intermediate results for animation.
+ resize_h=512, # resize original input to this size. None means do not resize.
+ resize_w=512, # resize original input to this size. None means do not resize.
+ serial=True) # if need animation, serial must be True.
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/model.py b/modules/image/Image_gan/style_transfer/painttransformer/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f40a3ec0210a961fd90191e228f83712fd5781
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/model.py
@@ -0,0 +1,68 @@
+import paddle
+import paddle.nn as nn
+import math
+
+
+class Painter(nn.Layer):
+ """
+ network architecture written in paddle.
+ """
+
+ def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3):
+ super().__init__()
+ self.enc_img = nn.Sequential(
+ nn.Pad2D([1, 1, 1, 1], 'reflect'),
+ nn.Conv2D(3, 32, 3, 1),
+ nn.BatchNorm2D(32),
+ nn.ReLU(), # maybe replace with the inplace version
+ nn.Pad2D([1, 1, 1, 1], 'reflect'),
+ nn.Conv2D(32, 64, 3, 2),
+ nn.BatchNorm2D(64),
+ nn.ReLU(),
+ nn.Pad2D([1, 1, 1, 1], 'reflect'),
+ nn.Conv2D(64, 128, 3, 2),
+ nn.BatchNorm2D(128),
+ nn.ReLU())
+ self.enc_canvas = nn.Sequential(
+ nn.Pad2D([1, 1, 1, 1], 'reflect'), nn.Conv2D(3, 32, 3, 1), nn.BatchNorm2D(32), nn.ReLU(),
+ nn.Pad2D([1, 1, 1, 1], 'reflect'), nn.Conv2D(32, 64, 3, 2), nn.BatchNorm2D(64), nn.ReLU(),
+ nn.Pad2D([1, 1, 1, 1], 'reflect'), nn.Conv2D(64, 128, 3, 2), nn.BatchNorm2D(128), nn.ReLU())
+ self.conv = nn.Conv2D(128 * 2, hidden_dim, 1)
+ self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers)
+ self.linear_param = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
+ nn.Linear(hidden_dim, param_per_stroke))
+ self.linear_decider = nn.Linear(hidden_dim, 1)
+ self.query_pos = paddle.static.create_parameter([total_strokes, hidden_dim],
+ dtype='float32',
+ default_initializer=nn.initializer.Uniform(0, 1))
+ self.row_embed = paddle.static.create_parameter([8, hidden_dim // 2],
+ dtype='float32',
+ default_initializer=nn.initializer.Uniform(0, 1))
+ self.col_embed = paddle.static.create_parameter([8, hidden_dim // 2],
+ dtype='float32',
+ default_initializer=nn.initializer.Uniform(0, 1))
+
+ def forward(self, img, canvas):
+ """
+ prediction
+ """
+ b, _, H, W = img.shape
+ img_feat = self.enc_img(img)
+ canvas_feat = self.enc_canvas(canvas)
+ h, w = img_feat.shape[-2:]
+ feat = paddle.concat([img_feat, canvas_feat], axis=1)
+ feat_conv = self.conv(feat)
+
+ pos_embed = paddle.concat([
+ self.col_embed[:w].unsqueeze(0).tile([h, 1, 1]),
+ self.row_embed[:h].unsqueeze(1).tile([1, w, 1]),
+ ],
+ axis=-1).flatten(0, 1).unsqueeze(1)
+
+ hidden_state = self.transformer((pos_embed + feat_conv.flatten(2).transpose([2, 0, 1])).transpose([1, 0, 2]),
+ self.query_pos.unsqueeze(1).tile([1, b, 1]).transpose([1, 0, 2]))
+
+ param = self.linear_param(hidden_state)
+ decision = self.linear_decider(hidden_state)
+ return param, decision
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/module.py b/modules/image/Image_gan/style_transfer/painttransformer/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea886fa175b6d216329f75f4526a37e3f4f7c94
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/module.py
@@ -0,0 +1,159 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import copy
+
+import paddle
+import paddlehub as hub
+from paddlehub.module.module import moduleinfo, runnable, serving
+import numpy as np
+import cv2
+from skimage.io import imread
+from skimage.transform import rescale, resize
+
+from .model import Painter
+from .render_utils import totensor, read_img
+from .render_serial import render_serial
+from .util import base64_to_cv2
+
+
+@moduleinfo(
+ name="painttransformer",
+ type="CV/style_transfer",
+ author="paddlepaddle",
+ author_email="",
+ summary="",
+ version="1.0.0")
+class painttransformer:
+ def __init__(self):
+ self.pretrained_model = os.path.join(self.directory, "paint_best.pdparams")
+
+ self.network = Painter(5, 8, 256, 8, 3, 3)
+ self.network.set_state_dict(paddle.load(self.pretrained_model))
+ self.network.eval()
+ for param in self.network.parameters():
+ param.stop_gradient = True
+ #* ----- load brush ----- *#
+ brush_large_vertical = read_img(os.path.join(self.directory, 'brush/brush_large_vertical.png'), 'L')
+ brush_large_horizontal = read_img(os.path.join(self.directory, 'brush/brush_large_horizontal.png'), 'L')
+ self.meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0)
+
+ def style_transfer(self,
+ images=None,
+ paths=None,
+ output_dir='./transfer_result/',
+ use_gpu=False,
+ need_animation=False,
+ visualization=True):
+ '''
+
+
+ images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2).
+ paths (list[str]): paths to images
+ output_dir: the dir to save the results
+ use_gpu: if True, use gpu to perform the computation, otherwise cpu.
+ visualization: if True, save results in output_dir.
+ '''
+ results = []
+ paddle.disable_static()
+ place = 'gpu:0' if use_gpu else 'cpu'
+ place = paddle.set_device(place)
+ if images == None and paths == None:
+ print('No image provided. Please input an image or a image path.')
+ return
+
+ if images != None:
+ for image in images:
+ image = image[:, :, ::-1]
+ image = totensor(image)
+ final_result_list = render_serial(image, self.network, self.meta_brushes)
+ results.append(final_result_list)
+
+ if paths != None:
+ for path in paths:
+ image = cv2.imread(path)[:, :, ::-1]
+ image = totensor(image)
+ final_result_list = render_serial(image, self.network, self.meta_brushes)
+ results.append(final_result_list)
+
+ if visualization == True:
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ for i, out in enumerate(results):
+ if out:
+ if need_animation:
+ curoutputdir = os.path.join(output_dir, 'output_{}'.format(i))
+ if not os.path.exists(curoutputdir):
+ os.makedirs(curoutputdir, exist_ok=True)
+ for j, outimg in enumerate(out):
+ cv2.imwrite(os.path.join(curoutputdir, 'frame_{}.png'.format(j)), outimg)
+ else:
+ cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[-1])
+
+ return results
+
+ @runnable
+ def run_cmd(self, argvs: list):
+ """
+ Run as a command.
+ """
+ self.parser = argparse.ArgumentParser(
+ description="Run the {} module.".format(self.name),
+ prog='hub run {}'.format(self.name),
+ usage='%(prog)s',
+ add_help=True)
+
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.arg_config_group = self.parser.add_argument_group(
+ title="Config options", description="Run configuration for controlling module behavior, not required.")
+ self.add_module_config_arg()
+ self.add_module_input_arg()
+ self.args = self.parser.parse_args(argvs)
+ results = self.style_transfer(
+ paths=[self.args.input_path],
+ output_dir=self.args.output_dir,
+ use_gpu=self.args.use_gpu,
+ need_animation=self.args.need_animation,
+ visualization=self.args.visualization)
+ return results
+
+ @serving
+ def serving_method(self, images, **kwargs):
+ """
+ Run as a service.
+ """
+ images_decode = [base64_to_cv2(image) for image in images]
+ results = self.style_transfer(images=images_decode, **kwargs)
+ tolist = [result.tolist() for result in results]
+ return tolist
+
+ def add_module_config_arg(self):
+ """
+ Add the command config options.
+ """
+ self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
+
+ self.arg_config_group.add_argument(
+ '--output_dir', type=str, default='transfer_result', help='output directory for saving result.')
+ self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
+ self.arg_config_group.add_argument(
+ '--need_animation', type=bool, default=False, help='save intermediate results or not.')
+
+ def add_module_input_arg(self):
+ """
+ Add the command input options.
+ """
+ self.arg_input_group.add_argument('--input_path', type=str, help="path to input image.")
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/render_parallel.py b/modules/image/Image_gan/style_transfer/painttransformer/render_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..a58ebec4bdae82881c8339dd6cae81ddc11407c2
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/render_parallel.py
@@ -0,0 +1,247 @@
+import render_utils
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+import math
+
+
+def crop(img, h, w):
+ H, W = img.shape[-2:]
+ pad_h = (H - h) // 2
+ pad_w = (W - w) // 2
+ remainder_h = (H - h) % 2
+ remainder_w = (W - w) % 2
+ img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w]
+ return img
+
+
+def stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num):
+ """
+ stroke_net_predict
+ """
+ img_patch = img_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size])
+ result_patch = result_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size])
+ #*----- Stroke Predictor -----*#
+ shape_param, stroke_decision = net_g(img_patch, result_patch)
+ stroke_decision = (stroke_decision > 0).astype('float32')
+ #*----- sampling color -----*#
+ grid = shape_param[:, :, :2].reshape([img_patch.shape[0] * stroke_num, 1, 1, 2])
+ img_temp = img_patch.unsqueeze(1).tile([1, stroke_num, 1, 1,
+ 1]).reshape([img_patch.shape[0] * stroke_num, 3, patch_size, patch_size])
+ color = nn.functional.grid_sample(
+ img_temp, 2 * grid - 1, align_corners=False).reshape([img_patch.shape[0], stroke_num, 3])
+ param = paddle.concat([shape_param, color], axis=-1)
+
+ param = param.reshape([-1, 8])
+ param[:, :2] = param[:, :2] / 2 + 0.25
+ param[:, 2:4] = param[:, 2:4] / 2
+ param = param.reshape([1, patch_num, patch_num, stroke_num, 8])
+ decision = stroke_decision.reshape([1, patch_num, patch_num, stroke_num]) #.astype('bool')
+ return param, decision
+
+
+def param2img_parallel(param, decision, meta_brushes, cur_canvas, stroke_num=8):
+ """
+ Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
+ and whether there is a border (if intermediate painting results are required).
+ Output the painting results of adding the corresponding strokes on the current canvas.
+ Args:
+ param: a tensor with shape batch size x patch along height dimension x patch along width dimension
+ x n_stroke_per_patch x n_param_per_stroke
+ decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
+ x n_stroke_per_patch
+ meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
+ The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
+ cur_canvas: a tensor with shape batch size x 3 x H x W,
+ where H and W denote height and width of padded results of original images.
+
+ Returns:
+ cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
+ """
+ # param: b, h, w, stroke_per_patch, param_per_stroke
+ # decision: b, h, w, stroke_per_patch
+ b, h, w, s, p = param.shape
+ h, w = int(h), int(w)
+ param = param.reshape([-1, 8])
+ decision = decision.reshape([-1, 8])
+
+ H, W = cur_canvas.shape[-2:]
+ is_odd_y = h % 2 == 1
+ is_odd_x = w % 2 == 1
+ render_size_y = 2 * H // h
+ render_size_x = 2 * W // w
+
+ even_idx_y = paddle.arange(0, h, 2)
+ even_idx_x = paddle.arange(0, w, 2)
+ if h > 1:
+ odd_idx_y = paddle.arange(1, h, 2)
+ if w > 1:
+ odd_idx_x = paddle.arange(1, w, 2)
+
+ cur_canvas = F.pad(cur_canvas, [render_size_x // 4, render_size_x // 4, render_size_y // 4, render_size_y // 4])
+
+ valid_foregrounds = render_utils.param2stroke(param, render_size_y, render_size_x, meta_brushes)
+
+ #* ----- load dilation/erosion ---- *#
+ dilation = render_utils.Dilation2d(m=1)
+ erosion = render_utils.Erosion2d(m=1)
+
+ #* ----- generate alphas ----- *#
+ valid_alphas = (valid_foregrounds > 0).astype('float32')
+ valid_foregrounds = valid_foregrounds.reshape([-1, stroke_num, 1, render_size_y, render_size_x])
+ valid_alphas = valid_alphas.reshape([-1, stroke_num, 1, render_size_y, render_size_x])
+
+ temp = [dilation(valid_foregrounds[:, i, :, :, :]) for i in range(stroke_num)]
+ valid_foregrounds = paddle.stack(temp, axis=1)
+ valid_foregrounds = valid_foregrounds.reshape([-1, 1, render_size_y, render_size_x])
+
+ temp = [erosion(valid_alphas[:, i, :, :, :]) for i in range(stroke_num)]
+ valid_alphas = paddle.stack(temp, axis=1)
+ valid_alphas = valid_alphas.reshape([-1, 1, render_size_y, render_size_x])
+
+ foregrounds = valid_foregrounds.reshape([-1, h, w, stroke_num, 1, render_size_y, render_size_x])
+ alphas = valid_alphas.reshape([-1, h, w, stroke_num, 1, render_size_y, render_size_x])
+ decision = decision.reshape([-1, h, w, stroke_num, 1, 1, 1])
+ param = param.reshape([-1, h, w, stroke_num, 8])
+
+ def partial_render(this_canvas, patch_coord_y, patch_coord_x):
+ canvas_patch = F.unfold(
+ this_canvas, [render_size_y, render_size_x], strides=[render_size_y // 2, render_size_x // 2])
+ # canvas_patch: b, 3 * py * px, h * w
+ canvas_patch = canvas_patch.reshape([b, 3, render_size_y, render_size_x, h, w])
+ canvas_patch = canvas_patch.transpose([0, 4, 5, 1, 2, 3])
+ selected_canvas_patch = paddle.gather(canvas_patch, patch_coord_y, 1)
+ selected_canvas_patch = paddle.gather(selected_canvas_patch, patch_coord_x, 2)
+ selected_canvas_patch = selected_canvas_patch.reshape([0, 0, 0, 1, 3, render_size_y, render_size_x])
+ selected_foregrounds = paddle.gather(foregrounds, patch_coord_y, 1)
+ selected_foregrounds = paddle.gather(selected_foregrounds, patch_coord_x, 2)
+ selected_alphas = paddle.gather(alphas, patch_coord_y, 1)
+ selected_alphas = paddle.gather(selected_alphas, patch_coord_x, 2)
+ selected_decisions = paddle.gather(decision, patch_coord_y, 1)
+ selected_decisions = paddle.gather(selected_decisions, patch_coord_x, 2)
+ selected_color = paddle.gather(param, patch_coord_y, 1)
+ selected_color = paddle.gather(selected_color, patch_coord_x, 2)
+ selected_color = paddle.gather(selected_color, paddle.to_tensor([5, 6, 7]), 4)
+ selected_color = selected_color.reshape([0, 0, 0, stroke_num, 3, 1, 1])
+
+ for i in range(stroke_num):
+ i = paddle.to_tensor(i)
+
+ cur_foreground = paddle.gather(selected_foregrounds, i, 3)
+ cur_alpha = paddle.gather(selected_alphas, i, 3)
+ cur_decision = paddle.gather(selected_decisions, i, 3)
+ cur_color = paddle.gather(selected_color, i, 3)
+ cur_foreground = cur_foreground * cur_color
+ selected_canvas_patch = cur_foreground * cur_alpha * cur_decision + selected_canvas_patch * (
+ 1 - cur_alpha * cur_decision)
+
+ selected_canvas_patch = selected_canvas_patch.reshape([0, 0, 0, 3, render_size_y, render_size_x])
+ this_canvas = selected_canvas_patch.transpose([0, 3, 1, 4, 2, 5])
+
+ # this_canvas: b, 3, h_half, py, w_half, px
+ h_half = this_canvas.shape[2]
+ w_half = this_canvas.shape[4]
+ this_canvas = this_canvas.reshape([b, 3, h_half * render_size_y, w_half * render_size_x])
+ # this_canvas: b, 3, h_half * py, w_half * px
+ return this_canvas
+
+ # even - even area
+ # 1 | 0
+ # 0 | 0
+ canvas = partial_render(cur_canvas, even_idx_y, even_idx_x)
+ if not is_odd_y:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2)
+ if not is_odd_x:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
+ cur_canvas = canvas
+
+ # odd - odd area
+ # 0 | 0
+ # 0 | 1
+ if h > 1 and w > 1:
+ canvas = partial_render(cur_canvas, odd_idx_y, odd_idx_x)
+ canvas = paddle.concat([cur_canvas[:, :, :render_size_y // 2, -canvas.shape[3]:], canvas], axis=2)
+ canvas = paddle.concat([cur_canvas[:, :, -canvas.shape[2]:, :render_size_x // 2], canvas], axis=3)
+ if is_odd_y:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2)
+ if is_odd_x:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
+ cur_canvas = canvas
+
+ # odd - even area
+ # 0 | 0
+ # 1 | 0
+ if h > 1:
+ canvas = partial_render(cur_canvas, odd_idx_y, even_idx_x)
+ canvas = paddle.concat([cur_canvas[:, :, :render_size_y // 2, :canvas.shape[3]], canvas], axis=2)
+ if is_odd_y:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2)
+ if not is_odd_x:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
+ cur_canvas = canvas
+
+ # odd - even area
+ # 0 | 1
+ # 0 | 0
+ if w > 1:
+ canvas = partial_render(cur_canvas, even_idx_y, odd_idx_x)
+ canvas = paddle.concat([cur_canvas[:, :, :canvas.shape[2], :render_size_x // 2], canvas], axis=3)
+ if not is_odd_y:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, -canvas.shape[3]:]], axis=2)
+ if is_odd_x:
+ canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
+ cur_canvas = canvas
+
+ cur_canvas = cur_canvas[:, :, render_size_y // 4:-render_size_y // 4, render_size_x // 4:-render_size_x // 4]
+
+ return cur_canvas
+
+
+def render_parallel(original_img, net_g, meta_brushes):
+
+ patch_size = 32
+ stroke_num = 8
+
+ with paddle.no_grad():
+
+ original_h, original_w = original_img.shape[-2:]
+ K = max(math.ceil(math.log2(max(original_h, original_w) / patch_size)), 0)
+ original_img_pad_size = patch_size * (2**K)
+ original_img_pad = render_utils.pad(original_img, original_img_pad_size, original_img_pad_size)
+ final_result = paddle.zeros_like(original_img)
+
+ for layer in range(0, K + 1):
+ layer_size = patch_size * (2**layer)
+
+ img = F.interpolate(original_img_pad, (layer_size, layer_size))
+ result = F.interpolate(final_result, (layer_size, layer_size))
+ img_patch = F.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size])
+ result_patch = F.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size])
+
+ # There are patch_num * patch_num patches in total
+ patch_num = (layer_size - patch_size) // patch_size + 1
+ param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num)
+
+ #print(param.shape, decision.shape)
+ final_result = param2img_parallel(param, decision, meta_brushes, final_result)
+
+ # paint another time for last layer
+ border_size = original_img_pad_size // (2 * patch_num)
+ img = F.interpolate(original_img_pad, (layer_size, layer_size))
+ result = F.interpolate(final_result, (layer_size, layer_size))
+ img = F.pad(img, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2])
+ result = F.pad(result, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2])
+ img_patch = F.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size])
+ result_patch = F.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size])
+ final_result = F.pad(final_result, [border_size, border_size, border_size, border_size])
+ patch_num = (img.shape[2] - patch_size) // patch_size + 1
+ #w = (img.shape[3] - patch_size) // patch_size + 1
+
+ param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num)
+
+ final_result = param2img_parallel(param, decision, meta_brushes, final_result)
+
+ final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
+ final_result = (final_result.numpy().squeeze().transpose([1, 2, 0])[:, :, ::-1] * 255).astype(np.uint8)
+ return final_result
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/render_serial.py b/modules/image/Image_gan/style_transfer/painttransformer/render_serial.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3170a29a174bc03593f44ec5d248299724c253f
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/render_serial.py
@@ -0,0 +1,280 @@
+# !/usr/bin/env python3
+"""
+codes for oilpainting style transfer.
+"""
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+from PIL import Image
+import math
+import cv2
+import time
+from .render_utils import param2stroke, Dilation2d, Erosion2d
+
+
+def get_single_layer_lists(param, decision, ori_img, render_size_x, render_size_y, h, w, meta_brushes, dilation,
+ erosion, stroke_num):
+ """
+ get_single_layer_lists
+ """
+ valid_foregrounds = param2stroke(param[:, :], render_size_y, render_size_x, meta_brushes)
+
+ valid_alphas = (valid_foregrounds > 0).astype('float32')
+ valid_foregrounds = valid_foregrounds.reshape([-1, stroke_num, 1, render_size_y, render_size_x])
+ valid_alphas = valid_alphas.reshape([-1, stroke_num, 1, render_size_y, render_size_x])
+
+ temp = [dilation(valid_foregrounds[:, i, :, :, :]) for i in range(stroke_num)]
+ valid_foregrounds = paddle.stack(temp, axis=1)
+ valid_foregrounds = valid_foregrounds.reshape([-1, 1, render_size_y, render_size_x])
+
+ temp = [erosion(valid_alphas[:, i, :, :, :]) for i in range(stroke_num)]
+ valid_alphas = paddle.stack(temp, axis=1)
+ valid_alphas = valid_alphas.reshape([-1, 1, render_size_y, render_size_x])
+
+ patch_y = 4 * render_size_y // 5
+ patch_x = 4 * render_size_x // 5
+
+ img_patch = ori_img.reshape([1, 3, h, ori_img.shape[2] // h, w, ori_img.shape[3] // w])
+ img_patch = img_patch.transpose([0, 2, 4, 1, 3, 5])[0]
+
+ xid_list = []
+ yid_list = []
+ error_list = []
+
+ for flag_idx, flag in enumerate(decision.cpu().numpy()):
+ if flag:
+ flag_idx = flag_idx // stroke_num
+ x_id = flag_idx % w
+ flag_idx = flag_idx // w
+ y_id = flag_idx % h
+ xid_list.append(x_id)
+ yid_list.append(y_id)
+
+ inner_fores = valid_foregrounds[:, :, render_size_y // 10:9 * render_size_y // 10, render_size_x // 10:9 *
+ render_size_x // 10]
+ inner_alpha = valid_alphas[:, :, render_size_y // 10:9 * render_size_y // 10, render_size_x // 10:9 *
+ render_size_x // 10]
+ inner_fores = inner_fores.reshape([h * w, stroke_num, 1, patch_y, patch_x])
+ inner_alpha = inner_alpha.reshape([h * w, stroke_num, 1, patch_y, patch_x])
+ inner_real = img_patch.reshape([h * w, 3, patch_y, patch_x]).unsqueeze(1)
+
+ R = param[:, 5]
+ G = param[:, 6]
+ B = param[:, 7] #, G, B = param[5:]
+ R = R.reshape([-1, stroke_num]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ G = G.reshape([-1, stroke_num]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ B = B.reshape([-1, stroke_num]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ error_R = R * inner_fores - inner_real[:, :, 0:1, :, :]
+ error_G = G * inner_fores - inner_real[:, :, 1:2, :, :]
+ error_B = B * inner_fores - inner_real[:, :, 2:3, :, :]
+ error = paddle.abs(error_R) + paddle.abs(error_G) + paddle.abs(error_B)
+
+ error = error * inner_alpha
+ error = paddle.sum(error, axis=(2, 3, 4)) / paddle.sum(inner_alpha, axis=(2, 3, 4))
+ error_list = error.reshape([-1]).numpy()[decision.numpy()]
+ error_list = list(error_list)
+
+ valid_foregrounds = paddle.to_tensor(valid_foregrounds.numpy()[decision.numpy()])
+ valid_alphas = paddle.to_tensor(valid_alphas.numpy()[decision.numpy()])
+
+ selected_param = paddle.to_tensor(param.numpy()[decision.numpy()])
+ return xid_list, yid_list, valid_foregrounds, valid_alphas, error_list, selected_param
+
+
+def get_single_stroke_on_full_image_A(x_id, y_id, valid_foregrounds, valid_alphas, param, original_img, render_size_x,
+ render_size_y, patch_x, patch_y):
+ """
+ get_single_stroke_on_full_image_A
+ """
+ tmp_foreground = paddle.zeros_like(original_img)
+
+ patch_y_num = original_img.shape[2] // patch_y
+ patch_x_num = original_img.shape[3] // patch_x
+
+ brush = valid_foregrounds.unsqueeze(0)
+ color_map = param[5:]
+ brush = brush.tile([1, 3, 1, 1])
+ color_map = color_map.unsqueeze(-1).unsqueeze(-1).unsqueeze(0) #.repeat(1, 1, H, W)
+ brush = brush * color_map
+
+ pad_l = x_id * patch_x
+ pad_r = (patch_x_num - x_id - 1) * patch_x
+ pad_t = y_id * patch_y
+ pad_b = (patch_y_num - y_id - 1) * patch_y
+ tmp_foreground = nn.functional.pad(brush, [pad_l, pad_r, pad_t, pad_b])
+ tmp_foreground = tmp_foreground[:, :, render_size_y // 10:-render_size_y // 10, render_size_x //
+ 10:-render_size_x // 10]
+
+ tmp_alpha = nn.functional.pad(valid_alphas.unsqueeze(0), [pad_l, pad_r, pad_t, pad_b])
+ tmp_alpha = tmp_alpha[:, :, render_size_y // 10:-render_size_y // 10, render_size_x // 10:-render_size_x // 10]
+ return tmp_foreground, tmp_alpha
+
+
+def get_single_stroke_on_full_image_B(x_id, y_id, valid_foregrounds, valid_alphas, param, original_img, render_size_x,
+ render_size_y, patch_x, patch_y):
+ """
+ get_single_stroke_on_full_image_B
+ """
+ x_expand = patch_x // 2 + render_size_x // 10
+ y_expand = patch_y // 2 + render_size_y // 10
+
+ pad_l = x_id * patch_x
+ pad_r = original_img.shape[3] + 2 * x_expand - (x_id * patch_x + render_size_x)
+ pad_t = y_id * patch_y
+ pad_b = original_img.shape[2] + 2 * y_expand - (y_id * patch_y + render_size_y)
+
+ brush = valid_foregrounds.unsqueeze(0)
+ color_map = param[5:]
+ brush = brush.tile([1, 3, 1, 1])
+ color_map = color_map.unsqueeze(-1).unsqueeze(-1).unsqueeze(0) #.repeat(1, 1, H, W)
+ brush = brush * color_map
+
+ tmp_foreground = nn.functional.pad(brush, [pad_l, pad_r, pad_t, pad_b])
+
+ tmp_foreground = tmp_foreground[:, :, y_expand:-y_expand, x_expand:-x_expand]
+ tmp_alpha = nn.functional.pad(valid_alphas.unsqueeze(0), [pad_l, pad_r, pad_t, pad_b])
+ tmp_alpha = tmp_alpha[:, :, y_expand:-y_expand, x_expand:-x_expand]
+ return tmp_foreground, tmp_alpha
+
+
+def stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num):
+ """
+ stroke_net_predict
+ """
+ img_patch = img_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size])
+ result_patch = result_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size])
+ #*----- Stroke Predictor -----*#
+ shape_param, stroke_decision = net_g(img_patch, result_patch)
+ stroke_decision = (stroke_decision > 0).astype('float32')
+ #*----- sampling color -----*#
+ grid = shape_param[:, :, :2].reshape([img_patch.shape[0] * stroke_num, 1, 1, 2])
+ img_temp = img_patch.unsqueeze(1).tile([1, stroke_num, 1, 1,
+ 1]).reshape([img_patch.shape[0] * stroke_num, 3, patch_size, patch_size])
+ color = nn.functional.grid_sample(
+ img_temp, 2 * grid - 1, align_corners=False).reshape([img_patch.shape[0], stroke_num, 3])
+ stroke_param = paddle.concat([shape_param, color], axis=-1)
+
+ param = stroke_param.reshape([-1, 8])
+ decision = stroke_decision.reshape([-1]).astype('bool')
+ param[:, :2] = param[:, :2] / 1.25 + 0.1
+ param[:, 2:4] = param[:, 2:4] / 1.25
+ return param, decision
+
+
+def sort_strokes(params, decision, scores):
+ """
+ sort_strokes
+ """
+ sorted_scores, sorted_index = paddle.sort(scores, axis=1, descending=False)
+ sorted_params = []
+ for idx in range(8):
+ tmp_pick_params = paddle.gather(params[:, :, idx], axis=1, index=sorted_index)
+ sorted_params.append(tmp_pick_params)
+ sorted_params = paddle.stack(sorted_params, axis=2)
+ sorted_decison = paddle.gather(decision.squeeze(2), axis=1, index=sorted_index)
+ return sorted_params, sorted_decison
+
+
+def render_serial(original_img, net_g, meta_brushes):
+
+ patch_size = 32
+ stroke_num = 8
+ H, W = original_img.shape[-2:]
+ K = max(math.ceil(math.log2(max(H, W) / patch_size)), 0)
+
+ dilation = Dilation2d(m=1)
+ erosion = Erosion2d(m=1)
+ frames_per_layer = [20, 20, 30, 40, 60]
+ final_frame_list = []
+
+ with paddle.no_grad():
+ #* ----- read in image and init canvas ----- *#
+ final_result = paddle.zeros_like(original_img)
+
+ for layer in range(0, K + 1):
+ t0 = time.time()
+ layer_size = patch_size * (2**layer)
+
+ img = nn.functional.interpolate(original_img, (layer_size, layer_size))
+ result = nn.functional.interpolate(final_result, (layer_size, layer_size))
+ img_patch = nn.functional.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size])
+ result_patch = nn.functional.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size])
+ h = (img.shape[2] - patch_size) // patch_size + 1
+ w = (img.shape[3] - patch_size) // patch_size + 1
+ render_size_y = int(1.25 * H // h)
+ render_size_x = int(1.25 * W // w)
+
+ #* -------------------------------------------------------------*#
+ #* -------------generate strokes on window type A---------------*#
+ #* -------------------------------------------------------------*#
+ param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num)
+ expand_img = original_img
+ wA_xid_list, wA_yid_list, wA_fore_list, wA_alpha_list, wA_error_list, wA_params = \
+ get_single_layer_lists(param, decision, original_img, render_size_x, render_size_y, h, w,
+ meta_brushes, dilation, erosion, stroke_num)
+
+ #* -------------------------------------------------------------*#
+ #* -------------generate strokes on window type B---------------*#
+ #* -------------------------------------------------------------*#
+ #*----- generate input canvas and target patches -----*#
+ wB_error_list = []
+
+ img = nn.functional.pad(img, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2])
+ result = nn.functional.pad(result, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2])
+ img_patch = nn.functional.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size])
+ result_patch = nn.functional.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size])
+ h += 1
+ w += 1
+
+ param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num)
+
+ patch_y = 4 * render_size_y // 5
+ patch_x = 4 * render_size_x // 5
+ expand_img = nn.functional.pad(original_img, [patch_x // 2, patch_x // 2, patch_y // 2, patch_y // 2])
+ wB_xid_list, wB_yid_list, wB_fore_list, wB_alpha_list, wB_error_list, wB_params = \
+ get_single_layer_lists(param, decision, expand_img, render_size_x, render_size_y, h, w,
+ meta_brushes, dilation, erosion, stroke_num)
+ #* -------------------------------------------------------------*#
+ #* -------------rank strokes and plot stroke one by one---------*#
+ #* -------------------------------------------------------------*#
+ numA = len(wA_error_list)
+ numB = len(wB_error_list)
+ total_error_list = wA_error_list + wB_error_list
+ sort_list = list(np.argsort(total_error_list))
+
+ sample = 0
+ samples = np.linspace(0, len(sort_list) - 2, frames_per_layer[layer]).astype(int)
+ for ii in sort_list:
+ ii = int(ii)
+ if ii < numA:
+ x_id = wA_xid_list[ii]
+ y_id = wA_yid_list[ii]
+ valid_foregrounds = wA_fore_list[ii]
+ valid_alphas = wA_alpha_list[ii]
+ sparam = wA_params[ii]
+ tmp_foreground, tmp_alpha = get_single_stroke_on_full_image_A(
+ x_id, y_id, valid_foregrounds, valid_alphas, sparam, original_img, render_size_x, render_size_y,
+ patch_x, patch_y)
+ else:
+ x_id = wB_xid_list[ii - numA]
+ y_id = wB_yid_list[ii - numA]
+ valid_foregrounds = wB_fore_list[ii - numA]
+ valid_alphas = wB_alpha_list[ii - numA]
+ sparam = wB_params[ii - numA]
+ tmp_foreground, tmp_alpha = get_single_stroke_on_full_image_B(
+ x_id, y_id, valid_foregrounds, valid_alphas, sparam, original_img, render_size_x, render_size_y,
+ patch_x, patch_y)
+
+ final_result = tmp_foreground * tmp_alpha + (1 - tmp_alpha) * final_result
+ if sample in samples:
+ saveframe = (final_result.numpy().squeeze().transpose([1, 2, 0])[:, :, ::-1] * 255).astype(np.uint8)
+ final_frame_list.append(saveframe)
+ #saveframe = cv2.resize(saveframe, (ow, oh))
+
+ sample += 1
+ print("layer %d cost: %.02f" % (layer, time.time() - t0))
+
+ saveframe = (final_result.numpy().squeeze().transpose([1, 2, 0])[:, :, ::-1] * 255).astype(np.uint8)
+ final_frame_list.append(saveframe)
+ return final_frame_list
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/render_utils.py b/modules/image/Image_gan/style_transfer/painttransformer/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..735ac983a343961939fe333b06ac2b1fec01654f
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/render_utils.py
@@ -0,0 +1,111 @@
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import cv2
+import numpy as np
+from PIL import Image
+import math
+
+
+class Erosion2d(nn.Layer):
+ """
+ Erosion2d
+ """
+
+ def __init__(self, m=1):
+ super(Erosion2d, self).__init__()
+ self.m = m
+ self.pad = [m, m, m, m]
+
+ def forward(self, x):
+ batch_size, c, h, w = x.shape
+ x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9)
+ channel = nn.functional.unfold(x_pad, 2 * self.m + 1, strides=1, paddings=0).reshape([batch_size, c, -1, h, w])
+ result = paddle.min(channel, axis=2)
+ return result
+
+
+class Dilation2d(nn.Layer):
+ """
+ Dilation2d
+ """
+
+ def __init__(self, m=1):
+ super(Dilation2d, self).__init__()
+ self.m = m
+ self.pad = [m, m, m, m]
+
+ def forward(self, x):
+ batch_size, c, h, w = x.shape
+ x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9)
+ channel = nn.functional.unfold(x_pad, 2 * self.m + 1, strides=1, paddings=0).reshape([batch_size, c, -1, h, w])
+ result = paddle.max(channel, axis=2)
+ return result
+
+
+def param2stroke(param, H, W, meta_brushes):
+ """
+ param2stroke
+ """
+ b = param.shape[0]
+ param_list = paddle.split(param, 8, axis=1)
+ x0, y0, w, h, theta = [item.squeeze(-1) for item in param_list[:5]]
+ sin_theta = paddle.sin(math.pi * theta)
+ cos_theta = paddle.cos(math.pi * theta)
+ index = paddle.full((b, ), -1, dtype='int64').numpy()
+
+ index[(h > w).numpy()] = 0
+ index[(h <= w).numpy()] = 1
+ meta_brushes_resize = F.interpolate(meta_brushes, (H, W)).numpy()
+ brush = paddle.to_tensor(meta_brushes_resize[index])
+
+ warp_00 = cos_theta / w
+ warp_01 = sin_theta * H / (W * w)
+ warp_02 = (1 - 2 * x0) * cos_theta / w + (1 - 2 * y0) * sin_theta * H / (W * w)
+ warp_10 = -sin_theta * W / (H * h)
+ warp_11 = cos_theta / h
+ warp_12 = (1 - 2 * y0) * cos_theta / h - (1 - 2 * x0) * sin_theta * W / (H * h)
+ warp_0 = paddle.stack([warp_00, warp_01, warp_02], axis=1)
+ warp_1 = paddle.stack([warp_10, warp_11, warp_12], axis=1)
+ warp = paddle.stack([warp_0, warp_1], axis=1)
+ grid = nn.functional.affine_grid(warp, [b, 3, H, W]) # paddle和torch默认值是反过来的
+ brush = nn.functional.grid_sample(brush, grid)
+ return brush
+
+
+def read_img(img_path, img_type='RGB', h=None, w=None):
+ """
+ read img
+ """
+ img = Image.open(img_path).convert(img_type)
+ if h is not None and w is not None:
+ img = img.resize((w, h), resample=Image.NEAREST)
+ img = np.array(img)
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=-1)
+ img = img.transpose((2, 0, 1))
+ img = paddle.to_tensor(img).unsqueeze(0).astype('float32') / 255.
+ return img
+
+
+def preprocess(img, w=512, h=512):
+ image = cv2.resize(img, (w, h), cv2.INTER_NEAREST)
+ image = image.transpose((2, 0, 1))
+ image = paddle.to_tensor(image).unsqueeze(0).astype('float32') / 255.
+ return image
+
+
+def totensor(img):
+ image = img.transpose((2, 0, 1))
+ image = paddle.to_tensor(image).unsqueeze(0).astype('float32') / 255.
+ return image
+
+
+def pad(img, H, W):
+ b, c, h, w = img.shape
+ pad_h = (H - h) // 2
+ pad_w = (W - w) // 2
+ remainder_h = (H - h) % 2
+ remainder_w = (W - w) % 2
+ expand_img = nn.functional.pad(img, [pad_w, pad_w + remainder_w, pad_h, pad_h + remainder_h])
+ return expand_img
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/requirements.txt b/modules/image/Image_gan/style_transfer/painttransformer/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..67e9bb6fa840355e9ed0d44b7134850f1fe22fe1
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/requirements.txt
@@ -0,0 +1 @@
+ppgan
diff --git a/modules/image/Image_gan/style_transfer/painttransformer/util.py b/modules/image/Image_gan/style_transfer/painttransformer/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88ac3562b74cadc1d4d6459a56097ca4a938a0b
--- /dev/null
+++ b/modules/image/Image_gan/style_transfer/painttransformer/util.py
@@ -0,0 +1,10 @@
+import base64
+import cv2
+import numpy as np
+
+
+def base64_to_cv2(b64str):
+ data = base64.b64decode(b64str.encode('utf8'))
+ data = np.fromstring(data, np.uint8)
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
+ return data