未验证 提交 b2f18921 编写于 作者: B Birdylx 提交者: GitHub

fix basic chain (#697)

上级 7f0b34dd
...@@ -25,6 +25,9 @@ model: ...@@ -25,6 +25,9 @@ model:
pixel_criterion: pixel_criterion:
name: CharbonnierLoss name: CharbonnierLoss
export_model:
- {name: 'generator', inputs_num: 1}
dataset: dataset:
train: train:
name: RepeatDataset name: RepeatDataset
......
...@@ -15,6 +15,9 @@ model: ...@@ -15,6 +15,9 @@ model:
pixel_criterion: pixel_criterion:
name: L1Loss name: L1Loss
export_model:
- {name: 'generator', inputs_num: 1}
dataset: dataset:
train: train:
name: SRDataset name: SRDataset
......
...@@ -62,6 +62,7 @@ class ResidualBlockNoBN(nn.Layer): ...@@ -62,6 +62,7 @@ class ResidualBlockNoBN(nn.Layer):
nf (int): Channel number of intermediate features. nf (int): Channel number of intermediate features.
Default: 64. Default: 64.
""" """
def __init__(self, nf=64): def __init__(self, nf=64):
super(ResidualBlockNoBN, self).__init__() super(ResidualBlockNoBN, self).__init__()
self.nf = nf self.nf = nf
...@@ -100,6 +101,7 @@ class PredeblurResNetPyramid(nn.Layer): ...@@ -100,6 +101,7 @@ class PredeblurResNetPyramid(nn.Layer):
nf (int): Channel number of intermediate features. Default: 64. nf (int): Channel number of intermediate features. Default: 64.
HR_in (bool): Whether the input has high resolution. Default: False. HR_in (bool): Whether the input has high resolution. Default: False.
""" """
def __init__(self, in_nf=3, nf=64, HR_in=False): def __init__(self, in_nf=3, nf=64, HR_in=False):
super(PredeblurResNetPyramid, self).__init__() super(PredeblurResNetPyramid, self).__init__()
self.in_nf = in_nf self.in_nf = in_nf
...@@ -189,6 +191,7 @@ class TSAFusion(nn.Layer): ...@@ -189,6 +191,7 @@ class TSAFusion(nn.Layer):
nframes (int): Number of frames. Default: 5. nframes (int): Number of frames. Default: 5.
center (int): The index of center frame. Default: 2. center (int): The index of center frame. Default: 2.
""" """
def __init__(self, nf=64, nframes=5, center=2): def __init__(self, nf=64, nframes=5, center=2):
super(TSAFusion, self).__init__() super(TSAFusion, self).__init__()
self.nf = nf self.nf = nf
...@@ -347,6 +350,7 @@ class DCNPack(nn.Layer): ...@@ -347,6 +350,7 @@ class DCNPack(nn.Layer):
Ref: Ref:
Delving Deep into Deformable Alignment in Video Super-Resolution. Delving Deep into Deformable Alignment in Video Super-Resolution.
""" """
def __init__(self, def __init__(self,
num_filters=64, num_filters=64,
kernel_size=3, kernel_size=3,
...@@ -408,6 +412,7 @@ class PCDAlign(nn.Layer): ...@@ -408,6 +412,7 @@ class PCDAlign(nn.Layer):
nf (int): Channel number of middle features. Default: 64. nf (int): Channel number of middle features. Default: 64.
groups (int): Deformable groups. Defaults: 8. groups (int): Deformable groups. Defaults: 8.
""" """
def __init__(self, nf=64, groups=8): def __init__(self, nf=64, groups=8):
super(PCDAlign, self).__init__() super(PCDAlign, self).__init__()
self.nf = nf self.nf = nf
...@@ -594,6 +599,7 @@ class EDVRNet(nn.Layer): ...@@ -594,6 +599,7 @@ class EDVRNet(nn.Layer):
with_tsa (bool): Whether has TSA module. Default: True. with_tsa (bool): Whether has TSA module. Default: True.
TSA_only (bool): Whether only use TSA module. Default: False. TSA_only (bool): Whether only use TSA module. Default: False.
""" """
def __init__(self, def __init__(self,
in_nf=3, in_nf=3,
out_nf=3, out_nf=3,
...@@ -750,13 +756,13 @@ class EDVRNet(nn.Layer): ...@@ -750,13 +756,13 @@ class EDVRNet(nn.Layer):
L1_fea[:, self.center, :, :, :], L2_fea[:, self.center, :, :, :], L1_fea[:, self.center, :, :, :], L2_fea[:, self.center, :, :, :],
L3_fea[:, self.center, :, :, :] L3_fea[:, self.center, :, :, :]
] ]
aligned_fea = []
for i in range(N): aligned_fea = [
nbr_fea_l = [ self.PCDModule([
L1_fea[:, i, :, :, :], L2_fea[:, i, :, :, :], L3_fea[:, L1_fea[:, i, :, :, :], L2_fea[:, i, :, :, :], L3_fea[:,
i, :, :, :] i, :, :, :]
] ], ref_fea_l) for i in range(N)
aligned_fea.append(self.PCDModule(nbr_fea_l, ref_fea_l)) ]
# TSA Fusion # TSA Fusion
aligned_fea = paddle.stack(aligned_fea, axis=1) # [B, N, C, H, W] aligned_fea = paddle.stack(aligned_fea, axis=1) # [B, N, C, H, W]
......
...@@ -29,6 +29,7 @@ from .generators.builder import build_generator ...@@ -29,6 +29,7 @@ from .generators.builder import build_generator
from .criterions.builder import build_criterion from .criterions.builder import build_criterion
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer from ..solver import build_lr_scheduler, build_optimizer
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
...@@ -38,22 +39,28 @@ def pad_shape(shape, pad_size): ...@@ -38,22 +39,28 @@ def pad_shape(shape, pad_size):
shape[-1] += 2 * pad_size shape[-1] += 2 * pad_size
return shape return shape
def quant(x, num): def quant(x, num):
n, c, h, w = x.shape n, c, h, w = x.shape
kmeans = KMeans(num, random_state=0).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c])) kmeans = KMeans(num, random_state=0).fit(
x.transpose([0, 2, 3, 1]).reshape([-1, c]))
centers = kmeans.cluster_centers_ centers = kmeans.cluster_centers_
x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2]) x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2])
return paddle.to_tensor(x, 'float32'), centers return paddle.to_tensor(x, 'float32'), centers
def quant_to_centers(x, centers): def quant_to_centers(x, centers):
n, c, h, w = x.shape n, c, h, w = x.shape
num = centers.shape[0] num = centers.shape[0]
kmeans = KMeans(num, init=centers, n_init=1).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c])) kmeans = KMeans(num, init=centers,
n_init=1).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c]))
x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2]) x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2])
return paddle.to_tensor(x, 'float32') return paddle.to_tensor(x, 'float32')
@MODELS.register() @MODELS.register()
class SinGANModel(BaseModel): class SinGANModel(BaseModel):
def __init__(self, def __init__(self,
generator, generator,
discriminator, discriminator,
...@@ -70,53 +77,58 @@ class SinGANModel(BaseModel): ...@@ -70,53 +77,58 @@ class SinGANModel(BaseModel):
disc_iters=3, disc_iters=3,
noise_amp_init=0.1): noise_amp_init=0.1):
super(SinGANModel, self).__init__() super(SinGANModel, self).__init__()
# setup config # setup config
self.gen_iters = gen_iters self.gen_iters = gen_iters
self.disc_iters = disc_iters self.disc_iters = disc_iters
self.min_size = min_size self.min_size = min_size
self.is_finetune = is_finetune self.is_finetune = is_finetune
self.noise_amp_init = noise_amp_init self.noise_amp_init = noise_amp_init
self.train_image = T.Compose([ self.train_image = T.Compose([T.Transpose(),
T.Transpose(), T.Normalize(127.5, 127.5)])(cv2.cvtColor(
T.Normalize(127.5, 127.5) cv2.imread(train_image,
])(cv2.cvtColor(cv2.imread(train_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)) cv2.IMREAD_COLOR),
cv2.COLOR_BGR2RGB))
self.train_image = paddle.to_tensor(self.train_image).unsqueeze(0) self.train_image = paddle.to_tensor(self.train_image).unsqueeze(0)
self.scale_num = math.ceil(math.log( self.scale_num = math.ceil(
self.min_size / min(self.train_image.shape[-2:]), math.log(self.min_size / min(self.train_image.shape[-2:]),
scale_factor)) + 1 scale_factor)) + 1
self.scale_factor = math.pow( self.scale_factor = math.pow(
self.min_size / min(self.train_image.shape[-2:]), self.min_size / min(self.train_image.shape[-2:]),
1 / (self.scale_num - 1)) 1 / (self.scale_num - 1))
self.reals = [ self.reals = [
F.interpolate(self.train_image, None, self.scale_factor ** i, 'bicubic') F.interpolate(self.train_image, None, self.scale_factor**i,
for i in range(self.scale_num - 1, -1, -1)] 'bicubic') for i in range(self.scale_num - 1, -1, -1)
]
# build generator # build generator
generator['scale_num'] = self.scale_num generator['scale_num'] = self.scale_num
generator['coarsest_shape'] =self.reals[0].shape generator['coarsest_shape'] = self.reals[0].shape
self.nets['netG'] = build_generator(generator) self.nets['netG'] = build_generator(generator)
self.niose_pad_size = 0 if generator.get('noise_zero_pad', True) \ self.niose_pad_size = 0 if generator.get('noise_zero_pad', True) \
else self.nets['netG']._pad_size else self.nets['netG']._pad_size
self.nets['netG'].scale_factor = paddle.to_tensor(self.scale_factor, 'float32') self.nets['netG'].scale_factor = paddle.to_tensor(
self.scale_factor, 'float32')
# build discriminator # build discriminator
nfc_init = discriminator.pop('nfc_init', 32) nfc_init = discriminator.pop('nfc_init', 32)
min_nfc_init = discriminator.pop('min_nfc_init', 32) min_nfc_init = discriminator.pop('min_nfc_init', 32)
for i in range(self.scale_num): for i in range(self.scale_num):
discriminator['nfc'] = min(nfc_init * pow(2, math.floor(i / 4)), 128) discriminator['nfc'] = min(nfc_init * pow(2, math.floor(i / 4)),
discriminator['min_nfc'] = min(min_nfc_init * pow(2, math.floor(i / 4)), 128) 128)
discriminator['min_nfc'] = min(
min_nfc_init * pow(2, math.floor(i / 4)), 128)
self.nets[f'netD{i}'] = build_discriminator(discriminator) self.nets[f'netD{i}'] = build_discriminator(discriminator)
# build criterion # build criterion
self.gan_criterion = build_criterion(gan_criterion) self.gan_criterion = build_criterion(gan_criterion)
self.recon_criterion = build_criterion(recon_criterion) self.recon_criterion = build_criterion(recon_criterion)
self.gp_criterion = build_criterion(gp_criterion) self.gp_criterion = build_criterion(gp_criterion)
if self.is_finetune: if self.is_finetune:
self.finetune_scale = finetune_scale self.finetune_scale = finetune_scale
self.quant_real, self.quant_centers = quant(self.reals[finetune_scale], color_num) self.quant_real, self.quant_centers = quant(
self.reals[finetune_scale], color_num)
# setup training config # setup training config
self.lr_schedulers = OrderedDict() self.lr_schedulers = OrderedDict()
...@@ -138,9 +150,11 @@ class SinGANModel(BaseModel): ...@@ -138,9 +150,11 @@ class SinGANModel(BaseModel):
def setup_optimizers(self, lr_schedulers, cfg): def setup_optimizers(self, lr_schedulers, cfg):
for i in range(self.scale_num): for i in range(self.scale_num):
self.optimizers[f'optim_netG{i}'] = build_optimizer( self.optimizers[f'optim_netG{i}'] = build_optimizer(
cfg['optimizer_G'], lr_schedulers[f"lr{i}"], self.nets[f'netG'].generators[i].parameters()) cfg['optimizer_G'], lr_schedulers[f"lr{i}"],
self.nets[f'netG'].generators[i].parameters())
self.optimizers[f'optim_netD{i}'] = build_optimizer( self.optimizers[f'optim_netD{i}'] = build_optimizer(
cfg['optimizer_D'], lr_schedulers[f"lr{i}"], self.nets[f'netD{i}'].parameters()) cfg['optimizer_D'], lr_schedulers[f"lr{i}"],
self.nets[f'netD{i}'].parameters())
return self.optimizers return self.optimizers
def setup_input(self, input): def setup_input(self, input):
...@@ -149,17 +163,18 @@ class SinGANModel(BaseModel): ...@@ -149,17 +163,18 @@ class SinGANModel(BaseModel):
def backward_D(self): def backward_D(self):
self.loss_D_real = self.gan_criterion(self.pred_real, True, True) self.loss_D_real = self.gan_criterion(self.pred_real, True, True)
self.loss_D_fake = self.gan_criterion(self.pred_fake, False, True) self.loss_D_fake = self.gan_criterion(self.pred_fake, False, True)
self.loss_D_gp = self.gp_criterion(self.nets[f'netD{self.current_scale}'], self.loss_D_gp = self.gp_criterion(
self.real_img, self.nets[f'netD{self.current_scale}'], self.real_img,
self.fake_img) self.fake_img)
self.loss_D = self.loss_D_real + self.loss_D_fake + self.loss_D_gp self.loss_D = self.loss_D_real + self.loss_D_fake + self.loss_D_gp
self.loss_D.backward() self.loss_D.backward()
self.losses[f'scale{self.current_scale}/D_total_loss'] = self.loss_D self.losses[f'scale{self.current_scale}/D_total_loss'] = self.loss_D
self.losses[f'scale{self.current_scale}/D_real_loss'] = self.loss_D_real self.losses[f'scale{self.current_scale}/D_real_loss'] = self.loss_D_real
self.losses[f'scale{self.current_scale}/D_fake_loss'] = self.loss_D_fake self.losses[f'scale{self.current_scale}/D_fake_loss'] = self.loss_D_fake
self.losses[f'scale{self.current_scale}/D_gradient_penalty'] = self.loss_D_gp self.losses[
f'scale{self.current_scale}/D_gradient_penalty'] = self.loss_D_gp
def backward_G(self): def backward_G(self):
self.loss_G_gan = self.gan_criterion(self.pred_fake, True, False) self.loss_G_gan = self.gan_criterion(self.pred_fake, True, False)
self.loss_G_recon = self.recon_criterion(self.recon_img, self.real_img) self.loss_G_recon = self.recon_criterion(self.recon_img, self.real_img)
...@@ -167,7 +182,8 @@ class SinGANModel(BaseModel): ...@@ -167,7 +182,8 @@ class SinGANModel(BaseModel):
self.loss_G.backward() self.loss_G.backward()
self.losses[f'scale{self.current_scale}/G_adv_loss'] = self.loss_G_gan self.losses[f'scale{self.current_scale}/G_adv_loss'] = self.loss_G_gan
self.losses[f'scale{self.current_scale}/G_recon_loss'] = self.loss_G_recon self.losses[
f'scale{self.current_scale}/G_recon_loss'] = self.loss_G_recon
def scale_prepare(self): def scale_prepare(self):
self.real_img = self.reals[self.current_scale] self.real_img = self.reals[self.current_scale]
...@@ -180,7 +196,7 @@ class SinGANModel(BaseModel): ...@@ -180,7 +196,7 @@ class SinGANModel(BaseModel):
self.visual_items[f'real_img_scale{self.current_scale}'] = self.real_img self.visual_items[f'real_img_scale{self.current_scale}'] = self.real_img
if self.is_finetune: if self.is_finetune:
self.visual_items['quant_real'] = self.quant_real self.visual_items['quant_real'] = self.quant_real
self.recon_prev = paddle.zeros_like(self.reals[0]) self.recon_prev = paddle.zeros_like(self.reals[0])
if self.current_scale > 0: if self.current_scale > 0:
z_pyramid = [] z_pyramid = []
...@@ -189,71 +205,77 @@ class SinGANModel(BaseModel): ...@@ -189,71 +205,77 @@ class SinGANModel(BaseModel):
z = self.nets['netG'].z_fixed z = self.nets['netG'].z_fixed
else: else:
z = paddle.zeros( z = paddle.zeros(
pad_shape( pad_shape(self.reals[i].shape, self.niose_pad_size))
self.reals[i].shape, self.niose_pad_size))
z_pyramid.append(z) z_pyramid.append(z)
self.recon_prev = self.nets['netG']( self.recon_prev = self.nets['netG'](z_pyramid, self.recon_prev,
z_pyramid, self.recon_prev, self.current_scale - 1,
self.current_scale - 1, 0).detach() 0).detach()
self.recon_prev = F.interpolate( self.recon_prev = F.interpolate(self.recon_prev,
self.recon_prev, self.real_img.shape[-2:], None, 'bicubic') self.real_img.shape[-2:], None,
'bicubic')
if self.is_finetune: if self.is_finetune:
self.recon_prev = quant_to_centers(self.recon_prev, self.quant_centers) self.recon_prev = quant_to_centers(self.recon_prev,
self.quant_centers)
self.nets['netG'].sigma[self.current_scale] = F.mse_loss( self.nets['netG'].sigma[self.current_scale] = F.mse_loss(
self.real_img, self.recon_prev self.real_img, self.recon_prev).sqrt() * self.noise_amp_init
).sqrt() * self.noise_amp_init
for i in range(self.scale_num): for i in range(self.scale_num):
self.set_requires_grad(self.nets['netG'].generators[i], i == self.current_scale) self.set_requires_grad(self.nets['netG'].generators[i],
i == self.current_scale)
def forward(self): def forward(self):
if not self.is_finetune: if not self.is_finetune:
self.fake_img = self.nets['netG']( self.fake_img = self.nets['netG'](self.z_pyramid,
self.z_pyramid, paddle.zeros(
paddle.zeros( pad_shape(
pad_shape(self.z_pyramid[0].shape, -self.niose_pad_size)), self.z_pyramid[0].shape,
self.current_scale, 0) -self.niose_pad_size)),
self.current_scale, 0)
else: else:
x_prev = self.nets['netG']( x_prev = self.nets['netG'](self.z_pyramid[:self.finetune_scale],
self.z_pyramid[:self.finetune_scale], paddle.zeros(
paddle.zeros( pad_shape(self.z_pyramid[0].shape,
pad_shape(self.z_pyramid[0].shape, -self.niose_pad_size)), -self.niose_pad_size)),
self.finetune_scale - 1, 0) self.finetune_scale - 1, 0)
x_prev = F.interpolate(x_prev, self.z_pyramid[self.finetune_scale].shape[-2:], None, 'bicubic') x_prev = F.interpolate(
x_prev, self.z_pyramid[self.finetune_scale].shape[-2:], None,
'bicubic')
x_prev_quant = quant_to_centers(x_prev, self.quant_centers) x_prev_quant = quant_to_centers(x_prev, self.quant_centers)
self.fake_img = self.nets['netG']( self.fake_img = self.nets['netG'](
self.z_pyramid[self.finetune_scale:], self.z_pyramid[self.finetune_scale:], x_prev_quant,
x_prev_quant,
self.current_scale, self.finetune_scale) self.current_scale, self.finetune_scale)
self.recon_img = self.nets['netG']( self.recon_img = self.nets['netG'](
[(paddle.randn if self.current_scale == 0 else paddle.zeros)( [(paddle.randn if self.current_scale == 0 else paddle.zeros)(
pad_shape(self.real_img.shape, self.niose_pad_size))], pad_shape(self.real_img.shape, self.niose_pad_size))],
self.recon_prev, self.recon_prev, self.current_scale, self.current_scale)
self.current_scale,
self.current_scale)
self.pred_real = self.nets[f'netD{self.current_scale}'](self.real_img) self.pred_real = self.nets[f'netD{self.current_scale}'](self.real_img)
self.pred_fake = self.nets[f'netD{self.current_scale}']( self.pred_fake = self.nets[f'netD{self.current_scale}'](
self.fake_img.detach() if self.update_D else self.fake_img) self.fake_img.detach() if self.update_D else self.fake_img)
self.visual_items[f'fake_img_scale{self.current_scale}'] = self.fake_img self.visual_items[f'fake_img_scale{self.current_scale}'] = self.fake_img
self.visual_items[f'recon_img_scale{self.current_scale}'] = self.recon_img self.visual_items[
f'recon_img_scale{self.current_scale}'] = self.recon_img
if self.is_finetune: if self.is_finetune:
self.visual_items[f'prev_img_scale{self.current_scale}'] = x_prev self.visual_items[f'prev_img_scale{self.current_scale}'] = x_prev
self.visual_items[f'quant_prev_img_scale{self.current_scale}'] = x_prev_quant self.visual_items[
f'quant_prev_img_scale{self.current_scale}'] = x_prev_quant
def train_iter(self, optimizers=None): def train_iter(self, optimizers=None):
if self.current_iter % self.scale_iters == 0: if self.current_iter % self.scale_iters == 0:
self.current_scale += 1 self.current_scale += 1
self.scale_prepare() self.scale_prepare()
self.z_pyramid = [paddle.randn( self.z_pyramid = [
pad_shape(self.reals[i].shape, self.niose_pad_size)) paddle.randn(pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.current_scale + 1)] for i in range(self.current_scale + 1)
]
self.update_D = (self.current_iter % (self.disc_iters + self.gen_iters) < self.disc_iters) self.update_D = (self.current_iter %
self.set_requires_grad(self.nets[f'netD{self.current_scale}'], self.update_D) (self.disc_iters + self.gen_iters) < self.disc_iters)
self.set_requires_grad(self.nets[f'netD{self.current_scale}'],
self.update_D)
self.forward() self.forward()
if self.update_D: if self.update_D:
optimizers[f'optim_netD{self.current_scale}'].clear_grad() optimizers[f'optim_netD{self.current_scale}'].clear_grad()
...@@ -267,22 +289,25 @@ class SinGANModel(BaseModel): ...@@ -267,22 +289,25 @@ class SinGANModel(BaseModel):
self.current_iter += 1 self.current_iter += 1
def test_iter(self, metrics=None): def test_iter(self, metrics=None):
z_pyramid = [paddle.randn( z_pyramid = [
pad_shape(self.reals[i].shape, self.niose_pad_size)) paddle.randn(pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.scale_num)] for i in range(self.scale_num)
]
self.nets['netG'].eval() self.nets['netG'].eval()
fake_img = self.nets['netG']( fake_img = self.nets['netG'](z_pyramid,
z_pyramid, paddle.zeros(
paddle.zeros(pad_shape(z_pyramid[0].shape, -self.niose_pad_size)), pad_shape(z_pyramid[0].shape,
self.scale_num - 1, 0) -self.niose_pad_size)),
self.scale_num - 1, 0)
self.visual_items['fake_img_test'] = fake_img self.visual_items['fake_img_test'] = fake_img
with paddle.no_grad(): with paddle.no_grad():
if metrics is not None: if metrics is not None:
for metric in metrics.values(): for metric in metrics.values():
metric.update(fake_img, self.train_image) metric.update(fake_img, self.train_image)
self.nets['netG'].train() self.nets['netG'].train()
class InferGenerator(paddle.nn.Layer): class InferGenerator(paddle.nn.Layer):
def set_config(self, generator, noise_shapes, scale_num): def set_config(self, generator, noise_shapes, scale_num):
self.generator = generator self.generator = generator
self.noise_shapes = noise_shapes self.noise_shapes = noise_shapes
...@@ -299,10 +324,14 @@ class SinGANModel(BaseModel): ...@@ -299,10 +324,14 @@ class SinGANModel(BaseModel):
export_model=None, export_model=None,
output_dir=None, output_dir=None,
inputs_size=None, inputs_size=None,
export_serving_model=False): export_serving_model=False,
noise_shapes = [pad_shape(x.shape, self.niose_pad_size) for x in self.reals] model_name=None):
noise_shapes = [
pad_shape(x.shape, self.niose_pad_size) for x in self.reals
]
infer_generator = self.InferGenerator() infer_generator = self.InferGenerator()
infer_generator.set_config(self.nets['netG'], noise_shapes, self.scale_num) infer_generator.set_config(self.nets['netG'], noise_shapes,
self.scale_num)
paddle.jit.save(infer_generator, paddle.jit.save(infer_generator,
os.path.join(output_dir, "singan_random_sample"), os.path.join(output_dir, "singan_random_sample"),
input_spec=[1]) input_spec=[1])
...@@ -13,7 +13,7 @@ train_infer_img_dir:./data/basicvsr_reds/test ...@@ -13,7 +13,7 @@ train_infer_img_dir:./data/basicvsr_reds/test
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.interval=5 norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.interval=5 snapshot_config.interval=25
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -24,6 +24,31 @@ null:null ...@@ -24,6 +24,31 @@ null:null
eval:null eval:null
null:null null:null
## ##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/edvr_m_wo_tsa.yaml --inputs_size="1,5,3,180,320" --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/edvr/edvrmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type edvr -c configs/edvr_m_wo_tsa.yaml --seed 123 -o dataset.test.num_frames=5 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:64 batch_size:64
fp_items:fp32 fp_items:fp32
......
...@@ -4,16 +4,16 @@ python:python3.7 ...@@ -4,16 +4,16 @@ python:python3.7
gpu_list:0 gpu_list:0
## ##
auto_cast:null auto_cast:null
total_iters:lite_train_lite_infer=10 total_iters:lite_train_lite_infer=100
output_dir:./output/ output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=100 dataset.train.batch_size:lite_train_lite_infer=2
pretrained_model:null pretrained_model:null
train_model_name:null train_model_name:esrgan_psnr_x4_div2k*/*checkpoint.pdparams
train_infer_img_dir:null train_infer_img_dir:null
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_config.interval=5 dataset.train.num_workers=0 norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_config.interval=10 snapshot_config.interval=25
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -24,6 +24,31 @@ null:null ...@@ -24,6 +24,31 @@ null:null
eval:null eval:null
null:null null:null
## ##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/esrgan_psnr_x4_div2k.yaml --inputs_size="1,3,128,128" --model_name inference --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:inference
train_model:./inference/esrgan/esrganmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type esrgan -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:32|64 batch_size:32|64
fp_items:fp32 fp_items:fp32
......
...@@ -256,7 +256,15 @@ def main(): ...@@ -256,7 +256,15 @@ def main():
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0]) prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max) image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/esrgan/{}.png".format(i)) gt_numpy = tensor2img(data['gt'][0], min_max)
save_image(
image_numpy,
os.path.join(args.output_path, "esrgan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_numpy)
break
elif model_type == "edvr": elif model_type == "edvr":
lq = data['lq'].numpy() lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq) input_handles[0].copy_from_cpu(lq)
...@@ -264,7 +272,14 @@ def main(): ...@@ -264,7 +272,14 @@ def main():
prediction = output_handle.copy_to_cpu() prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction[0]) prediction = paddle.to_tensor(prediction[0])
image_numpy = tensor2img(prediction, min_max) image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, "infer_output/edvr/{}.png".format(i)) gt_numpy = tensor2img(data['gt'][0, 0], min_max)
save_image(image_numpy,
os.path.join(args.output_path, "edvr/{}.png".format(i)))
metric_file = os.path.join(args.output_path, model_type,
"metric.txt")
for metric in metrics.values():
metric.update(image_numpy, gt_numpy)
break
elif model_type == "stylegan2": elif model_type == "stylegan2":
noise = paddle.randn([1, 1, 512]).cpu().numpy() noise = paddle.randn([1, 1, 512]).cpu().numpy()
input_handles[0].copy_from_cpu(noise) input_handles[0].copy_from_cpu(noise)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册