未验证 提交 38b08d08 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #40 from littletomatodonkey/master

fix some apis to paddle-dev
...@@ -59,15 +59,17 @@ class DictDataLoader(): ...@@ -59,15 +59,17 @@ class DictDataLoader():
place = paddle.CUDAPlace(ParallelEnv().dev_id) \ place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
sampler = DistributedBatchSampler(self.dataset, sampler = DistributedBatchSampler(
batch_size=batch_size, self.dataset,
shuffle=True if is_train else False, batch_size=batch_size,
drop_last=True if is_train else False) shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(self.dataset,
batch_sampler=sampler, self.dataloader = paddle.io.DataLoader(
places=place, self.dataset,
num_workers=num_workers) batch_sampler=sampler,
places=place,
num_workers=num_workers)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -92,7 +94,7 @@ class DictDataLoader(): ...@@ -92,7 +94,7 @@ class DictDataLoader():
return len(self.dataloader) return len(self.dataloader)
def get_items_by_indexs(self, key, indexs): def get_items_by_indexs(self, key, indexs):
if isinstance(indexs, paddle.Variable): if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy() indexs = indexs.numpy()
current_items = [] current_items = []
items = getattr(self.dataset, key) items = getattr(self.dataset, key)
......
...@@ -13,6 +13,7 @@ class ResnetGenerator(nn.Layer): ...@@ -13,6 +13,7 @@ class ResnetGenerator(nn.Layer):
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
""" """
def __init__(self, def __init__(self,
input_nc, input_nc,
output_nc, output_nc,
...@@ -37,17 +38,14 @@ class ResnetGenerator(nn.Layer): ...@@ -37,17 +38,14 @@ class ResnetGenerator(nn.Layer):
norm_layer = build_norm_layer(norm_type) norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial: if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm use_bias = norm_layer.func == nn.InstanceNorm2d
else: else:
use_bias = norm_layer == nn.InstanceNorm use_bias = norm_layer == nn.InstanceNorm2d
model = [ model = [
nn.ReflectionPad2d([3, 3, 3, 3]), nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
nn.Conv2d(input_nc, nn.Conv2d(
ngf, input_nc, ngf, kernel_size=7, padding=0, bias_attr=use_bias),
kernel_size=7,
padding=0,
bias_attr=use_bias),
norm_layer(ngf), norm_layer(ngf),
nn.ReLU() nn.ReLU()
] ]
...@@ -56,12 +54,13 @@ class ResnetGenerator(nn.Layer): ...@@ -56,12 +54,13 @@ class ResnetGenerator(nn.Layer):
for i in range(n_downsampling): # add downsampling layers for i in range(n_downsampling): # add downsampling layers
mult = 2**i mult = 2**i
model += [ model += [
nn.Conv2d(ngf * mult, nn.Conv2d(
ngf * mult * 2, ngf * mult,
kernel_size=3, ngf * mult * 2,
stride=2, kernel_size=3,
padding=1, stride=2,
bias_attr=use_bias), padding=1,
bias_attr=use_bias),
norm_layer(ngf * mult * 2), norm_layer(ngf * mult * 2),
nn.ReLU() nn.ReLU()
] ]
...@@ -70,27 +69,29 @@ class ResnetGenerator(nn.Layer): ...@@ -70,27 +69,29 @@ class ResnetGenerator(nn.Layer):
for i in range(n_blocks): # add ResNet blocks for i in range(n_blocks): # add ResNet blocks
model += [ model += [
ResnetBlock(ngf * mult, ResnetBlock(
padding_type=padding_type, ngf * mult,
norm_layer=norm_layer, padding_type=padding_type,
use_dropout=use_dropout, norm_layer=norm_layer,
use_bias=use_bias) use_dropout=use_dropout,
use_bias=use_bias)
] ]
for i in range(n_downsampling): # add upsampling layers for i in range(n_downsampling): # add upsampling layers
mult = 2**(n_downsampling - i) mult = 2**(n_downsampling - i)
model += [ model += [
nn.ConvTranspose2d(ngf * mult, nn.ConvTranspose2d(
int(ngf * mult / 2), ngf * mult,
kernel_size=3, int(ngf * mult / 2),
stride=2, kernel_size=3,
padding=1, stride=2,
output_padding=1, padding=1,
bias_attr=use_bias), output_padding=1,
bias_attr=use_bias),
norm_layer(int(ngf * mult / 2)), norm_layer(int(ngf * mult / 2)),
nn.ReLU() nn.ReLU()
] ]
model += [nn.ReflectionPad2d([3, 3, 3, 3])] model += [nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect")]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()] model += [nn.Tanh()]
...@@ -103,6 +104,7 @@ class ResnetGenerator(nn.Layer): ...@@ -103,6 +104,7 @@ class ResnetGenerator(nn.Layer):
class ResnetBlock(nn.Layer): class ResnetBlock(nn.Layer):
"""Define a Resnet block""" """Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block """Initialize the Resnet block
...@@ -130,15 +132,13 @@ class ResnetBlock(nn.Layer): ...@@ -130,15 +132,13 @@ class ResnetBlock(nn.Layer):
""" """
conv_block = [] conv_block = []
p = 0 p = 0
if padding_type == 'reflect': if padding_type in ['reflect', 'replicate']:
conv_block += [nn.ReflectionPad2d([1, 1, 1, 1])] conv_block += [nn.Pad2D(padding=[1, 1, 1, 1], mode=padding_type)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d([1, 1, 1, 1])]
elif padding_type == 'zero': elif padding_type == 'zero':
p = 1 p = 1
else: else:
raise NotImplementedError('padding [%s] is not implemented' % raise NotImplementedError(
padding_type) 'padding [%s] is not implemented' % padding_type)
conv_block += [ conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias), nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
...@@ -149,15 +149,13 @@ class ResnetBlock(nn.Layer): ...@@ -149,15 +149,13 @@ class ResnetBlock(nn.Layer):
conv_block += [nn.Dropout(0.5)] conv_block += [nn.Dropout(0.5)]
p = 0 p = 0
if padding_type == 'reflect': if padding_type in ['reflect', 'replicate']:
conv_block += [nn.ReflectionPad2d([1, 1, 1, 1])] conv_block += [nn.Pad2D(padding=[1, 1, 1, 1], mode=padding_type)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d([1, 1, 1, 1])]
elif padding_type == 'zero': elif padding_type == 'zero':
p = 1 p = 1
else: else:
raise NotImplementedError('padding [%s] is not implemented' % raise NotImplementedError(
padding_type) 'padding [%s] is not implemented' % padding_type)
conv_block += [ conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias), nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim) norm_layer(dim)
......
...@@ -10,6 +10,7 @@ class GANLoss(nn.Layer): ...@@ -10,6 +10,7 @@ class GANLoss(nn.Layer):
The GANLoss class abstracts away the need to create the target label tensor The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input. that has the same size as the input.
""" """
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class. """ Initialize the GANLoss class.
......
...@@ -256,10 +256,8 @@ def kaiming_init(layer, ...@@ -256,10 +256,8 @@ def kaiming_init(layer,
distribution='normal'): distribution='normal'):
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if distribution == 'uniform': if distribution == 'uniform':
kaiming_uniform_(layer.weight, kaiming_uniform_(
a=a, layer.weight, a=a, mode=mode, nonlinearity=nonlinearity)
mode=mode,
nonlinearity=nonlinearity)
else: else:
kaiming_normal_(layer.weight, a=a, mode=mode, nonlinearity=nonlinearity) kaiming_normal_(layer.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(layer, 'bias') and layer.bias is not None: if hasattr(layer, 'bias') and layer.bias is not None:
...@@ -275,6 +273,7 @@ def init_weights(net, init_type='normal', init_gain=0.02): ...@@ -275,6 +273,7 @@ def init_weights(net, init_type='normal', init_gain=0.02):
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself. work better for some applications. Feel free to try yourself.
""" """
def init_func(m): # define the initialization function def init_func(m): # define the initialization function
classname = m.__class__.__name__ classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 if hasattr(m, 'weight') and (classname.find('Conv') != -1
......
...@@ -21,21 +21,22 @@ def build_norm_layer(norm_type='instance'): ...@@ -21,21 +21,22 @@ def build_norm_layer(norm_type='instance'):
if norm_type == 'batch': if norm_type == 'batch':
norm_layer = functools.partial( norm_layer = functools.partial(
nn.BatchNorm, nn.BatchNorm,
param_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(1.0, 0.02)), initializer=nn.initializer.Normal(1.0, 0.02)),
bias_attr=paddle.ParamAttr( bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)), initializer=nn.initializer.Constant(0.0)),
trainable_statistics=True) trainable_statistics=True)
elif norm_type == 'instance': elif norm_type == 'instance':
norm_layer = functools.partial( norm_layer = functools.partial(
nn.InstanceNorm, nn.InstanceNorm2d,
param_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0), initializer=nn.initializer.Constant(1.0),
learning_rate=0.0, learning_rate=0.0,
trainable=False), trainable=False),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0), bias_attr=paddle.ParamAttr(
learning_rate=0.0, initializer=nn.initializer.Constant(0.0),
trainable=False)) learning_rate=0.0,
trainable=False))
elif norm_type == 'spectral': elif norm_type == 'spectral':
norm_layer = functools.partial(Spectralnorm) norm_layer = functools.partial(Spectralnorm)
elif norm_type == 'none': elif norm_type == 'none':
...@@ -43,6 +44,6 @@ def build_norm_layer(norm_type='instance'): ...@@ -43,6 +44,6 @@ def build_norm_layer(norm_type='instance'):
def norm_layer(x): def norm_layer(x):
return Identity() return Identity()
else: else:
raise NotImplementedError('normalization layer [%s] is not found' % raise NotImplementedError(
norm_type) 'normalization layer [%s] is not found' % norm_type)
return norm_layer return norm_layer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册