提交 ae28ccf8 编写于 作者: C ceci3

add multi-gpu

上级 27457a04
......@@ -88,16 +88,32 @@ class BaseResnetDistiller(BaseModel):
self.netG_pretrained = network.define_G(
cfgs.input_nc, cfgs.output_nc, cfgs.pretrained_ngf,
cfgs.pretrained_netG, cfgs.norm_type, 0)
if self.cfgs.use_parallel:
self.netG_pretrained = fluid.dygraph.parallel.DataParallel(
self.netG_pretrained, self.cfgs.strategy)
self.netD = network.define_D(cfgs.output_nc, cfgs.ndf, cfgs.netD,
cfgs.norm_type, cfgs.n_layer_D)
if self.cfgs.use_parallel:
self.netG_teacher = fluid.dygraph.parallel.DataParallel(
self.netG_teacher, self.cfgs.strategy)
self.netG_student = fluid.dygraph.parallel.DataParallel(
self.netG_student, self.cfgs.strategy)
self.netD = fluid.dygraph.parallel.DataParallel(self.netD,
self.cfgs.strategy)
self.netG_teacher.eval()
self.netG_student.train()
self.netD.train()
### [9, 12, 15, 18]
self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)]
self.mapping_layers = [
'_layers.model.%d' % i for i in range(9, 21, 3)
] if self.cfgs.use_parallel else [
'model.%d' % i for i in range(9, 21, 3)
]
self.netAs = []
self.Tacts, self.Sacts = {}, {}
......@@ -157,8 +173,8 @@ class BaseResnetDistiller(BaseModel):
self.is_best = False
def setup(self):
self.load_networks()
def setup(self, model_weight=None):
self.load_networks(model_weight)
if self.cfgs.lambda_distill > 0:
......@@ -183,30 +199,37 @@ class BaseResnetDistiller(BaseModel):
def set_single_input(self, inputs):
self.real_A = inputs[0]
def load_networks(self):
def load_networks(self, model_weight=None):
if self.cfgs.restore_teacher_G_path is None:
assert len(
model_weight
) != 0, "restore_teacher_G_path and model_weight cannot be None at the same time."
if self.cfgs.direction == 'AtoB':
teacher_G_path = os.path.join(self.cfgs.save_dir, 'mobile',
'last_netG_A')
key = 'netG_A' if 'netG_A' in model_weight else 'netG_teacher'
self.netG_teacher.set_dict(model_weight[key])
else:
teacher_G_path = os.path.join(self.cfgs.save_dir, 'mobile',
'last_netG_B')
key = 'netG_B' if 'netG_B' in model_weight else 'netG_teacher'
self.netG_teacher.set_dict(model_weight[key])
else:
teacher_G_path = self.cfgs.restore_teacher_G_path
util.load_network(self.netG_teacher, teacher_G_path)
util.load_network(self.netG_teacher, self.cfgs.teacher_G_path)
if self.cfgs.restore_student_G_path is not None:
util.load_network(self.netG_student,
self.cfgs.restore_student_G_path)
else:
if self.task == 'supernet':
student_G_path = os.path.join(self.cfgs.save_dir, 'distiller',
'last_stu_netG')
util.load_network(self.netG_student, student_G_path)
self.netG_student.set_dict(model_weight['netG_student'])
if self.cfgs.restore_D_path is not None:
util.load_network(self.netD, self.cfgs.restore_D_path)
else:
if self.cfgs.direction == 'AtoB':
key = 'netD_A' if 'netD_A' in model_weight else 'netD'
self.netD.set_dict(model_weight[key])
else:
key = 'netD_B' if 'netD_B' in model_weight else 'netD'
self.netD.set_dict(model_weight[key])
if self.cfgs.restore_A_path is not None:
for i, netA in enumerate(self.netAs):
netA_path = '%s-%d.pth' % (self.cfgs.restore_A_path, i)
......@@ -232,6 +255,9 @@ class BaseResnetDistiller(BaseModel):
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
if self.cfgs.use_parallel:
self.netD.apply_collective_grads()
def calc_distill_loss(self):
raise NotImplementedError
......
......@@ -117,6 +117,9 @@ class ResnetDistiller(BaseResnetDistiller):
self.loss_G = self.loss_G_gan + self.loss_G_recon + self.loss_G_distill
self.loss_G.backward()
if self.cfgs.use_parallel:
self.netG_student.apply_collective_grads()
def optimize_parameter(self):
self.forward()
......@@ -130,24 +133,27 @@ class ResnetDistiller(BaseResnetDistiller):
self.optimizer_G.optimizer.minimize(self.loss_G)
self.optimizer_G.optimizer.clear_gradients()
def load_networks(self):
load_pretrain = True
if self.cfgs.restore_pretrained_G_path is not None:
pretrained_G_path = self.cfgs.restore_pretrained_G_path
else:
pretrained_G_path = os.path.join(self.cfgs.save_dir, 'mobile',
'last_netG_B')
if not os.path.exists(os.path.join(pretrained_G_path, 'pdparams')):
load_pretrain = False
def load_networks(self, model_weight=None):
if self.cfgs.restore_pretrained_G_path != False:
if self.cfgs.restore_pretrained_G_path != None:
pretrained_G_path = self.cfgs.restore_pretrained_G_path
util.load_network(self.netG_pretrained, pretrained_G_path)
else:
assert len(
model_weight
) != 0, "restore_pretrained_G_path and model_weight can not be None at the same time, if you donnot want to load pretrained model, please set restore_pretrained_G_path=Fasle"
if self.cfgs.direction == 'AtoB':
self.netG_pretrained.set_dict(model_weight['netG_A'])
else:
self.netG_pretrained.set_dict(model_weight['netG_B'])
if load_pretrain:
util.load_network(self.netG_pretrained, pretrained_G_path)
load_pretrained_weight(
self.cfgs.pretrained_netG, self.cfgs.student_netG,
self.cfgs.pretrained_netG, self.cfgs.distiller_student_netG,
self.netG_pretrained, self.netG_student,
self.cfgs.pretrained_ngf, self.cfgs.student_ngf)
del self.netG_pretrained
super(ResnetDistiller, self).load_networks()
super(ResnetDistiller, self).load_networks(model_weight)
def evaluate_model(self, step):
ret = {}
......
......@@ -53,7 +53,8 @@ class gan_compression:
def start_train(self):
steps = self.cfgs.task.split('+')
for step in steps:
model_weight = {}
for idx, step in enumerate(steps):
if step == 'mobile':
from models import create_model
elif step == 'distiller':
......@@ -66,9 +67,16 @@ class gan_compression:
print(
"============================= start train {} ==============================".
format(step))
fluid.enable_imperative()
fluid.enable_imperative(place=self.cfgs.place)
if self.cfgs.use_parallel and idx == 0:
strategy = fluid.dygraph.parallel.prepare_context()
setattr(self.cfgs, 'strategy', strategy)
model = create_model(self.cfgs)
model.setup()
model.setup(model_weight)
### clear model_weight every step
model_weight = {}
_train_dataloader, _ = create_data(self.cfgs)
......@@ -90,6 +98,11 @@ class gan_compression:
message += '%s: %.3f ' % (k, v)
logging.info(message)
if epoch_id == (epochs - 1):
for name in model.model_names:
model_weight[name] = model._sub_layers[
name].state_dict()
save_model = (not self.cfgs.use_parallel) or (
self.cfgs.use_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
......@@ -97,8 +110,6 @@ class gan_compression:
epochs - 1) and save_model:
model.evaluate_model(epoch_id)
model.save_network(epoch_id)
if epoch_id == (epochs - 1):
model.save_network('last')
print("=" * 80)
......
......@@ -24,7 +24,7 @@ class BaseModel(fluid.dygraph.Layer):
def set_input(self, inputs):
raise NotImplementedError
def setup(self):
def setup(self, model_weight=None):
self.load_network()
def load_network(self):
......
......@@ -94,7 +94,7 @@ class CycleGAN(BaseModel):
'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
'G_idt_B'
]
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
self.model_names = ['netG_A', 'netG_B', 'netD_A', 'netD_B']
self.netG_A = network.define_G(cfgs.input_nc, cfgs.output_nc, cfgs.ngf,
cfgs.netG, cfgs.norm_type,
......@@ -107,6 +107,16 @@ class CycleGAN(BaseModel):
self.netD_B = network.define_D(cfgs.input_nc, cfgs.ndf, cfgs.netD,
cfgs.norm_type, cfgs.n_layer_D)
if self.cfgs.use_parallel:
self.netG_A = fluid.dygraph.parallel.DataParallel(
self.netG_A, self.cfgs.strategy)
self.netG_B = fluid.dygraph.parallel.DataParallel(
self.netG_B, self.cfgs.strategy)
self.netD_A = fluid.dygraph.parallel.DataParallel(
self.netD_A, self.cfgs.strategy)
self.netD_B = fluid.dygraph.parallel.DataParallel(
self.netD_B, self.cfgs.strategy)
if cfgs.lambda_identity > 0.0:
assert (cfgs.input_nc == cfgs.output_nc)
self.fake_A_pool = ImagePool(cfgs.pool_size)
......@@ -159,12 +169,12 @@ class CycleGAN(BaseModel):
def set_single_input(self, inputs):
self.real_A = inputs[0]
def setup(self):
def setup(self, model_weight=None):
self.load_network()
def load_network(self):
for name in self.model_names:
net = getattr(self, 'net' + name, None)
net = getattr(self, name, None)
path = getattr(self.cfgs, 'restore_%s_path' % name, None)
if path is not None:
util.load_network(net, path)
......@@ -172,10 +182,10 @@ class CycleGAN(BaseModel):
def save_network(self, epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net%s' % (epoch, name)
save_filename = '%s_%s' % (epoch, name)
save_path = os.path.join(self.cfgs.save_dir, 'mobile',
save_filename)
net = getattr(self, 'net' + name)
net = getattr(self, name)
fluid.save_dygraph(net.state_dict(), save_path)
def forward(self):
......
import functools
import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose
from paddle.nn.layer import ReLU, Pad2D
from paddleslim.models.dygraph.modules import ResnetBlock
......@@ -56,6 +57,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
output_size = (i + 1) * (self.cfgs.crop_size / 2)
self.model.extend([
Conv2DTranspose(
ngf * mult,
......@@ -63,6 +65,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
filter_size=3,
stride=2,
padding=1,
output_size=output_size,
bias_attr=use_bias), Pad2D(
paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0),
norm_layer(int(ngf * mult / 2)), ReLU()
......@@ -72,7 +75,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
self.model.extend([Conv2D(ngf, output_nc, filter_size=7, padding=0)])
def forward(self, inputs):
y = fluid.layers.clamp(inputs, min=-1.0, max=1.0)
y = tensor.clamp(inputs, min=-1.0, max=1.0)
for sublayer in self.model:
y = sublayer(y)
y = fluid.layers.tanh(y)
......
......@@ -75,7 +75,7 @@ class SubMobileResnetGenerator(fluid.dygraph.Layer):
for i in range(n_downsampling):
out_c = config['channels'][offset + i]
mult = 2**(n_downsampling - i)
output_size = (i + 1) * 128
output_size = (i + 1) * (self.cfgs.crop_size / 2)
self.model.extend([
Conv2DTranspose(
in_c * mult,
......
......@@ -75,19 +75,19 @@ class SuperMobileResnetGenerator(fluid.dygraph.Layer):
input_channel,
output_nc,
ngf,
norm_layer=BatchNorm,
norm_layer=InstanceNorm,
dropout_rate=0,
n_blocks=6,
padding_type='reflect'):
assert n_blocks >= 0
super(SuperMobileResnetGenerator, self).__init__()
use_bias = norm_layer == InstanceNorm
if norm_layer.func == InstanceNorm or norm_layer == InstanceNorm:
norm_layer = SuperInstanceNorm
else:
raise NotImplementedError
use_bias = norm_layer == InstanceNorm
self.model = fluid.dygraph.LayerList([])
self.model.extend([
Pad2D(
......
......@@ -18,6 +18,7 @@ from discrimitor import NLayerDiscriminator
from generator.resnet_generator import ResnetGenerator
from generator.mobile_generator import MobileResnetGenerator
from generator.super_generator import SuperMobileResnetGenerator
from generator.sub_mobile_generator import SubMobileResnetGenerator
class Identity(fluid.dygraph.Layer):
......@@ -88,6 +89,16 @@ def define_G(input_nc,
norm_layer=norm_layer,
dropout_rate=dropout_rate,
n_blocks=9)
elif netG == 'sub_mobile_resnet_9blocks':
assert self.cfgs.config_str is not None
config = decode_config(self.cfgs.config_str)
net = SubMobileResnetGenerator(
input_nc,
output_nc,
config,
norm_layer=norm_layer,
dropout_rate=dropout_rate,
n_blocks=9)
return net
......
......@@ -89,7 +89,10 @@ class ResnetSupernet(BaseResnetDistiller):
with fluid.dygraph.no_grad():
self.Tfake_B = self.netG_teacher(self.real_A)
self.Tfake_B.stop_gradient = True
self.netG_student.configs = config
if self.cfgs.use_parallel:
self.netG_student._layers.configs = config
else:
self.netG_student.configs = config
self.Sfake_B = self.netG_student(self.real_A)
def calc_distill_loss(self):
......@@ -137,7 +140,8 @@ class ResnetSupernet(BaseResnetDistiller):
def evaluate_model(self, step):
ret = {}
self.is_best = False
save_dir = os.path.join(self.cfgs.save_dir, 'eval', str(step))
save_dir = os.path.join(self.cfgs.save_dir, 'supernet', 'eval',
str(step))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
self.netG_student.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册