未验证 提交 94adf2c4 编写于 作者: L LielinJiang 提交者: GitHub

add option 'find unused paramters' for trainer (#311)

上级 8dc86476
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
find_unused_parameters: True
model: model:
name: CycleGANModel name: CycleGANModel
......
epochs: 200 epochs: 200
output_dir: output_dir output_dir: output_dir
find_unused_parameters: True
model: model:
name: CycleGANModel name: CycleGANModel
......
total_iters: 250000 total_iters: 250000
output_dir: output_dir output_dir: output_dir
find_unused_parameters: True
# tensor range for function tensor2img # tensor range for function tensor2img
min_max: min_max:
(0., 1.) (0., 1.)
......
epochs: 100 epochs: 100
output_dir: tmp output_dir: tmp
checkpoints_dir: checkpoints checkpoints_dir: checkpoints
find_unused_parameters: True
model: model:
name: MakeupModel name: MakeupModel
......
total_iters: 60000 total_iters: 60000
output_dir: output_dir output_dir: output_dir
find_unused_parameters: True
# tensor range for function tensor2img # tensor range for function tensor2img
min_max: min_max:
(0., 1.) (0., 1.)
......
total_iters: 60000 total_iters: 60000
output_dir: output_dir output_dir: output_dir
find_unused_parameters: True
# tensor range for function tensor2img # tensor range for function tensor2img
min_max: min_max:
(0., 1.) (0., 1.)
......
...@@ -146,8 +146,10 @@ class Trainer: ...@@ -146,8 +146,10 @@ class Trainer:
def distributed_data_parallel(self): def distributed_data_parallel(self):
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
find_unused_parameters = self.cfg.get('find_unused_parameters', False)
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net) self.model.nets[net_name] = paddle.DataParallel(
net, find_unused_parameters=find_unused_parameters)
def learning_rate_scheduler_step(self): def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict): if isinstance(self.model.lr_scheduler, dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册