未验证 提交 605b34ba 编写于 作者: L LielinJiang 提交者: GitHub

fix resume and multiple gpu train bug (#57)

上级 2d17703b
...@@ -55,11 +55,8 @@ class Trainer: ...@@ -55,11 +55,8 @@ class Trainer:
def distributed_data_parallel(self): def distributed_data_parallel(self):
strategy = paddle.distributed.prepare_context() strategy = paddle.distributed.prepare_context()
for name in self.model.model_names: for net_name, net in self.model.nets.items():
if isinstance(name, str): self.model.nets[net_name] = paddle.DataParallel(net, strategy)
net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name,
paddle.DataParallel(net, strategy))
def train(self): def train(self):
reader_cost_averager = TimeAverager() reader_cost_averager = TimeAverager()
...@@ -77,9 +74,9 @@ class Trainer: ...@@ -77,9 +74,9 @@ class Trainer:
self.model.set_input(data) self.model.set_input(data)
self.model.optimize_parameters() self.model.optimize_parameters()
batch_cost_averager.record( batch_cost_averager.record(time.time() - step_start_time,
time.time() - step_start_time, num_samples=self.cfg.get(
num_samples=self.cfg.get('batch_size', 1)) 'batch_size', 1))
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average() self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average() self.step_time = batch_cost_averager.get_average()
...@@ -94,8 +91,8 @@ class Trainer: ...@@ -94,8 +91,8 @@ class Trainer:
step_start_time = time.time() step_start_time = time.time()
self.logger.info( self.logger.info('train one epoch time: {}'.format(time.time() -
'train one epoch time: {}'.format(time.time() - start_time)) start_time))
if self.validate_interval > -1 and epoch % self.validate_interval: if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate() self.validate()
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
...@@ -105,8 +102,8 @@ class Trainer: ...@@ -105,8 +102,8 @@ class Trainer:
def validate(self): def validate(self):
if not hasattr(self, 'val_dataloader'): if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader( self.val_dataloader = build_dataloader(self.cfg.dataset.val,
self.cfg.dataset.val, is_train=False) is_train=False)
metric_result = {} metric_result = {}
...@@ -152,8 +149,8 @@ class Trainer: ...@@ -152,8 +149,8 @@ class Trainer:
self.visual('visual_val', visual_results=visual_results) self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info( self.logger.info('val iter: [%d/%d]' %
'val iter: [%d/%d]' % (i, len(self.val_dataloader))) (i, len(self.val_dataloader)))
for metric_name in metric_result.keys(): for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset) metric_result[metric_name] /= len(self.val_dataloader.dataset)
...@@ -163,8 +160,8 @@ class Trainer: ...@@ -163,8 +160,8 @@ class Trainer:
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader( self.test_dataloader = build_dataloader(self.cfg.dataset.test,
self.cfg.dataset.test, is_train=False) is_train=False)
# data[0]: img, data[1]: img path index # data[0]: img, data[1]: img path index
# test batch size must be 1 # test batch size must be 1
...@@ -188,8 +185,8 @@ class Trainer: ...@@ -188,8 +185,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results) self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info( self.logger.info('Test iter: [%d/%d]' %
'Test iter: [%d/%d]' % (i, len(self.test_dataloader))) (i, len(self.test_dataloader)))
def print_log(self): def print_log(self):
losses = self.model.get_current_losses() losses = self.model.get_current_losses()
...@@ -277,13 +274,13 @@ class Trainer: ...@@ -277,13 +274,13 @@ class Trainer:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
for opt_name, opt in self.model.optimizers.items(): for opt_name, opt in self.model.optimizers.items():
opt.set_dict(state_dicts[opt_name]) opt.set_state_dict(state_dicts[opt_name])
def load(self, weight_path): def load(self, weight_path):
state_dicts = load(weight_path) state_dicts = load(weight_path)
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
...@@ -6,7 +6,11 @@ import paddle ...@@ -6,7 +6,11 @@ import paddle
def makedirs(dir): def makedirs(dir):
if not os.path.exists(dir): if not os.path.exists(dir):
# avoid error when train with multiple gpus
try:
os.makedirs(dir) os.makedirs(dir)
except:
pass
def save(state_dicts, file_name): def save(state_dicts, file_name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册