未验证 提交 c56dbd8f 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #5 from LielinJiang/benchmark

for Benchmark test
...@@ -28,6 +28,7 @@ dataset: ...@@ -28,6 +28,7 @@ dataset:
train: train:
name: UnpairedDataset name: UnpairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
num_workers: 4
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
direction: AtoB direction: AtoB
......
...@@ -25,6 +25,7 @@ dataset: ...@@ -25,6 +25,7 @@ dataset:
train: train:
name: PairedDataset name: PairedDataset
dataroot: data/cityscapes dataroot: data/cityscapes
num_workers: 4
phase: train phase: train
max_dataset_size: inf max_dataset_size: inf
direction: BtoA direction: BtoA
......
...@@ -111,6 +111,6 @@ def build_dataloader(cfg, is_train=True): ...@@ -111,6 +111,6 @@ def build_dataloader(cfg, is_train=True):
batch_size = cfg.get('batch_size', 1) batch_size = cfg.get('batch_size', 1)
num_workers = cfg.get('num_workers', 0) num_workers = cfg.get('num_workers', 0)
dataloader = DictDataLoader(dataset, batch_size, is_train) dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
return dataloader return dataloader
\ No newline at end of file
...@@ -2,8 +2,9 @@ import os ...@@ -2,8 +2,9 @@ import os
import time import time
import logging import logging
import paddle
from paddle.imperative import ParallelEnv from paddle.imperative import ParallelEnv, DataParallel
from ..datasets.builder import build_dataloader from ..datasets.builder import build_dataloader
from ..models.builder import build_model from ..models.builder import build_model
...@@ -22,10 +23,13 @@ class Trainer: ...@@ -22,10 +23,13 @@ class Trainer:
# build model # build model
self.model = build_model(cfg) self.model = build_model(cfg)
# multiple gpus prepare
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
# base config # base config
# self.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
self.output_dir = cfg.output_dir self.output_dir = cfg.output_dir
self.epochs = cfg.epochs self.epochs = cfg.epochs
self.start_epoch = 0 self.start_epoch = 0
...@@ -38,24 +42,38 @@ class Trainer: ...@@ -38,24 +42,38 @@ class Trainer:
self.local_rank = ParallelEnv().local_rank self.local_rank = ParallelEnv().local_rank
# time count
self.time_count = {}
def distributed_data_parallel(self):
strategy = paddle.imperative.prepare_context()
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name, DataParallel(net, strategy))
def train(self): def train(self):
for epoch in range(self.start_epoch, self.epochs): for epoch in range(self.start_epoch, self.epochs):
start_time = time.time()
self.current_epoch = epoch self.current_epoch = epoch
start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader): for i, data in enumerate(self.train_dataloader):
data_time = time.time()
self.batch_id = i self.batch_id = i
# unpack data from dataset and apply preprocessing # unpack data from dataset and apply preprocessing
# data input should be dict # data input should be dict
self.model.set_input(data) self.model.set_input(data)
self.model.optimize_parameters() self.model.optimize_parameters()
self.data_time = data_time - step_start_time
self.step_time = time.time() - step_start_time
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.print_log() self.print_log()
if i % self.visual_interval == 0: if i % self.visual_interval == 0:
self.visual('visual_train') self.visual('visual_train')
step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) self.logger.info('train one epoch time: {}'.format(time.time() - start_time))
if epoch % self.weight_interval == 0: if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1) self.save(epoch, 'weight', keep=-1)
...@@ -98,6 +116,12 @@ class Trainer: ...@@ -98,6 +116,12 @@ class Trainer:
for k, v in losses.items(): for k, v in losses.items():
message += '%s: %.3f ' % (k, v) message += '%s: %.3f ' % (k, v)
if hasattr(self, 'data_time'):
message += 'reader cost: %.5fs ' % self.data_time
if hasattr(self, 'step_time'):
message += 'batch cost: %.5fs' % self.step_time
# print the message # print the message
self.logger.info(message) self.logger.info(message)
......
import paddle import paddle
from paddle.imperative import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -137,6 +138,12 @@ class CycleGANModel(BaseModel): ...@@ -137,6 +138,12 @@ class CycleGANModel(BaseModel):
loss_D_fake = self.criterionGAN(pred_fake, False) loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients # Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D = (loss_D_real + loss_D_fake) * 0.5
# loss_D.backward()
if ParallelEnv().nranks > 1:
loss_D = netD.scale_loss(loss_D)
loss_D.backward()
netD.apply_collective_grads()
else:
loss_D.backward() loss_D.backward()
return loss_D return loss_D
...@@ -177,6 +184,13 @@ class CycleGANModel(BaseModel): ...@@ -177,6 +184,13 @@ class CycleGANModel(BaseModel):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if ParallelEnv().nranks > 1:
self.loss_G = self.netG_A.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG_A.apply_collective_grads()
self.netG_B.apply_collective_grads()
else:
self.loss_G.backward() self.loss_G.backward()
def optimize_parameters(self): def optimize_parameters(self):
......
...@@ -36,11 +36,8 @@ class ResnetGenerator(paddle.fluid.dygraph.Layer): ...@@ -36,11 +36,8 @@ class ResnetGenerator(paddle.fluid.dygraph.Layer):
else: else:
use_bias = norm_layer == nn.InstanceNorm use_bias = norm_layer == nn.InstanceNorm
print('norm layer:', norm_layer, 'use bias:', use_bias)
model = [ReflectionPad2d(3), model = [ReflectionPad2d(3),
nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias), nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias),
# nn.nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias),
norm_layer(ngf), norm_layer(ngf),
nn.ReLU()] nn.ReLU()]
...@@ -62,8 +59,7 @@ class ResnetGenerator(paddle.fluid.dygraph.Layer): ...@@ -62,8 +59,7 @@ class ResnetGenerator(paddle.fluid.dygraph.Layer):
model += [ model += [
nn.Conv2DTranspose(ngf * mult, int(ngf * mult / 2), nn.Conv2DTranspose(ngf * mult, int(ngf * mult / 2),
filter_size=3, stride=2, filter_size=3, stride=2,
padding=1, #output_padding=1, padding=1,
# padding='same', #output_padding=1,
bias_attr=use_bias), bias_attr=use_bias),
Pad2D(paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0), Pad2D(paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0),
norm_layer(int(ngf * mult / 2)), norm_layer(int(ngf * mult / 2)),
......
import paddle import paddle
from paddle.imperative import ParallelEnv
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -43,7 +44,6 @@ class Pix2PixModel(BaseModel): ...@@ -43,7 +44,6 @@ class Pix2PixModel(BaseModel):
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.netG = build_generator(opt.model.generator) self.netG = build_generator(opt.model.generator)
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.isTrain: if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator) self.netD = build_discriminator(opt.model.discriminator)
...@@ -98,6 +98,11 @@ class Pix2PixModel(BaseModel): ...@@ -98,6 +98,11 @@ class Pix2PixModel(BaseModel):
self.loss_D_real = self.criterionGAN(pred_real, True) self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
if ParallelEnv().nranks > 1:
self.loss_D = self.netD.scale_loss(self.loss_D)
self.loss_D.backward()
self.netD.apply_collective_grads()
else:
self.loss_D.backward() self.loss_D.backward()
def backward_G(self): def backward_G(self):
...@@ -110,7 +115,12 @@ class Pix2PixModel(BaseModel): ...@@ -110,7 +115,12 @@ class Pix2PixModel(BaseModel):
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1
# self.loss_G = self.loss_G_L1
if ParallelEnv().nranks > 1:
self.loss_G = self.netG.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG.apply_collective_grads()
else:
self.loss_G.backward() self.loss_G.backward()
def optimize_parameters(self): def optimize_parameters(self):
......
...@@ -11,15 +11,13 @@ def save(state_dicts, file_name): ...@@ -11,15 +11,13 @@ def save(state_dicts, file_name):
def convert(state_dict): def convert(state_dict):
model_dict = {} model_dict = {}
# name_table = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)): if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
model_dict[k] = v.numpy() model_dict[k] = v.numpy()
else: else:
model_dict[k] = v model_dict[k] = v
return state_dict
# name_table[k] = v.name
# model_dict["StructuredToParameterName@@"] = name_table
return model_dict return model_dict
final_dict = {} final_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册