提交 ae28ccf8 编写于 作者: C ceci3

add multi-gpu

上级 27457a04
...@@ -88,16 +88,32 @@ class BaseResnetDistiller(BaseModel): ...@@ -88,16 +88,32 @@ class BaseResnetDistiller(BaseModel):
self.netG_pretrained = network.define_G( self.netG_pretrained = network.define_G(
cfgs.input_nc, cfgs.output_nc, cfgs.pretrained_ngf, cfgs.input_nc, cfgs.output_nc, cfgs.pretrained_ngf,
cfgs.pretrained_netG, cfgs.norm_type, 0) cfgs.pretrained_netG, cfgs.norm_type, 0)
if self.cfgs.use_parallel:
self.netG_pretrained = fluid.dygraph.parallel.DataParallel(
self.netG_pretrained, self.cfgs.strategy)
self.netD = network.define_D(cfgs.output_nc, cfgs.ndf, cfgs.netD, self.netD = network.define_D(cfgs.output_nc, cfgs.ndf, cfgs.netD,
cfgs.norm_type, cfgs.n_layer_D) cfgs.norm_type, cfgs.n_layer_D)
if self.cfgs.use_parallel:
self.netG_teacher = fluid.dygraph.parallel.DataParallel(
self.netG_teacher, self.cfgs.strategy)
self.netG_student = fluid.dygraph.parallel.DataParallel(
self.netG_student, self.cfgs.strategy)
self.netD = fluid.dygraph.parallel.DataParallel(self.netD,
self.cfgs.strategy)
self.netG_teacher.eval() self.netG_teacher.eval()
self.netG_student.train() self.netG_student.train()
self.netD.train() self.netD.train()
### [9, 12, 15, 18] ### [9, 12, 15, 18]
self.mapping_layers = ['model.%d' % i for i in range(9, 21, 3)] self.mapping_layers = [
'_layers.model.%d' % i for i in range(9, 21, 3)
] if self.cfgs.use_parallel else [
'model.%d' % i for i in range(9, 21, 3)
]
self.netAs = [] self.netAs = []
self.Tacts, self.Sacts = {}, {} self.Tacts, self.Sacts = {}, {}
...@@ -157,8 +173,8 @@ class BaseResnetDistiller(BaseModel): ...@@ -157,8 +173,8 @@ class BaseResnetDistiller(BaseModel):
self.is_best = False self.is_best = False
def setup(self): def setup(self, model_weight=None):
self.load_networks() self.load_networks(model_weight)
if self.cfgs.lambda_distill > 0: if self.cfgs.lambda_distill > 0:
...@@ -183,30 +199,37 @@ class BaseResnetDistiller(BaseModel): ...@@ -183,30 +199,37 @@ class BaseResnetDistiller(BaseModel):
def set_single_input(self, inputs): def set_single_input(self, inputs):
self.real_A = inputs[0] self.real_A = inputs[0]
def load_networks(self): def load_networks(self, model_weight=None):
if self.cfgs.restore_teacher_G_path is None: if self.cfgs.restore_teacher_G_path is None:
assert len(
model_weight
) != 0, "restore_teacher_G_path and model_weight cannot be None at the same time."
if self.cfgs.direction == 'AtoB': if self.cfgs.direction == 'AtoB':
teacher_G_path = os.path.join(self.cfgs.save_dir, 'mobile', key = 'netG_A' if 'netG_A' in model_weight else 'netG_teacher'
'last_netG_A') self.netG_teacher.set_dict(model_weight[key])
else: else:
teacher_G_path = os.path.join(self.cfgs.save_dir, 'mobile', key = 'netG_B' if 'netG_B' in model_weight else 'netG_teacher'
'last_netG_B') self.netG_teacher.set_dict(model_weight[key])
else: else:
teacher_G_path = self.cfgs.restore_teacher_G_path util.load_network(self.netG_teacher, self.cfgs.teacher_G_path)
util.load_network(self.netG_teacher, teacher_G_path)
if self.cfgs.restore_student_G_path is not None: if self.cfgs.restore_student_G_path is not None:
util.load_network(self.netG_student, util.load_network(self.netG_student,
self.cfgs.restore_student_G_path) self.cfgs.restore_student_G_path)
else: else:
if self.task == 'supernet': if self.task == 'supernet':
student_G_path = os.path.join(self.cfgs.save_dir, 'distiller', self.netG_student.set_dict(model_weight['netG_student'])
'last_stu_netG')
util.load_network(self.netG_student, student_G_path)
if self.cfgs.restore_D_path is not None: if self.cfgs.restore_D_path is not None:
util.load_network(self.netD, self.cfgs.restore_D_path) util.load_network(self.netD, self.cfgs.restore_D_path)
else:
if self.cfgs.direction == 'AtoB':
key = 'netD_A' if 'netD_A' in model_weight else 'netD'
self.netD.set_dict(model_weight[key])
else:
key = 'netD_B' if 'netD_B' in model_weight else 'netD'
self.netD.set_dict(model_weight[key])
if self.cfgs.restore_A_path is not None: if self.cfgs.restore_A_path is not None:
for i, netA in enumerate(self.netAs): for i, netA in enumerate(self.netAs):
netA_path = '%s-%d.pth' % (self.cfgs.restore_A_path, i) netA_path = '%s-%d.pth' % (self.cfgs.restore_A_path, i)
...@@ -232,6 +255,9 @@ class BaseResnetDistiller(BaseModel): ...@@ -232,6 +255,9 @@ class BaseResnetDistiller(BaseModel):
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward() self.loss_D.backward()
if self.cfgs.use_parallel:
self.netD.apply_collective_grads()
def calc_distill_loss(self): def calc_distill_loss(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -117,6 +117,9 @@ class ResnetDistiller(BaseResnetDistiller): ...@@ -117,6 +117,9 @@ class ResnetDistiller(BaseResnetDistiller):
self.loss_G = self.loss_G_gan + self.loss_G_recon + self.loss_G_distill self.loss_G = self.loss_G_gan + self.loss_G_recon + self.loss_G_distill
self.loss_G.backward() self.loss_G.backward()
if self.cfgs.use_parallel:
self.netG_student.apply_collective_grads()
def optimize_parameter(self): def optimize_parameter(self):
self.forward() self.forward()
...@@ -130,24 +133,27 @@ class ResnetDistiller(BaseResnetDistiller): ...@@ -130,24 +133,27 @@ class ResnetDistiller(BaseResnetDistiller):
self.optimizer_G.optimizer.minimize(self.loss_G) self.optimizer_G.optimizer.minimize(self.loss_G)
self.optimizer_G.optimizer.clear_gradients() self.optimizer_G.optimizer.clear_gradients()
def load_networks(self): def load_networks(self, model_weight=None):
load_pretrain = True if self.cfgs.restore_pretrained_G_path != False:
if self.cfgs.restore_pretrained_G_path is not None: if self.cfgs.restore_pretrained_G_path != None:
pretrained_G_path = self.cfgs.restore_pretrained_G_path pretrained_G_path = self.cfgs.restore_pretrained_G_path
else: util.load_network(self.netG_pretrained, pretrained_G_path)
pretrained_G_path = os.path.join(self.cfgs.save_dir, 'mobile', else:
'last_netG_B') assert len(
if not os.path.exists(os.path.join(pretrained_G_path, 'pdparams')): model_weight
load_pretrain = False ) != 0, "restore_pretrained_G_path and model_weight can not be None at the same time, if you donnot want to load pretrained model, please set restore_pretrained_G_path=Fasle"
if self.cfgs.direction == 'AtoB':
self.netG_pretrained.set_dict(model_weight['netG_A'])
else:
self.netG_pretrained.set_dict(model_weight['netG_B'])
if load_pretrain:
util.load_network(self.netG_pretrained, pretrained_G_path)
load_pretrained_weight( load_pretrained_weight(
self.cfgs.pretrained_netG, self.cfgs.student_netG, self.cfgs.pretrained_netG, self.cfgs.distiller_student_netG,
self.netG_pretrained, self.netG_student, self.netG_pretrained, self.netG_student,
self.cfgs.pretrained_ngf, self.cfgs.student_ngf) self.cfgs.pretrained_ngf, self.cfgs.student_ngf)
del self.netG_pretrained del self.netG_pretrained
super(ResnetDistiller, self).load_networks()
super(ResnetDistiller, self).load_networks(model_weight)
def evaluate_model(self, step): def evaluate_model(self, step):
ret = {} ret = {}
......
...@@ -53,7 +53,8 @@ class gan_compression: ...@@ -53,7 +53,8 @@ class gan_compression:
def start_train(self): def start_train(self):
steps = self.cfgs.task.split('+') steps = self.cfgs.task.split('+')
for step in steps: model_weight = {}
for idx, step in enumerate(steps):
if step == 'mobile': if step == 'mobile':
from models import create_model from models import create_model
elif step == 'distiller': elif step == 'distiller':
...@@ -66,9 +67,16 @@ class gan_compression: ...@@ -66,9 +67,16 @@ class gan_compression:
print( print(
"============================= start train {} ==============================". "============================= start train {} ==============================".
format(step)) format(step))
fluid.enable_imperative() fluid.enable_imperative(place=self.cfgs.place)
if self.cfgs.use_parallel and idx == 0:
strategy = fluid.dygraph.parallel.prepare_context()
setattr(self.cfgs, 'strategy', strategy)
model = create_model(self.cfgs) model = create_model(self.cfgs)
model.setup() model.setup(model_weight)
### clear model_weight every step
model_weight = {}
_train_dataloader, _ = create_data(self.cfgs) _train_dataloader, _ = create_data(self.cfgs)
...@@ -90,6 +98,11 @@ class gan_compression: ...@@ -90,6 +98,11 @@ class gan_compression:
message += '%s: %.3f ' % (k, v) message += '%s: %.3f ' % (k, v)
logging.info(message) logging.info(message)
if epoch_id == (epochs - 1):
for name in model.model_names:
model_weight[name] = model._sub_layers[
name].state_dict()
save_model = (not self.cfgs.use_parallel) or ( save_model = (not self.cfgs.use_parallel) or (
self.cfgs.use_parallel and self.cfgs.use_parallel and
fluid.dygraph.parallel.Env().local_rank == 0) fluid.dygraph.parallel.Env().local_rank == 0)
...@@ -97,8 +110,6 @@ class gan_compression: ...@@ -97,8 +110,6 @@ class gan_compression:
epochs - 1) and save_model: epochs - 1) and save_model:
model.evaluate_model(epoch_id) model.evaluate_model(epoch_id)
model.save_network(epoch_id) model.save_network(epoch_id)
if epoch_id == (epochs - 1):
model.save_network('last')
print("=" * 80) print("=" * 80)
......
...@@ -24,7 +24,7 @@ class BaseModel(fluid.dygraph.Layer): ...@@ -24,7 +24,7 @@ class BaseModel(fluid.dygraph.Layer):
def set_input(self, inputs): def set_input(self, inputs):
raise NotImplementedError raise NotImplementedError
def setup(self): def setup(self, model_weight=None):
self.load_network() self.load_network()
def load_network(self): def load_network(self):
......
...@@ -94,7 +94,7 @@ class CycleGAN(BaseModel): ...@@ -94,7 +94,7 @@ class CycleGAN(BaseModel):
'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B', 'D_A', 'G_A', 'G_cycle_A', 'G_idt_A', 'D_B', 'G_B', 'G_cycle_B',
'G_idt_B' 'G_idt_B'
] ]
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] self.model_names = ['netG_A', 'netG_B', 'netD_A', 'netD_B']
self.netG_A = network.define_G(cfgs.input_nc, cfgs.output_nc, cfgs.ngf, self.netG_A = network.define_G(cfgs.input_nc, cfgs.output_nc, cfgs.ngf,
cfgs.netG, cfgs.norm_type, cfgs.netG, cfgs.norm_type,
...@@ -107,6 +107,16 @@ class CycleGAN(BaseModel): ...@@ -107,6 +107,16 @@ class CycleGAN(BaseModel):
self.netD_B = network.define_D(cfgs.input_nc, cfgs.ndf, cfgs.netD, self.netD_B = network.define_D(cfgs.input_nc, cfgs.ndf, cfgs.netD,
cfgs.norm_type, cfgs.n_layer_D) cfgs.norm_type, cfgs.n_layer_D)
if self.cfgs.use_parallel:
self.netG_A = fluid.dygraph.parallel.DataParallel(
self.netG_A, self.cfgs.strategy)
self.netG_B = fluid.dygraph.parallel.DataParallel(
self.netG_B, self.cfgs.strategy)
self.netD_A = fluid.dygraph.parallel.DataParallel(
self.netD_A, self.cfgs.strategy)
self.netD_B = fluid.dygraph.parallel.DataParallel(
self.netD_B, self.cfgs.strategy)
if cfgs.lambda_identity > 0.0: if cfgs.lambda_identity > 0.0:
assert (cfgs.input_nc == cfgs.output_nc) assert (cfgs.input_nc == cfgs.output_nc)
self.fake_A_pool = ImagePool(cfgs.pool_size) self.fake_A_pool = ImagePool(cfgs.pool_size)
...@@ -159,12 +169,12 @@ class CycleGAN(BaseModel): ...@@ -159,12 +169,12 @@ class CycleGAN(BaseModel):
def set_single_input(self, inputs): def set_single_input(self, inputs):
self.real_A = inputs[0] self.real_A = inputs[0]
def setup(self): def setup(self, model_weight=None):
self.load_network() self.load_network()
def load_network(self): def load_network(self):
for name in self.model_names: for name in self.model_names:
net = getattr(self, 'net' + name, None) net = getattr(self, name, None)
path = getattr(self.cfgs, 'restore_%s_path' % name, None) path = getattr(self.cfgs, 'restore_%s_path' % name, None)
if path is not None: if path is not None:
util.load_network(net, path) util.load_network(net, path)
...@@ -172,10 +182,10 @@ class CycleGAN(BaseModel): ...@@ -172,10 +182,10 @@ class CycleGAN(BaseModel):
def save_network(self, epoch): def save_network(self, epoch):
for name in self.model_names: for name in self.model_names:
if isinstance(name, str): if isinstance(name, str):
save_filename = '%s_net%s' % (epoch, name) save_filename = '%s_%s' % (epoch, name)
save_path = os.path.join(self.cfgs.save_dir, 'mobile', save_path = os.path.join(self.cfgs.save_dir, 'mobile',
save_filename) save_filename)
net = getattr(self, 'net' + name) net = getattr(self, name)
fluid.save_dygraph(net.state_dict(), save_path) fluid.save_dygraph(net.state_dict(), save_path)
def forward(self): def forward(self):
......
import functools import functools
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.tensor as tensor
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose
from paddle.nn.layer import ReLU, Pad2D from paddle.nn.layer import ReLU, Pad2D
from paddleslim.models.dygraph.modules import ResnetBlock from paddleslim.models.dygraph.modules import ResnetBlock
...@@ -56,6 +57,7 @@ class ResnetGenerator(fluid.dygraph.Layer): ...@@ -56,6 +57,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
for i in range(n_downsampling): for i in range(n_downsampling):
mult = 2**(n_downsampling - i) mult = 2**(n_downsampling - i)
output_size = (i + 1) * (self.cfgs.crop_size / 2)
self.model.extend([ self.model.extend([
Conv2DTranspose( Conv2DTranspose(
ngf * mult, ngf * mult,
...@@ -63,6 +65,7 @@ class ResnetGenerator(fluid.dygraph.Layer): ...@@ -63,6 +65,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
filter_size=3, filter_size=3,
stride=2, stride=2,
padding=1, padding=1,
output_size=output_size,
bias_attr=use_bias), Pad2D( bias_attr=use_bias), Pad2D(
paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0), paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0),
norm_layer(int(ngf * mult / 2)), ReLU() norm_layer(int(ngf * mult / 2)), ReLU()
...@@ -72,7 +75,7 @@ class ResnetGenerator(fluid.dygraph.Layer): ...@@ -72,7 +75,7 @@ class ResnetGenerator(fluid.dygraph.Layer):
self.model.extend([Conv2D(ngf, output_nc, filter_size=7, padding=0)]) self.model.extend([Conv2D(ngf, output_nc, filter_size=7, padding=0)])
def forward(self, inputs): def forward(self, inputs):
y = fluid.layers.clamp(inputs, min=-1.0, max=1.0) y = tensor.clamp(inputs, min=-1.0, max=1.0)
for sublayer in self.model: for sublayer in self.model:
y = sublayer(y) y = sublayer(y)
y = fluid.layers.tanh(y) y = fluid.layers.tanh(y)
......
...@@ -75,7 +75,7 @@ class SubMobileResnetGenerator(fluid.dygraph.Layer): ...@@ -75,7 +75,7 @@ class SubMobileResnetGenerator(fluid.dygraph.Layer):
for i in range(n_downsampling): for i in range(n_downsampling):
out_c = config['channels'][offset + i] out_c = config['channels'][offset + i]
mult = 2**(n_downsampling - i) mult = 2**(n_downsampling - i)
output_size = (i + 1) * 128 output_size = (i + 1) * (self.cfgs.crop_size / 2)
self.model.extend([ self.model.extend([
Conv2DTranspose( Conv2DTranspose(
in_c * mult, in_c * mult,
......
...@@ -75,19 +75,19 @@ class SuperMobileResnetGenerator(fluid.dygraph.Layer): ...@@ -75,19 +75,19 @@ class SuperMobileResnetGenerator(fluid.dygraph.Layer):
input_channel, input_channel,
output_nc, output_nc,
ngf, ngf,
norm_layer=BatchNorm, norm_layer=InstanceNorm,
dropout_rate=0, dropout_rate=0,
n_blocks=6, n_blocks=6,
padding_type='reflect'): padding_type='reflect'):
assert n_blocks >= 0 assert n_blocks >= 0
super(SuperMobileResnetGenerator, self).__init__() super(SuperMobileResnetGenerator, self).__init__()
use_bias = norm_layer == InstanceNorm
if norm_layer.func == InstanceNorm or norm_layer == InstanceNorm: if norm_layer.func == InstanceNorm or norm_layer == InstanceNorm:
norm_layer = SuperInstanceNorm norm_layer = SuperInstanceNorm
else: else:
raise NotImplementedError raise NotImplementedError
use_bias = norm_layer == InstanceNorm
self.model = fluid.dygraph.LayerList([]) self.model = fluid.dygraph.LayerList([])
self.model.extend([ self.model.extend([
Pad2D( Pad2D(
......
...@@ -18,6 +18,7 @@ from discrimitor import NLayerDiscriminator ...@@ -18,6 +18,7 @@ from discrimitor import NLayerDiscriminator
from generator.resnet_generator import ResnetGenerator from generator.resnet_generator import ResnetGenerator
from generator.mobile_generator import MobileResnetGenerator from generator.mobile_generator import MobileResnetGenerator
from generator.super_generator import SuperMobileResnetGenerator from generator.super_generator import SuperMobileResnetGenerator
from generator.sub_mobile_generator import SubMobileResnetGenerator
class Identity(fluid.dygraph.Layer): class Identity(fluid.dygraph.Layer):
...@@ -88,6 +89,16 @@ def define_G(input_nc, ...@@ -88,6 +89,16 @@ def define_G(input_nc,
norm_layer=norm_layer, norm_layer=norm_layer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
n_blocks=9) n_blocks=9)
elif netG == 'sub_mobile_resnet_9blocks':
assert self.cfgs.config_str is not None
config = decode_config(self.cfgs.config_str)
net = SubMobileResnetGenerator(
input_nc,
output_nc,
config,
norm_layer=norm_layer,
dropout_rate=dropout_rate,
n_blocks=9)
return net return net
......
...@@ -89,7 +89,10 @@ class ResnetSupernet(BaseResnetDistiller): ...@@ -89,7 +89,10 @@ class ResnetSupernet(BaseResnetDistiller):
with fluid.dygraph.no_grad(): with fluid.dygraph.no_grad():
self.Tfake_B = self.netG_teacher(self.real_A) self.Tfake_B = self.netG_teacher(self.real_A)
self.Tfake_B.stop_gradient = True self.Tfake_B.stop_gradient = True
self.netG_student.configs = config if self.cfgs.use_parallel:
self.netG_student._layers.configs = config
else:
self.netG_student.configs = config
self.Sfake_B = self.netG_student(self.real_A) self.Sfake_B = self.netG_student(self.real_A)
def calc_distill_loss(self): def calc_distill_loss(self):
...@@ -137,7 +140,8 @@ class ResnetSupernet(BaseResnetDistiller): ...@@ -137,7 +140,8 @@ class ResnetSupernet(BaseResnetDistiller):
def evaluate_model(self, step): def evaluate_model(self, step):
ret = {} ret = {}
self.is_best = False self.is_best = False
save_dir = os.path.join(self.cfgs.save_dir, 'eval', str(step)) save_dir = os.path.join(self.cfgs.save_dir, 'supernet', 'eval',
str(step))
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
self.netG_student.eval() self.netG_student.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册