未验证 提交 2ab96cb8 编写于 作者: L lzzyzlbb 提交者: GitHub

Add static model and inference of stylegan2,fom, basicvsr (#491)

* Update benchmark.yaml

* Update benchmark.yaml

* add static model and inference of fom, basicvsr, stylegan2

* add static model and inference of fom, basicvsr, stylegan2

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets
上级 6e3dad37
......@@ -12,7 +12,7 @@ FOMM:
fp_item: fp32
bs_item: 8 16
epochs: 1
log_interval: 11
log_interval: 1
esrgan:
dataset_web: https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14.tar
......
......@@ -38,6 +38,7 @@ dataset:
use_rot: True
scale: 4
val_partition: REDS4
num_clips: 270
test:
name: SRREDSMultipleGTDataset
......@@ -90,3 +91,6 @@ log_config:
snapshot_config:
interval: 5000
export_model:
- {name: 'generator', inputs_num: 1}
......@@ -119,3 +119,7 @@ log_config:
snapshot_config:
interval: 5
export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
......@@ -123,3 +123,6 @@ snapshot_config:
optimizer:
name: Adam
export_model:
- {}
......@@ -115,3 +115,6 @@ validate:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8
export_model:
- {name: 'netG', inputs_num: 1}
......@@ -55,7 +55,8 @@ class SRREDSMultipleGTDataset(Dataset):
use_rot=False,
scale=4,
val_partition='REDS4',
batch_size=4):
batch_size=4,
num_clips=270):
super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode
self.fileroot = str(lq_folder)
......@@ -69,6 +70,7 @@ class SRREDSMultipleGTDataset(Dataset):
self.scale = scale
self.val_partition = val_partition
self.batch_size = batch_size
self.num_clips = num_clips # training num of LQ and GT pairs
self.data_infos = self.load_annotations()
def __getitem__(self, idx):
......@@ -93,7 +95,7 @@ class SRREDSMultipleGTDataset(Dataset):
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, 270)]
keys = [f'{i:03d}' for i in range(0, self.num_clips)]
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
......@@ -170,7 +172,9 @@ class SRREDSMultipleGTDataset(Dataset):
gt_list = rlt[number_frames:]
# stack LQ images to NHWC, N is the frame number
frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list]
frame_list = [
v.transpose(2, 0, 1).astype('float32') for v in frame_list
]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
img_LQs = np.stack(frame_list, axis=0)
......
......@@ -160,7 +160,9 @@ def flow_warp(x,
Returns:
Tensor: Warped image or feature map.
"""
if x.shape[-2:] != flow.shape[1:3]:
x_h, x_w = x.shape[-2:]
flow_h, flow_w = flow.shape[1:3]
if x_h != flow_h or x_w != flow_w:
raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and '
f'flow ({flow.shape[1:3]}) are not the same.')
_, _, h, w = x.shape
......@@ -293,7 +295,7 @@ class SPyNet(nn.Layer):
supp = supp[::-1]
# flow computation
flow = paddle.to_tensor(np.zeros([n, 2, h // 32, w // 32], 'float32'))
flow = paddle.zeros([n, 2, h // 32, w // 32])
# level=0
flow_up = flow
......@@ -555,6 +557,7 @@ class BasicVSRNet(nn.Layer):
"""
n, t, c, h, w = lrs.shape
t = paddle.to_tensor(t)
assert h >= 64 and w >= 64, (
'The height and width of inputs should be at least 64, '
f'but got {h} and {w}.')
......@@ -567,19 +570,18 @@ class BasicVSRNet(nn.Layer):
# backward-time propgation
outputs = []
feat_prop = paddle.to_tensor(
np.zeros([n, self.mid_channels, h, w], 'float32'))
feat_prop = paddle.zeros([n, self.mid_channels, h, w])
for i in range(t - 1, -1, -1):
if i < t - 1: # no warping required for the last timestep
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1]))
flow1 = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow1.transpose([0, 2, 3, 1]))
feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1)
feat_prop = self.backward_resblocks(feat_prop)
outputs.append(feat_prop)
outputs = outputs[::-1]
# forward-time propagation and upsampling
feat_prop = paddle.zeros_like(feat_prop)
for i in range(0, t):
......@@ -610,6 +612,7 @@ class BasicVSRNet(nn.Layer):
class SecondOrderDeformableAlignment(nn.Layer):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
......
......@@ -32,9 +32,9 @@ class PixelNorm(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, input):
return input * paddle.rsqrt(
paddle.mean(input * input, 1, keepdim=True) + 1e-8)
def forward(self, inputs):
return inputs * paddle.rsqrt(
paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8)
class ModulatedConv2D(nn.Layer):
......@@ -93,8 +93,8 @@ class ModulatedConv2D(nn.Layer):
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})")
def forward(self, input, style):
batch, in_channel, height, width = input.shape
def forward(self, inputs, style):
batch, in_channel, height, width = inputs.shape
style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style
......@@ -107,13 +107,13 @@ class ModulatedConv2D(nn.Layer):
self.kernel_size, self.kernel_size))
if self.upsample:
input = input.reshape((1, batch * in_channel, height, width))
inputs = inputs.reshape((1, batch * in_channel, height, width))
weight = weight.reshape((batch, self.out_channel, in_channel,
self.kernel_size, self.kernel_size))
weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
(batch * in_channel, self.out_channel, self.kernel_size,
self.kernel_size))
out = F.conv2d_transpose(input,
out = F.conv2d_transpose(inputs,
weight,
padding=0,
stride=2,
......@@ -123,16 +123,16 @@ class ModulatedConv2D(nn.Layer):
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
inputs = self.blur(inputs)
_, _, height, width = inputs.shape
inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
else:
input = input.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(inputs, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))
......@@ -165,8 +165,8 @@ class ConstantInput(nn.Layer):
(1, channel, size, size),
default_initializer=nn.initializer.Normal())
def forward(self, input):
batch = input.shape[0]
def forward(self, inputs):
batch = inputs.shape[0]
out = self.input.tile((batch, 1, 1, 1))
return out
......@@ -198,8 +198,8 @@ class StyledConv(nn.Layer):
self.activate = FusedLeakyReLU(out_channel *
2 if is_concat else out_channel)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
def forward(self, inputs, style, noise=None):
out = self.conv(inputs, style)
out = self.noise(out, noise=noise)
out = self.activate(out)
......@@ -225,8 +225,8 @@ class ToRGB(nn.Layer):
self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
def forward(self, inputs, style, skip=None):
out = self.conv(inputs, style)
out = out + self.bias
if skip is not None:
......@@ -349,15 +349,28 @@ class StyleGANv2Generator(nn.Layer):
return latent
def get_latent(self, input):
return self.style(input)
def get_latent(self, inputs):
return self.style(inputs)
def get_mean_style(self):
mean_style = None
with paddle.no_grad():
for i in range(10):
style = self.mean_latent(1024)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
return mean_style
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation=1.0,
truncation_latent=None,
input_is_latent=False,
noise=None,
......@@ -375,9 +388,10 @@ class StyleGANv2Generator(nn.Layer):
for i in range(self.num_layers)
]
if truncation < 1:
if truncation < 1.0:
style_t = []
if truncation_latent is None:
truncation_latent = self.get_mean_style()
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import random
import paddle
......@@ -24,6 +25,7 @@ from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer
def r1_penalty(real_pred, real_img):
"""
R1 regularization for discriminator. The core idea is to
......@@ -195,7 +197,6 @@ class StyleGAN2Model(BaseModel):
noises = []
for _ in range(num_noise):
noises.append(paddle.randn([batch, self.num_style_feat]))
return noises
def mixing_noise(self, batch, prob):
......@@ -294,3 +295,25 @@ class StyleGAN2Model(BaseModel):
metric.update(fake_img, self.real_img)
self.nets['gen_ema'].train()
class InferGenerator(paddle.nn.Layer):
def set_generator(self, generator):
self.generator = generator
def forward(self, style, truncation):
truncation_latent = self.generator.get_mean_style()
out = self.generator(styles=style,
truncation=truncation,
truncation_latent=truncation_latent)
return out[0]
def export_model(self,
export_model=None,
output_dir=None,
inputs_size=[[1, 1, 512], [1, 1]]):
infer_generator = self.InferGenerator()
infer_generator.set_generator(self.nets['gen'])
style = paddle.rand(shape=inputs_size[0], dtype='float32')
truncation = paddle.rand(shape=inputs_size[1], dtype='float32')
paddle.jit.save(infer_generator,
os.path.join(output_dir, "stylegan2model_gen"),
input_spec=[style, truncation])
......@@ -9,15 +9,19 @@ import paddle.fluid as fluid
import os
from functools import reduce
import paddle
from ppgan.utils.filesystem import makedirs
from pathlib import Path
def read_img(path):
img = imageio.imread(path)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# som images have 4 channels
if img.shape[2] > 3:
img = img[:,:,:3]
return img
img = imageio.imread(path)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# som images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def read_video(path):
reader = imageio.get_reader(path)
......@@ -32,10 +36,10 @@ def read_video(path):
reader.close()
return driving_video, fps
def face_detection(img_ori, weight_path):
config = paddle_infer.Config(
os.path.join(weight_path, '__model__'),
os.path.join(weight_path, '__params__'))
config = paddle_infer.Config(os.path.join(weight_path, '__model__'),
os.path.join(weight_path, '__params__'))
config.disable_gpu()
# disable print log when predict
config.disable_glog_info()
......@@ -44,10 +48,11 @@ def face_detection(img_ori, weight_path):
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
predictor = paddle_infer.create_predictor(config)
img = img_ori.astype(np.float32)
mean = np.array([123, 117, 104])[np.newaxis, np.newaxis, :]
std = np.array([127.502231, 127.502231, 127.502231])[np.newaxis, np.newaxis, :]
std = np.array([127.502231, 127.502231, 127.502231])[np.newaxis,
np.newaxis, :]
img -= mean
img /= std
img = img[:, :, [2, 1, 0]]
......@@ -82,101 +87,146 @@ def face_detection(img_ori, weight_path):
return int(y1), int(y2), int(x1), int(x2)
def main():
args = parse_args()
source_path = args.source_path
driving_path = args.driving_path
source_img = read_img(source_path)
#Todo:add blazeface static model
#left, right, up, bottom = face_detection(source_img, "/workspace/PaddleDetection/static/inference_model/blazeface/")
source = source_img #[left:right, up:bottom]
source = cv2.resize(source, (256, 256)) / 255.0
source = source[np.newaxis].astype(np.float32).transpose([0, 3, 1, 2])
driving_video, fps = read_video(driving_path)
driving_video = [cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video]
driving_len = len(driving_video)
driving_video = np.array(driving_video).astype(np.float32).transpose([0, 3, 1, 2])
source_path = args.source_path
driving_path = Path(args.driving_path)
makedirs(args.output_path)
if driving_path.is_dir():
driving_paths = list(driving_path.iterdir())
else:
driving_paths = [driving_path]
# 创建 config
kp_detector_config = paddle_infer.Config(args.model_profix+"/kp_detector.pdmodel", args.model_profix+"/kp_detector.pdiparams")
generator_config = paddle_infer.Config(args.model_profix+"/generator.pdmodel", args.model_profix+"/generator.pdiparams")
kp_detector_config.set_mkldnn_cache_capacity(10)
kp_detector_config.enable_mkldnn()
generator_config.set_mkldnn_cache_capacity(10)
generator_config.enable_mkldnn()
kp_detector_config.disable_gpu()
kp_detector_config.set_cpu_math_library_num_threads(6)
generator_config.disable_gpu()
generator_config.set_cpu_math_library_num_threads(6)
kp_detector_config = paddle_infer.Config(os.path.join(
args.model_path, "/kp_detector.pdmodel"),
os.path.join(args.model_path, "/kp_detector.pdiparams"))
generator_config = paddle_infer.Config(os.path.join(
args.model_path, "/generator.pdmodel"),
os.path.join(args.model_path, "/generator.pdiparams"))
if args.device == "gpu":
kp_detector_config.enable_use_gpu(100, 0)
generator_config.enable_use_gpu(100, 0)
else:
kp_detector_config.set_mkldnn_cache_capacity(10)
kp_detector_config.enable_mkldnn()
generator_config.set_mkldnn_cache_capacity(10)
generator_config.enable_mkldnn()
kp_detector_config.disable_gpu()
kp_detector_config.set_cpu_math_library_num_threads(6)
generator_config.disable_gpu()
generator_config.set_cpu_math_library_num_threads(6)
# 根据 config 创建 predictor
kp_detector_predictor = paddle_infer.create_predictor(kp_detector_config)
generator_predictor = paddle_infer.create_predictor(generator_config)
# 获取输入的名称
kp_detector_input_names = kp_detector_predictor.get_input_names()
kp_detector_input_handle = kp_detector_predictor.get_input_handle(kp_detector_input_names[0])
kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256])
kp_detector_input_handle.copy_from_cpu(source)
kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0])
source_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1])
source_v = kp_detector_output_handle.copy_to_cpu()
kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256])
kp_detector_input_handle.copy_from_cpu(driving_video[0:1])
kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0])
driving_init_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1])
driving_init_v = kp_detector_output_handle.copy_to_cpu()
start_time = time.time()
results = []
for i in tqdm(range(0, driving_len)):
kp_detector_input_handle.copy_from_cpu(driving_video[i:i+1])
for k in range(len(driving_paths)):
driving_path = driving_paths[k]
driving_video, fps = read_video(driving_path)
driving_video = [
cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video
]
driving_len = len(driving_video)
driving_video = np.array(driving_video).astype(np.float32).transpose(
[0, 3, 1, 2])
if source_path == None:
source = driving_video[0:1]
else:
source_img = read_img(source_path)
#Todo:add blazeface static model
#left, right, up, bottom = face_detection(source_img, "/workspace/PaddleDetection/static/inference_model/blazeface/")
source = source_img #[left:right, up:bottom]
source = cv2.resize(source, (256, 256)) / 255.0
source = source[np.newaxis].astype(np.float32).transpose(
[0, 3, 1, 2])
# 获取输入的名称
kp_detector_input_names = kp_detector_predictor.get_input_names()
kp_detector_input_handle = kp_detector_predictor.get_input_handle(
kp_detector_input_names[0])
kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256])
kp_detector_input_handle.copy_from_cpu(source)
kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[0])
source_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[1])
source_v = kp_detector_output_handle.copy_to_cpu()
kp_detector_input_handle.reshape([args.batch_size, 3, 256, 256])
kp_detector_input_handle.copy_from_cpu(driving_video[0:1])
kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[0])
driving_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(kp_detector_output_names[1])
driving_v = kp_detector_output_handle.copy_to_cpu()
generator_inputs = [source, source_j, source_v, driving_j, driving_v, driving_init_j, driving_init_v]
generator_input_names = generator_predictor.get_input_names()
for i in range(len(generator_input_names)):
generator_input_handle = generator_predictor.get_input_handle(generator_input_names[i])
generator_input_handle.copy_from_cpu(generator_inputs[i])
generator_predictor.run()
generator_output_names = generator_predictor.get_output_names()
generator_output_handle = generator_predictor.get_output_handle(generator_output_names[0])
output_data = generator_output_handle.copy_to_cpu()
output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0
#Todo:add blazeface static model
#frame = source_img.copy()
#frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA)
results.append(output_data.astype(np.uint8))
print(time.time() - start_time)
imageio.mimsave(args.output_path, [frame for frame in results], fps=fps)
kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[0])
driving_init_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[1])
driving_init_v = kp_detector_output_handle.copy_to_cpu()
start_time = time.time()
results = []
for i in tqdm(range(0, driving_len)):
kp_detector_input_handle.copy_from_cpu(driving_video[i:i + 1])
kp_detector_predictor.run()
kp_detector_output_names = kp_detector_predictor.get_output_names()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[0])
driving_j = kp_detector_output_handle.copy_to_cpu()
kp_detector_output_handle = kp_detector_predictor.get_output_handle(
kp_detector_output_names[1])
driving_v = kp_detector_output_handle.copy_to_cpu()
generator_inputs = [
source, source_j, source_v, driving_j, driving_v,
driving_init_j, driving_init_v
]
generator_input_names = generator_predictor.get_input_names()
for i in range(len(generator_input_names)):
generator_input_handle = generator_predictor.get_input_handle(
generator_input_names[i])
generator_input_handle.copy_from_cpu(generator_inputs[i])
generator_predictor.run()
generator_output_names = generator_predictor.get_output_names()
generator_output_handle = generator_predictor.get_output_handle(
generator_output_names[0])
output_data = generator_output_handle.copy_to_cpu()
output_data = np.transpose(output_data, [0, 2, 3, 1])[0] * 255.0
#Todo:add blazeface static model
#frame = source_img.copy()
#frame[left:right, up:bottom] = cv2.resize(output_data.astype(np.uint8), (bottom - up, right - left), cv2.INTER_AREA)
results.append(output_data.astype(np.uint8))
print(time.time() - start_time)
imageio.mimsave(os.path.join(args.output_path,
"result_" + str(k) + ".mp4"),
[frame for frame in results],
fps=fps)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_profix", type=str, help="model filename profix")
parser.add_argument("--model_path", type=str, help="model filename profix")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--source_path", type=str, default=1, help="source_path")
parser.add_argument("--driving_path", type=str, default=1, help="driving_path")
parser.add_argument("--output_path", type=str, default=1, help="output_path")
parser.add_argument("--source_path",
type=str,
default=None,
help="source_path")
parser.add_argument("--driving_path",
type=str,
default=None,
help="driving_path")
parser.add_argument("--output_path",
type=str,
default="infer_output/fom/",
help="output_path")
parser.add_argument("--device", type=str, default="gpu", help="device")
return parser.parse_args()
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -19,13 +19,14 @@ def parse_args():
default=None,
type=str,
required=True,
help="The path prefix of inference model to be used.", )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES))
help="The path prefix of inference model to be used.",
)
parser.add_argument("--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " +
", ".join(MODEL_CLASSES))
parser.add_argument(
"--device",
default="gpu",
......@@ -65,12 +66,14 @@ def main():
args = parse_args()
cfg = get_config(args.config_file, args.opt)
predictor = create_predictor(args.model_path, args.device)
input_handles = [predictor.get_input_handle(
name) for name in predictor.get_input_names()]
output_handle = predictor.get_output_handle(
predictor.get_output_names()[0])
test_dataloader = build_dataloader(
cfg.dataset.test, is_train=False, distributed=False)
input_handles = [
predictor.get_input_handle(name)
for name in predictor.get_input_names()
]
output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
test_dataloader = build_dataloader(cfg.dataset.test,
is_train=False,
distributed=False)
max_eval_steps = len(test_dataloader)
iter_loader = IterLoader(test_dataloader)
......@@ -110,8 +113,8 @@ def main():
prediction[j] = prediction[j][::-1, :, :]
image_numpy = paddle.to_tensor(prediction[j])
image_numpy = tensor2img(image_numpy, (0, 1))
save_image(
image_numpy, "infer_output/wav2lip/{}_{}.png".format(i, j))
save_image(image_numpy,
"infer_output/wav2lip/{}_{}.png".format(i, j))
elif model_type == "esrgan":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
......@@ -128,6 +131,23 @@ def main():
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/edvr/{}.png".format(i))
elif model_type == "stylegan2":
noise = paddle.randn([1, 1, 512]).cpu().numpy()
input_handles[0].copy_from_cpu(noise)
input_handles[1].copy_from_cpu(np.array([0.7]).astype('float32'))
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/stylegan2/{}.png".format(i))
elif model_type == "basicvsr":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/basicvsr/{}.png".format(i))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册