提交 fda5a23f 编写于 作者: L LielinJiang

refine code

上级 a1e739e7
......@@ -3,10 +3,6 @@ matplotlib.use('Agg')
import os
import sys
# cur_path = os.path.abspath(os.path.dirname(__file__))
# root_path = os.path.split(cur_path)[0]
# sys.path.append(root_path)
import yaml
import pickle
from argparse import ArgumentParser
......
......@@ -81,9 +81,10 @@ class OcclusionAwareGenerator(nn.Layer):
deformation = deformation.transpose([0, 3, 1, 2])
deformation = F.interpolate(deformation,
size=(h, w),
mode='bilinear')
mode='bilinear',
align_corners=False)
deformation = deformation.transpose([0, 2, 3, 1])
return F.grid_sample(inp, deformation)
return F.grid_sample(inp, deformation, align_corners=False)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
......@@ -113,7 +114,8 @@ class OcclusionAwareGenerator(nn.Layer):
3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map,
size=out.shape[2:],
mode='bilinear')
mode='bilinear',
align_corners=False)
out = out * occlusion_map
output_dict["deformed"] = self.deform_input(source_image,
......
......@@ -2,8 +2,6 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
# from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
def kp2gaussian(kp, spatial_size, kp_variance):
"""
......@@ -44,7 +42,6 @@ def make_coordinate_grid(spatial_size, type):
xx = paddle.tile(x.reshape([1, -1]), [h, 1])
meshed = paddle.concat([xx.unsqueeze(2), yy.unsqueeze(2)], 2)
# meshed = paddle.concat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
return meshed
......@@ -261,10 +258,7 @@ class AntiAliasInterpolation2d(nn.Layer):
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / paddle.sum(kernel)
# Reshape to depthwise convolutional weight
# print('debug shape:', kernel.shape)
# print('debug shape 1:', kernel.dim())
kernel = kernel.reshape([1, 1, *kernel.shape])
# kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
kernel = paddle.tile(kernel, [channels, *[1] * (kernel.dim() - 1)])
self.register_buffer('weight', kernel)
......
tqdm
PyYAML>=5.1
scikit-image>=0.14.0
scipy>=1.1.0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册