未验证 提交 38b08d08 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #40 from littletomatodonkey/master

fix some apis to paddle-dev
......@@ -59,15 +59,17 @@ class DictDataLoader():
place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
sampler = DistributedBatchSampler(self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers)
sampler = DistributedBatchSampler(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(
self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers)
self.batch_size = batch_size
......@@ -92,7 +94,7 @@ class DictDataLoader():
return len(self.dataloader)
def get_items_by_indexs(self, key, indexs):
if isinstance(indexs, paddle.Variable):
if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy()
current_items = []
items = getattr(self.dataset, key)
......
......@@ -13,6 +13,7 @@ class ResnetGenerator(nn.Layer):
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self,
input_nc,
output_nc,
......@@ -37,17 +38,14 @@ class ResnetGenerator(nn.Layer):
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm
use_bias = norm_layer == nn.InstanceNorm2d
model = [
nn.ReflectionPad2d([3, 3, 3, 3]),
nn.Conv2d(input_nc,
ngf,
kernel_size=7,
padding=0,
bias_attr=use_bias),
nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
nn.Conv2d(
input_nc, ngf, kernel_size=7, padding=0, bias_attr=use_bias),
norm_layer(ngf),
nn.ReLU()
]
......@@ -56,12 +54,13 @@ class ResnetGenerator(nn.Layer):
for i in range(n_downsampling): # add downsampling layers
mult = 2**i
model += [
nn.Conv2d(ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1,
bias_attr=use_bias),
nn.Conv2d(
ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1,
bias_attr=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU()
]
......@@ -70,27 +69,29 @@ class ResnetGenerator(nn.Layer):
for i in range(n_blocks): # add ResNet blocks
model += [
ResnetBlock(ngf * mult,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=use_bias)
ResnetBlock(
ngf * mult,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=use_bias)
]
for i in range(n_downsampling): # add upsampling layers
mult = 2**(n_downsampling - i)
model += [
nn.ConvTranspose2d(ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias_attr=use_bias),
nn.ConvTranspose2d(
ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias_attr=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU()
]
model += [nn.ReflectionPad2d([3, 3, 3, 3])]
model += [nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect")]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
......@@ -103,6 +104,7 @@ class ResnetGenerator(nn.Layer):
class ResnetBlock(nn.Layer):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block
......@@ -130,15 +132,13 @@ class ResnetBlock(nn.Layer):
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d([1, 1, 1, 1])]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d([1, 1, 1, 1])]
if padding_type in ['reflect', 'replicate']:
conv_block += [nn.Pad2D(padding=[1, 1, 1, 1], mode=padding_type)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' %
padding_type)
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
......@@ -149,15 +149,13 @@ class ResnetBlock(nn.Layer):
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d([1, 1, 1, 1])]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d([1, 1, 1, 1])]
if padding_type in ['reflect', 'replicate']:
conv_block += [nn.Pad2D(padding=[1, 1, 1, 1], mode=padding_type)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' %
padding_type)
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim)
......
......@@ -10,6 +10,7 @@ class GANLoss(nn.Layer):
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
......
......@@ -256,10 +256,8 @@ def kaiming_init(layer,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
kaiming_uniform_(layer.weight,
a=a,
mode=mode,
nonlinearity=nonlinearity)
kaiming_uniform_(
layer.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
kaiming_normal_(layer.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(layer, 'bias') and layer.bias is not None:
......@@ -275,6 +273,7 @@ def init_weights(net, init_type='normal', init_gain=0.02):
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1
......
......@@ -21,21 +21,22 @@ def build_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(
nn.BatchNorm,
param_attr=paddle.ParamAttr(
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(1.0, 0.02)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)),
trainable_statistics=True)
elif norm_type == 'instance':
norm_layer = functools.partial(
nn.InstanceNorm,
param_attr=paddle.ParamAttr(
nn.InstanceNorm2d,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0),
learning_rate=0.0,
trainable=False),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0),
learning_rate=0.0,
trainable=False))
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0),
learning_rate=0.0,
trainable=False))
elif norm_type == 'spectral':
norm_layer = functools.partial(Spectralnorm)
elif norm_type == 'none':
......@@ -43,6 +44,6 @@ def build_norm_layer(norm_type='instance'):
def norm_layer(x):
return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' %
norm_type)
raise NotImplementedError(
'normalization layer [%s] is not found' % norm_type)
return norm_layer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册