未验证 提交 b2f18921 编写于 作者: B Birdylx 提交者: GitHub

fix basic chain (#697)

上级 7f0b34dd
......@@ -25,6 +25,9 @@ model:
pixel_criterion:
name: CharbonnierLoss
export_model:
- {name: 'generator', inputs_num: 1}
dataset:
train:
name: RepeatDataset
......
......@@ -15,6 +15,9 @@ model:
pixel_criterion:
name: L1Loss
export_model:
- {name: 'generator', inputs_num: 1}
dataset:
train:
name: SRDataset
......
......@@ -62,6 +62,7 @@ class ResidualBlockNoBN(nn.Layer):
nf (int): Channel number of intermediate features.
Default: 64.
"""
def __init__(self, nf=64):
super(ResidualBlockNoBN, self).__init__()
self.nf = nf
......@@ -100,6 +101,7 @@ class PredeblurResNetPyramid(nn.Layer):
nf (int): Channel number of intermediate features. Default: 64.
HR_in (bool): Whether the input has high resolution. Default: False.
"""
def __init__(self, in_nf=3, nf=64, HR_in=False):
super(PredeblurResNetPyramid, self).__init__()
self.in_nf = in_nf
......@@ -189,6 +191,7 @@ class TSAFusion(nn.Layer):
nframes (int): Number of frames. Default: 5.
center (int): The index of center frame. Default: 2.
"""
def __init__(self, nf=64, nframes=5, center=2):
super(TSAFusion, self).__init__()
self.nf = nf
......@@ -347,6 +350,7 @@ class DCNPack(nn.Layer):
Ref:
Delving Deep into Deformable Alignment in Video Super-Resolution.
"""
def __init__(self,
num_filters=64,
kernel_size=3,
......@@ -408,6 +412,7 @@ class PCDAlign(nn.Layer):
nf (int): Channel number of middle features. Default: 64.
groups (int): Deformable groups. Defaults: 8.
"""
def __init__(self, nf=64, groups=8):
super(PCDAlign, self).__init__()
self.nf = nf
......@@ -594,6 +599,7 @@ class EDVRNet(nn.Layer):
with_tsa (bool): Whether has TSA module. Default: True.
TSA_only (bool): Whether only use TSA module. Default: False.
"""
def __init__(self,
in_nf=3,
out_nf=3,
......@@ -750,13 +756,13 @@ class EDVRNet(nn.Layer):
L1_fea[:, self.center, :, :, :], L2_fea[:, self.center, :, :, :],
L3_fea[:, self.center, :, :, :]
]
aligned_fea = []
for i in range(N):
nbr_fea_l = [
aligned_fea = [
self.PCDModule([
L1_fea[:, i, :, :, :], L2_fea[:, i, :, :, :], L3_fea[:,
i, :, :, :]
]
aligned_fea.append(self.PCDModule(nbr_fea_l, ref_fea_l))
], ref_fea_l) for i in range(N)
]
# TSA Fusion
aligned_fea = paddle.stack(aligned_fea, axis=1) # [B, N, C, H, W]
......
......@@ -29,6 +29,7 @@ from .generators.builder import build_generator
from .criterions.builder import build_criterion
from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
......@@ -38,22 +39,28 @@ def pad_shape(shape, pad_size):
shape[-1] += 2 * pad_size
return shape
def quant(x, num):
n, c, h, w = x.shape
kmeans = KMeans(num, random_state=0).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c]))
kmeans = KMeans(num, random_state=0).fit(
x.transpose([0, 2, 3, 1]).reshape([-1, c]))
centers = kmeans.cluster_centers_
x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2])
return paddle.to_tensor(x, 'float32'), centers
def quant_to_centers(x, centers):
n, c, h, w = x.shape
num = centers.shape[0]
kmeans = KMeans(num, init=centers, n_init=1).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c]))
kmeans = KMeans(num, init=centers,
n_init=1).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c]))
x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2])
return paddle.to_tensor(x, 'float32')
@MODELS.register()
class SinGANModel(BaseModel):
def __init__(self,
generator,
discriminator,
......@@ -70,53 +77,58 @@ class SinGANModel(BaseModel):
disc_iters=3,
noise_amp_init=0.1):
super(SinGANModel, self).__init__()
# setup config
self.gen_iters = gen_iters
self.disc_iters = disc_iters
self.min_size = min_size
self.is_finetune = is_finetune
self.noise_amp_init = noise_amp_init
self.train_image = T.Compose([
T.Transpose(),
T.Normalize(127.5, 127.5)
])(cv2.cvtColor(cv2.imread(train_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))
self.train_image = T.Compose([T.Transpose(),
T.Normalize(127.5, 127.5)])(cv2.cvtColor(
cv2.imread(train_image,
cv2.IMREAD_COLOR),
cv2.COLOR_BGR2RGB))
self.train_image = paddle.to_tensor(self.train_image).unsqueeze(0)
self.scale_num = math.ceil(math.log(
self.min_size / min(self.train_image.shape[-2:]),
scale_factor)) + 1
self.scale_num = math.ceil(
math.log(self.min_size / min(self.train_image.shape[-2:]),
scale_factor)) + 1
self.scale_factor = math.pow(
self.min_size / min(self.train_image.shape[-2:]),
self.min_size / min(self.train_image.shape[-2:]),
1 / (self.scale_num - 1))
self.reals = [
F.interpolate(self.train_image, None, self.scale_factor ** i, 'bicubic')
for i in range(self.scale_num - 1, -1, -1)]
F.interpolate(self.train_image, None, self.scale_factor**i,
'bicubic') for i in range(self.scale_num - 1, -1, -1)
]
# build generator
generator['scale_num'] = self.scale_num
generator['coarsest_shape'] =self.reals[0].shape
generator['coarsest_shape'] = self.reals[0].shape
self.nets['netG'] = build_generator(generator)
self.niose_pad_size = 0 if generator.get('noise_zero_pad', True) \
else self.nets['netG']._pad_size
self.nets['netG'].scale_factor = paddle.to_tensor(self.scale_factor, 'float32')
self.nets['netG'].scale_factor = paddle.to_tensor(
self.scale_factor, 'float32')
# build discriminator
nfc_init = discriminator.pop('nfc_init', 32)
min_nfc_init = discriminator.pop('min_nfc_init', 32)
for i in range(self.scale_num):
discriminator['nfc'] = min(nfc_init * pow(2, math.floor(i / 4)), 128)
discriminator['min_nfc'] = min(min_nfc_init * pow(2, math.floor(i / 4)), 128)
discriminator['nfc'] = min(nfc_init * pow(2, math.floor(i / 4)),
128)
discriminator['min_nfc'] = min(
min_nfc_init * pow(2, math.floor(i / 4)), 128)
self.nets[f'netD{i}'] = build_discriminator(discriminator)
# build criterion
self.gan_criterion = build_criterion(gan_criterion)
self.recon_criterion = build_criterion(recon_criterion)
self.gp_criterion = build_criterion(gp_criterion)
if self.is_finetune:
self.finetune_scale = finetune_scale
self.quant_real, self.quant_centers = quant(self.reals[finetune_scale], color_num)
self.quant_real, self.quant_centers = quant(
self.reals[finetune_scale], color_num)
# setup training config
self.lr_schedulers = OrderedDict()
......@@ -138,9 +150,11 @@ class SinGANModel(BaseModel):
def setup_optimizers(self, lr_schedulers, cfg):
for i in range(self.scale_num):
self.optimizers[f'optim_netG{i}'] = build_optimizer(
cfg['optimizer_G'], lr_schedulers[f"lr{i}"], self.nets[f'netG'].generators[i].parameters())
cfg['optimizer_G'], lr_schedulers[f"lr{i}"],
self.nets[f'netG'].generators[i].parameters())
self.optimizers[f'optim_netD{i}'] = build_optimizer(
cfg['optimizer_D'], lr_schedulers[f"lr{i}"], self.nets[f'netD{i}'].parameters())
cfg['optimizer_D'], lr_schedulers[f"lr{i}"],
self.nets[f'netD{i}'].parameters())
return self.optimizers
def setup_input(self, input):
......@@ -149,17 +163,18 @@ class SinGANModel(BaseModel):
def backward_D(self):
self.loss_D_real = self.gan_criterion(self.pred_real, True, True)
self.loss_D_fake = self.gan_criterion(self.pred_fake, False, True)
self.loss_D_gp = self.gp_criterion(self.nets[f'netD{self.current_scale}'],
self.real_img,
self.fake_img)
self.loss_D_gp = self.gp_criterion(
self.nets[f'netD{self.current_scale}'], self.real_img,
self.fake_img)
self.loss_D = self.loss_D_real + self.loss_D_fake + self.loss_D_gp
self.loss_D.backward()
self.losses[f'scale{self.current_scale}/D_total_loss'] = self.loss_D
self.losses[f'scale{self.current_scale}/D_real_loss'] = self.loss_D_real
self.losses[f'scale{self.current_scale}/D_fake_loss'] = self.loss_D_fake
self.losses[f'scale{self.current_scale}/D_gradient_penalty'] = self.loss_D_gp
self.losses[
f'scale{self.current_scale}/D_gradient_penalty'] = self.loss_D_gp
def backward_G(self):
self.loss_G_gan = self.gan_criterion(self.pred_fake, True, False)
self.loss_G_recon = self.recon_criterion(self.recon_img, self.real_img)
......@@ -167,7 +182,8 @@ class SinGANModel(BaseModel):
self.loss_G.backward()
self.losses[f'scale{self.current_scale}/G_adv_loss'] = self.loss_G_gan
self.losses[f'scale{self.current_scale}/G_recon_loss'] = self.loss_G_recon
self.losses[
f'scale{self.current_scale}/G_recon_loss'] = self.loss_G_recon
def scale_prepare(self):
self.real_img = self.reals[self.current_scale]
......@@ -180,7 +196,7 @@ class SinGANModel(BaseModel):
self.visual_items[f'real_img_scale{self.current_scale}'] = self.real_img
if self.is_finetune:
self.visual_items['quant_real'] = self.quant_real
self.recon_prev = paddle.zeros_like(self.reals[0])
if self.current_scale > 0:
z_pyramid = []
......@@ -189,71 +205,77 @@ class SinGANModel(BaseModel):
z = self.nets['netG'].z_fixed
else:
z = paddle.zeros(
pad_shape(
self.reals[i].shape, self.niose_pad_size))
pad_shape(self.reals[i].shape, self.niose_pad_size))
z_pyramid.append(z)
self.recon_prev = self.nets['netG'](
z_pyramid, self.recon_prev,
self.current_scale - 1, 0).detach()
self.recon_prev = F.interpolate(
self.recon_prev, self.real_img.shape[-2:], None, 'bicubic')
self.recon_prev = self.nets['netG'](z_pyramid, self.recon_prev,
self.current_scale - 1,
0).detach()
self.recon_prev = F.interpolate(self.recon_prev,
self.real_img.shape[-2:], None,
'bicubic')
if self.is_finetune:
self.recon_prev = quant_to_centers(self.recon_prev, self.quant_centers)
self.recon_prev = quant_to_centers(self.recon_prev,
self.quant_centers)
self.nets['netG'].sigma[self.current_scale] = F.mse_loss(
self.real_img, self.recon_prev
).sqrt() * self.noise_amp_init
self.real_img, self.recon_prev).sqrt() * self.noise_amp_init
for i in range(self.scale_num):
self.set_requires_grad(self.nets['netG'].generators[i], i == self.current_scale)
self.set_requires_grad(self.nets['netG'].generators[i],
i == self.current_scale)
def forward(self):
if not self.is_finetune:
self.fake_img = self.nets['netG'](
self.z_pyramid,
paddle.zeros(
pad_shape(self.z_pyramid[0].shape, -self.niose_pad_size)),
self.current_scale, 0)
self.fake_img = self.nets['netG'](self.z_pyramid,
paddle.zeros(
pad_shape(
self.z_pyramid[0].shape,
-self.niose_pad_size)),
self.current_scale, 0)
else:
x_prev = self.nets['netG'](
self.z_pyramid[:self.finetune_scale],
paddle.zeros(
pad_shape(self.z_pyramid[0].shape, -self.niose_pad_size)),
self.finetune_scale - 1, 0)
x_prev = F.interpolate(x_prev, self.z_pyramid[self.finetune_scale].shape[-2:], None, 'bicubic')
x_prev = self.nets['netG'](self.z_pyramid[:self.finetune_scale],
paddle.zeros(
pad_shape(self.z_pyramid[0].shape,
-self.niose_pad_size)),
self.finetune_scale - 1, 0)
x_prev = F.interpolate(
x_prev, self.z_pyramid[self.finetune_scale].shape[-2:], None,
'bicubic')
x_prev_quant = quant_to_centers(x_prev, self.quant_centers)
self.fake_img = self.nets['netG'](
self.z_pyramid[self.finetune_scale:],
x_prev_quant,
self.z_pyramid[self.finetune_scale:], x_prev_quant,
self.current_scale, self.finetune_scale)
self.recon_img = self.nets['netG'](
[(paddle.randn if self.current_scale == 0 else paddle.zeros)(
pad_shape(self.real_img.shape, self.niose_pad_size))],
self.recon_prev,
self.current_scale,
self.current_scale)
self.recon_prev, self.current_scale, self.current_scale)
self.pred_real = self.nets[f'netD{self.current_scale}'](self.real_img)
self.pred_fake = self.nets[f'netD{self.current_scale}'](
self.fake_img.detach() if self.update_D else self.fake_img)
self.visual_items[f'fake_img_scale{self.current_scale}'] = self.fake_img
self.visual_items[f'recon_img_scale{self.current_scale}'] = self.recon_img
self.visual_items[
f'recon_img_scale{self.current_scale}'] = self.recon_img
if self.is_finetune:
self.visual_items[f'prev_img_scale{self.current_scale}'] = x_prev
self.visual_items[f'quant_prev_img_scale{self.current_scale}'] = x_prev_quant
self.visual_items[
f'quant_prev_img_scale{self.current_scale}'] = x_prev_quant
def train_iter(self, optimizers=None):
if self.current_iter % self.scale_iters == 0:
self.current_scale += 1
self.scale_prepare()
self.z_pyramid = [paddle.randn(
pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.current_scale + 1)]
self.z_pyramid = [
paddle.randn(pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.current_scale + 1)
]
self.update_D = (self.current_iter % (self.disc_iters + self.gen_iters) < self.disc_iters)
self.set_requires_grad(self.nets[f'netD{self.current_scale}'], self.update_D)
self.update_D = (self.current_iter %
(self.disc_iters + self.gen_iters) < self.disc_iters)
self.set_requires_grad(self.nets[f'netD{self.current_scale}'],
self.update_D)
self.forward()
if self.update_D:
optimizers[f'optim_netD{self.current_scale}'].clear_grad()
......@@ -267,22 +289,25 @@ class SinGANModel(BaseModel):
self.current_iter += 1
def test_iter(self, metrics=None):
z_pyramid = [paddle.randn(
pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.scale_num)]
z_pyramid = [
paddle.randn(pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.scale_num)
]
self.nets['netG'].eval()
fake_img = self.nets['netG'](
z_pyramid,
paddle.zeros(pad_shape(z_pyramid[0].shape, -self.niose_pad_size)),
self.scale_num - 1, 0)
fake_img = self.nets['netG'](z_pyramid,
paddle.zeros(
pad_shape(z_pyramid[0].shape,
-self.niose_pad_size)),
self.scale_num - 1, 0)
self.visual_items['fake_img_test'] = fake_img
with paddle.no_grad():
if metrics is not None:
for metric in metrics.values():
metric.update(fake_img, self.train_image)
self.nets['netG'].train()
class InferGenerator(paddle.nn.Layer):
def set_config(self, generator, noise_shapes, scale_num):
self.generator = generator
self.noise_shapes = noise_shapes
......@@ -299,10 +324,14 @@ class SinGANModel(BaseModel):
export_model=None,
output_dir=None,
inputs_size=None,
export_serving_model=False):
noise_shapes = [pad_shape(x.shape, self.niose_pad_size) for x in self.reals]
export_serving_model=False,
model_name=None):
noise_shapes = [
pad_shape(x.shape, self.niose_pad_size) for x in self.reals
]
infer_generator = self.InferGenerator()
infer_generator.set_config(self.nets['netG'], noise_shapes, self.scale_num)
infer_generator.set_config(self.nets['netG'], noise_shapes,
self.scale_num)
paddle.jit.save(infer_generator,
os.path.join(output_dir, "singan_random_sample"),
input_spec=[1])
......@@ -13,7 +13,7 @@ train_infer_img_dir:./data/basicvsr_reds/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.interval=5
norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.interval=5 snapshot_config.interval=25
pact_train:null
fpgm_train:null
distill_train:null
......@@ -24,6 +24,31 @@ null:null
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/edvr_m_wo_tsa.yaml --inputs_size="1,5,3,180,320" --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/edvr/edvrmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type edvr -c configs/edvr_m_wo_tsa.yaml --seed 123 -o dataset.test.num_frames=5 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
===========================train_benchmark_params==========================
batch_size:64
fp_items:fp32
......
......@@ -4,16 +4,16 @@ python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10
total_iters:lite_train_lite_infer=100
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=100
dataset.train.batch_size:lite_train_lite_infer=2
pretrained_model:null
train_model_name:null
train_model_name:esrgan_psnr_x4_div2k*/*checkpoint.pdparams
train_infer_img_dir:null
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_config.interval=5 dataset.train.num_workers=0
norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_config.interval=10 snapshot_config.interval=25
pact_train:null
fpgm_train:null
distill_train:null
......@@ -24,6 +24,31 @@ null:null
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/esrgan_psnr_x4_div2k.yaml --inputs_size="1,3,128,128" --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/esrgan/esrganmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type esrgan -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
===========================train_benchmark_params==========================
batch_size:32|64
fp_items:fp32
......
......@@ -256,7 +256,15 @@ def main():
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/esrgan/{}.png".format(i))
gt_numpy = tensor2img(data['gt'][0], min_max)
save_image(
image_numpy,
os.path.join(args.output_path, "esrgan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_numpy)
break
elif model_type == "edvr":
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
......@@ -264,7 +272,14 @@ def main():
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/edvr/{}.png".format(i))
gt_numpy = tensor2img(data['gt'][0, 0], min_max)
save_image(image_numpy,
os.path.join(args.output_path, "edvr/{}.png".format(i)))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_numpy)
break
elif model_type == "stylegan2":
noise = paddle.randn([1, 1, 512]).cpu().numpy()
input_handles[0].copy_from_cpu(noise)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册