提交 27457a04 编写于 作者: C ceci3

update

上级 dc110e31
......@@ -13,7 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
from data_reader import data_reader
from data_reader import DataReader
def create_data(cfgs, direction='AtoB', eval_mode=False):
......@@ -21,15 +21,17 @@ def create_data(cfgs, direction='AtoB', eval_mode=False):
mode = 'TRAIN'
else:
mode = 'EVAL'
reader = data_reader(cfgs, mode=mode)
data, id2name = reader.make_data(direction)
reader = DataReader(cfgs, mode=mode)
dreader, id2name = reader.make_data(direction)
if cfgs.use_parallel:
dreader = fluid.contrib.reader.distributed_batch_reader(dreader)
#### id2name has something wrong when use_multiprocess
loader = fluid.io.DataLoader.from_generator(
capacity=4, iterable=True, use_double_buffer=True)
capacity=4, return_list=True, use_multiprocess=cfgs.use_multiprocess)
loader.set_batch_generator(
data,
places=fluid.CUDAPlace(0)
if cfgs.use_gpu else fluid.cpu_places()) ### fluid.cuda_places()
loader.set_batch_generator(dreader, places=cfgs.place)
return loader, id2name
......
......@@ -48,7 +48,7 @@ def RandomHorizonFlip(img):
return img
class reader_creator:
class ReaderCreator:
def __init__(self, *args, **kwcfgs):
raise NotImplementedError
......@@ -56,7 +56,7 @@ class reader_creator:
raise NotImplementedError
class single_datareader(reader_creator):
class SingleDatareader(ReaderCreator):
def __init__(self, list_filename, cfgs, mode='TEST'):
self.cfgs = cfgs
self.mode = mode
......@@ -114,7 +114,7 @@ class single_datareader(reader_creator):
return reader
class cycle_datareader(reader_creator):
class CycleDatareader(ReaderCreator):
def __init__(self, list_filename_A, list_filename_B, cfgs, mode='TRAIN'):
self.cfgs = cfgs
self.mode = mode
......@@ -202,7 +202,7 @@ class cycle_datareader(reader_creator):
return reader
class data_reader(object):
class DataReader(object):
def __init__(self, cfgs, mode='TRAIN'):
self.mode = mode
self.cfgs = cfgs
......
......@@ -17,7 +17,7 @@ import itertools
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.base import to_variable
from models.super_modules import SuperConv2D
from paddleslim.core.layers import SuperConv2D
from models import loss
from models import network
from models.base_model import BaseModel
......
......@@ -18,7 +18,6 @@ from paddle.fluid.dygraph.nn import Conv2D
from .base_resnet_distiller import BaseResnetDistiller
from utils import util
from utils.weight_transfer import load_pretrained_weight
from metric import compute_fid
from models import loss
from metric import get_fid
......@@ -161,14 +160,13 @@ class ResnetDistiller(BaseResnetDistiller):
fakes = []
cnt = 0
for i, data_i in enumerate(self.eval_dataloader):
id2name = self.name
self.set_single_input(data_i)
self.test()
fakes.append(self.Sfake_B.detach().numpy())
for j in range(len(self.Sfake_B)):
if cnt < 10:
Sname = 'Sfake_' + str(id2name[i + j]) + '.png'
Tname = 'Tfake_' + str(id2name[i + j]) + '.png'
Sname = 'Sfake_' + str(i + j) + '.png'
Tname = 'Tfake_' + str(i + j) + '.png'
Sfake_im = util.tensor2img(self.Sfake_B[j])
Tfake_im = util.tensor2img(self.Tfake_B[j])
util.save_image(Sfake_im, os.path.join(save_dir, Sname))
......
......@@ -23,18 +23,43 @@ from utils.get_args import configs
class gan_compression:
def __init__(self, cfgs, **kwargs):
self.cfgs = cfgs
use_gpu, use_parallel = self._get_device()
if not use_gpu:
place = fluid.CPUPlace()
else:
if not use_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
setattr(self.cfgs, 'use_gpu', use_gpu)
setattr(self.cfgs, 'use_parallel', use_parallel)
setattr(self.cfgs, 'place', place)
for k, v in kwargs.items():
setattr(self, k, v)
def _get_device(self):
num = self.cfgs.gpu_num
use_gpu, use_parallel = False, False
if num == -1:
use_gpu = False
else:
use_gpu = True
if num > 1:
use_parallel = True
return use_gpu, use_parallel
def start_train(self):
steps = self.cfgs.task.split('+')
for step in steps:
if step == 'mobile':
from models import create_model
elif step == 'distiller':
from distiller import create_distiller as create_model
from distillers import create_distiller as create_model
elif step == 'supernet':
from supernet import create_supernet as create_model
from supernets import create_supernet as create_model
else:
raise NotImplementedError
......@@ -65,8 +90,11 @@ class gan_compression:
message += '%s: %.3f ' % (k, v)
logging.info(message)
save_model = (not self.cfgs.use_parallel) or (
self.cfgs.use_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if epoch_id % self.cfgs.save_freq == 0 or epoch_id == (
epochs - 1):
epochs - 1) and save_model:
model.evaluate_model(epoch_id)
model.save_network(epoch_id)
if epoch_id == (epochs - 1):
......
import importlib
from .modules import *
from .base_model import BaseModel
......
......@@ -280,7 +280,6 @@ class CycleGAN(BaseModel):
self.netG_B.eval()
for direction in ['AtoB', 'BtoA']:
eval_dataloader = getattr(self, 'eval_dataloader_' + direction)
id2name = getattr(self, 'name_' + direction)
fakes = []
cnt = 0
for i, data_i in enumerate(eval_dataloader):
......@@ -289,8 +288,7 @@ class CycleGAN(BaseModel):
fakes.append(self.fake_B.detach().numpy())
for j in range(len(self.fake_B)):
if cnt < 10:
name = 'fake_' + direction + str(id2name[i +
j]) + '.png'
name = 'fake_' + direction + str(i + j) + '.png'
save_path = os.path.join(save_dir, name)
fake_im = util.tensor2img(self.fake_B[j])
util.save_image(fake_im, save_path)
......
......@@ -14,7 +14,7 @@
import functools
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, BatchNorm
from paddle.nn.layer import Leaky_ReLU, ReLU, Pad2D
from paddle.nn.layer import LeakyReLU, ReLU, Pad2D
class NLayerDiscriminator(fluid.dygraph.Layer):
......@@ -31,7 +31,7 @@ class NLayerDiscriminator(fluid.dygraph.Layer):
self.model = fluid.dygraph.LayerList([
Conv2D(
input_channel, ndf, filter_size=kw, stride=2, padding=padw),
Leaky_ReLU(0.2)
LeakyReLU(0.2)
])
nf_mult = 1
nf_mult_prev = 1
......@@ -45,19 +45,8 @@ class NLayerDiscriminator(fluid.dygraph.Layer):
filter_size=kw,
stride=2,
padding=padw,
bias_attr=use_bias),
#norm_layer(ndf * nf_mult),
InstanceNorm(
ndf * nf_mult,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0),
learning_rate=0.0,
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
learning_rate=0.0,
trainable=False)),
Leaky_ReLU(0.2)
bias_attr=use_bias), norm_layer(ndf * nf_mult),
LeakyReLU(0.2)
])
nf_mult_prev = nf_mult
......@@ -69,19 +58,7 @@ class NLayerDiscriminator(fluid.dygraph.Layer):
filter_size=kw,
stride=1,
padding=padw,
bias_attr=use_bias),
#norm_layer(ndf * nf_mult),
InstanceNorm(
ndf * nf_mult,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0),
learning_rate=0.0,
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
learning_rate=0.0,
trainable=False)),
Leaky_ReLU(0.2)
bias_attr=use_bias), norm_layer(ndf * nf_mult), LeakyReLU(0.2)
])
self.model.extend([
......
import functools
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose
from paddle.nn.layer import Leaky_ReLU, ReLU, Pad2D
from ..modules import MobileResnetBlock
from paddle.nn.layer import ReLU, Pad2D
from paddleslim.models.dygraph.modules import MobileResnetBlock
use_cudnn = False
......
import functools
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose
from paddle.nn.layer import Leaky_ReLU, ReLU, Pad2D
from ..modules import ResnetBlock
from paddle.nn.layer import ReLU, Pad2D
from paddleslim.models.dygraph.modules import ResnetBlock
class ResnetGenerator(fluid.dygraph.Layer):
......
......@@ -2,10 +2,8 @@ import functools
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose
from paddle.nn.layer import Leaky_ReLU, ReLU, Pad2D
from .modules import SeparableConv2D, MobileResnetBlock
use_cudnn = False
from paddle.nn.layer import ReLU, Pad2D
from paddleslim.models.dygraph.modules import SeparableConv2D, MobileResnetBlock
class SubMobileResnetGenerator(fluid.dygraph.Layer):
......
......@@ -2,8 +2,8 @@ import functools
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid.dygraph.nn import BatchNorm, InstanceNorm, Dropout
from paddle.nn.layer import Leaky_ReLU, ReLU, Pad2D
from ..super_modules import SuperConv2D, SuperConv2DTranspose, SuperSeparableConv2D, SuperInstanceNorm
from paddle.nn.layer import ReLU, Pad2D
from paddleslim.core.layers import SuperConv2D, SuperConv2DTranspose, SuperSeparableConv2D, SuperInstanceNorm
class SuperMobileResnetBlock(fluid.dygraph.Layer):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, BatchNorm, InstanceNorm, Dropout
from paddle.nn.layer import Leaky_ReLU, ReLU, Pad2D
__all__ = ['SeparableConv2D', 'MobileResnetBlock', 'ResnetBlock']
use_cudnn = False
class SeparableConv2D(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
norm_layer=InstanceNorm,
use_bias=True,
scale_factor=1,
stddev=0.02,
use_cudnn=use_cudnn):
super(SeparableConv2D, self).__init__()
self.conv = fluid.dygraph.LayerList([
Conv2D(
num_channels=num_channels,
num_filters=num_channels * scale_factor,
filter_size=filter_size,
stride=stride,
padding=padding,
use_cudnn=False,
groups=num_channels,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev)),
bias_attr=use_bias)
])
self.conv.extend([norm_layer(num_channels * scale_factor)])
self.conv.extend([
Conv2D(
num_channels=num_channels * scale_factor,
num_filters=num_filters,
filter_size=1,
stride=1,
use_cudnn=use_cudnn,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=stddev)),
bias_attr=use_bias)
])
def forward(self, inputs):
for sublayer in self.conv:
inputs = sublayer(inputs)
return inputs
class MobileResnetBlock(fluid.dygraph.Layer):
def __init__(self, in_c, out_c, padding_type, norm_layer, dropout_rate,
use_bias):
super(MobileResnetBlock, self).__init__()
self.padding_type = padding_type
self.dropout_rate = dropout_rate
self.conv_block = fluid.dygraph.LayerList([])
p = 0
if self.padding_type == 'reflect':
self.conv_block.extend(
[Pad2D(
paddings=[1, 1, 1, 1], mode='reflect')])
elif self.padding_type == 'replicate':
self.conv_block.extend(
[Pad2D(
inputs, paddings=[1, 1, 1, 1], mode='edge')])
elif self.padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' %
self.padding_type)
self.conv_block.extend([
SeparableConv2D(
num_channels=in_c,
num_filters=out_c,
filter_size=3,
padding=p,
stride=1), norm_layer(out_c), ReLU()
])
self.conv_block.extend([Dropout(p=self.dropout_rate)])
if self.padding_type == 'reflect':
self.conv_block.extend(
[Pad2D(
paddings=[1, 1, 1, 1], mode='reflect')])
elif self.padding_type == 'replicate':
self.conv_block.extend(
[Pad2D(
inputs, paddings=[1, 1, 1, 1], mode='edge')])
elif self.padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' %
self.padding_type)
self.conv_block.extend([
SeparableConv2D(
num_channels=out_c,
num_filters=in_c,
filter_size=3,
padding=p,
stride=1), norm_layer(in_c)
])
def forward(self, inputs):
y = inputs
for sublayer in self.conv_block:
y = sublayer(y)
out = inputs + y
return out
class ResnetBlock(fluid.dygraph.Layer):
def __init__(self,
dim,
padding_type,
norm_layer,
dropout_rate,
use_bias=False):
super(ResnetBlock, self).__init__()
self.conv_block = fluid.dygraph.LayerList([])
p = 0
if padding_type == 'reflect':
self.conv_block.extend(
[Pad2D(
paddings=[1, 1, 1, 1], mode='reflect')])
elif padding_type == 'replicate':
self.conv_block.extend([Pad2D(paddings=[1, 1, 1, 1], mode='edge')])
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' %
padding_type)
self.conv_block.extend([
Conv2D(
dim, dim, filter_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim), ReLU()
])
self.conv_block.extend([Dropout(dropout_rate)])
p = 0
if padding_type == 'reflect':
self.conv_block.extend(
[Pad2D(
paddings=[1, 1, 1, 1], mode='reflect')])
elif padding_type == 'replicate':
self.conv_block.extend([Pad2D(paddings=[1, 1, 1, 1], mode='edge')])
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' %
padding_type)
self.conv_block.extend([
Conv2D(
dim, dim, filter_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim)
])
def forward(self, inputs):
y = inputs
for sublayer in self.conv_block:
y = sublayer(y)
return y + inputs
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.dygraph_utils as dygraph_utils
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose
import paddle.fluid.core as core
import numpy as np
class SuperInstanceNorm(fluid.dygraph.InstanceNorm):
def __init__(self,
num_channels,
epsilon=1e-5,
param_attr=None,
bias_attr=None,
dtype='float32'):
super(SuperInstanceNorm, self).__init__(
num_channels,
epsilon=1e-5,
param_attr=None,
bias_attr=None,
dtype='float32')
def forward(self, input):
in_nc = int(input.shape[1])
scale = self.scale[:in_nc]
bias = self.scale[:in_nc]
if in_dygraph_mode():
out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon',
self._epsilon)
return out
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
"SuperInstanceNorm")
attrs = {"epsilon": self._epsilon}
inputs = {"X": [input], "Scale": [scale], "Bias": [bias]}
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
instance_norm_out = self._helper.create_variable_for_type_inference(
self._dtype)
outputs = {
"Y": [instance_norm_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}
self._helper.append_op(
type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return instance_norm_out
class SuperConv2D(fluid.dygraph.Conv2D):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
super(SuperConv2D, self).__init__(
num_channels, num_filters, filter_size, stride, padding, dilation,
groups, param_attr, bias_attr, use_cudnn, act, dtype)
def forward(self, input, config):
in_nc = int(input.shape[1])
out_nc = config['channel']
weight = self.weight[:out_nc, :in_nc, :, :]
#print('super conv shape', weight.shape)
if in_dygraph_mode():
if self._l_type == 'conv2d':
attrs = ('strides', self._stride, 'paddings', self._padding,
'dilations', self._dilation, 'groups', self._groups
if self._groups else 1, 'use_cudnn', self._use_cudnn)
out = core.ops.conv2d(input, weight, *attrs)
elif self._l_type == 'depthwise_conv2d':
attrs = ('strides', self._stride, 'paddings', self._padding,
'dilations', self._dilation, 'groups', self._groups,
'use_cudnn', self._use_cudnn)
out = core.ops.depthwise_conv2d(input, weight, *attrs)
else:
raise ValueError("conv type error")
pre_bias = out
if self.bias is not None:
bias = self.bias[:out_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias,
1)
else:
pre_act = pre_bias
return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act)
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups if self._groups else 1,
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
}
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'SuperConv2D')
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={
'Input': input,
'Filter': weight,
},
outputs={"Output": pre_bias},
attrs=attrs)
if self.bias is not None:
bias = self.bias[:out_nc]
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [bias]},
outputs={'Out': [pre_act]},
attrs={'axis': 1})
else:
pre_act = pre_bias
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_act, act=self._act)
class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
def __init__(self,
num_channels,
num_filters,
filter_size,
output_size=None,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
super(SuperConv2DTranspose,
self).__init__(num_channels, num_filters, filter_size,
output_size, padding, stride, dilation, groups,
param_attr, bias_attr, use_cudnn, act, dtype)
def forward(self, input, config):
in_nc = int(input.shape[1])
out_nc = int(config['channel'])
weight = self.weight[:in_nc, :out_nc, :, :]
if in_dygraph_mode():
op = getattr(core.ops, self._op_type)
out = op(input, weight, 'output_size', self._output_size,
'strides', self._stride, 'paddings', self._padding,
'dilations', self._dilation, 'groups', self._groups,
'use_cudnn', self._use_cudnn)
pre_bias = out
if self.bias is not None:
bias = self.bias[:out_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias,
1)
else:
pre_act = pre_bias
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=self._act)
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'],
"SuperConv2DTranspose")
inputs = {'Input': [input], 'Filter': [weight]}
attrs = {
'output_size': self._output_size,
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups,
'use_cudnn': self._use_cudnn
}
pre_bias = self._helper.create_variable_for_type_inference(
dtype=input.dtype)
self._helper.append_op(
type=self._op_type,
inputs=inputs,
outputs={'Output': pre_bias},
attrs=attrs)
if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [bias]},
outputs={'Out': [pre_act]},
attrs={'axis': 1})
else:
pre_act = pre_bias
out = self._helper.append_activation(pre_act, act=self._act)
return out
class SuperSeparableConv2D(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
norm_layer=InstanceNorm,
bias_attr=None,
scale_factor=1,
use_cudnn=False):
super(SuperSeparableConv2D, self).__init__()
self.conv = fluid.dygraph.LayerList([
fluid.dygraph.nn.Conv2D(
num_channels=num_channels,
num_filters=num_channels * scale_factor,
filter_size=filter_size,
stride=stride,
padding=padding,
use_cudnn=False,
groups=num_channels,
bias_attr=bias_attr)
])
if norm_layer == InstanceNorm:
self.conv.extend([
SuperInstanceNorm(
num_channels * scale_factor,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0),
learning_rate=0.0,
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
learning_rate=0.0,
trainable=False))
])
else:
raise NotImplementedError
self.conv.extend([
Conv2D(
num_channels=num_channels * scale_factor,
num_filters=num_filters,
filter_size=1,
stride=1,
use_cudnn=use_cudnn,
bias_attr=bias_attr)
])
def forward(self, input, config):
in_nc = int(input.shape[1])
out_nc = int(config['channel'])
weight = self.conv[0].weight[:in_nc]
### conv1
if in_dygraph_mode():
if self.conv[0]._l_type == 'conv2d':
attrs = ('strides', self.conv[0]._stride, 'paddings',
self.conv[0]._padding, 'dilations',
self.conv[0]._dilation, 'groups', in_nc, 'use_cudnn',
self.conv[0]._use_cudnn)
out = core.ops.conv2d(input, weight, *attrs)
elif self.conv[0]._l_type == 'depthwise_conv2d':
attrs = ('strides', self.conv[0]._stride, 'paddings',
self.conv[0]._padding, 'dilations',
self.conv[0]._dilation, 'groups', in_nc, 'use_cudnn',
self.conv[0]._use_cudnn)
out = core.ops.depthwise_conv2d(input, weight, *attrs)
else:
raise ValueError("conv type error")
pre_bias = out
if self.conv[0].bias is not None:
bias = self.conv[0].bias[:in_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias,
1)
else:
pre_act = pre_bias
conv0_out = dygraph_utils._append_activation_in_dygraph(
pre_act, self.conv[0]._act)
norm_out = self.conv[1](conv0_out)
weight = self.conv[2].weight[:out_nc, :in_nc, :, :]
if in_dygraph_mode():
if self.conv[2]._l_type == 'conv2d':
attrs = ('strides', self.conv[2]._stride, 'paddings',
self.conv[2]._padding, 'dilations',
self.conv[2]._dilation, 'groups', self.conv[2]._groups
if self.conv[2]._groups else 1, 'use_cudnn',
self.conv[2]._use_cudnn)
out = core.ops.conv2d(norm_out, weight, *attrs)
elif self.conv[2]._l_type == 'depthwise_conv2d':
attrs = ('strides', self.conv[2]._stride, 'paddings',
self.conv[2]._padding, 'dilations',
self.conv[2]._dilation, 'groups',
self.conv[2]._groups, 'use_cudnn',
self.conv[2]._use_cudnn)
out = core.ops.depthwise_conv2d(norm_out, weight, *attrs)
else:
raise ValueError("conv type error")
pre_bias = out
if self.conv[2].bias is not None:
bias = self.conv[2].bias[:out_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias,
1)
else:
pre_act = pre_bias
conv1_out = dygraph_utils._append_activation_in_dygraph(
pre_act, self.conv[2]._act)
return conv1_out
if __name__ == '__main__':
class Net(fluid.dygraph.Layer):
def __init__(self, in_cn=3):
super(Net, self).__init__()
self.myconv = SuperSeparableConv2D(
num_channels=in_cn, num_filters=3, filter_size=3)
def forward(self, input, config):
print(input.shape[1])
conv = self.myconv(input, config)
return conv
config = {'channel': 2}
with fluid.dygraph.guard():
net = Net()
data_A = np.random.random((1, 3, 256, 256)).astype("float32")
data_A = to_variable(data_A)
out = net(data_A, config)
print(out.numpy())
......@@ -149,14 +149,13 @@ class ResnetSupernet(BaseResnetDistiller):
config = self.configs(config_name)
fakes, names = [], []
for i, data_i in enumerate(self.eval_dataloader):
id2name = self.name
self.set_single_input(data_i)
self.test(config)
fakes.append(self.Sfake_B.detach().numpy())
for j in range(len(self.Sfake_B)):
if i < 10:
Sname = 'Sfake_' + str(id2name[i + j]) + '.png'
Tname = 'Tfake_' + str(id2name[i + j]) + '.png'
Sname = 'Sfake_' + str(i + j) + '.png'
Tname = 'Tfake_' + str(i + j) + '.png'
Sfake_im = util.tensor2img(self.Sfake_B[j])
Tfake_im = util.tensor2img(self.Tfake_B[j])
util.save_image(Sfake_im,
......
......@@ -44,10 +44,7 @@ class configs:
default='resnet',
help="generator network in supernet")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='Whether to use GPU in train/test model.')
'--gpu_num', type=int, default='0', help='GPU number.')
### data
parser.add_argument(
'--batch_size', type=int, default=1, help="Minbatch size")
......
......@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, InstanceNorm
from models.modules import SeparableConv2D, MobileResnetBlock, ResnetBlock
from paddle.fluid.dygraph.base import to_variable
import numpy as np
from paddleslim.models.dygraph.modules import SeparableConv2D, MobileResnetBlock, ResnetBlock
### CoutCinKhKw
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册