未验证 提交 3b8e36b8 编写于 作者: C CC 提交者: GitHub

add new feature(gradclipnorm) for invdn (#705)

上级 45922b0c
......@@ -61,7 +61,8 @@ optimizer:
- generator
beta1: 0.9
beta2: 0.99
epsilon: 1e-8 #TODO GRADIENT_CLIPPING
epsilon: 1e-8
clip_grad_norm: 10
log_config:
interval: 100
......
......@@ -21,6 +21,7 @@ from .base_model import BaseModel
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from ppgan.utils.visual import tensor2img
from ..solver import build_lr_scheduler, build_optimizer
@MODELS.register()
......@@ -71,6 +72,30 @@ class InvDNModel(BaseModel):
optims['optim'].step()
self.losses['loss'] = l_total.numpy()
def setup_optimizers(self, lr, cfg):
if cfg.get('name', None):
cfg_ = cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
cfg_['grad_clip'] = nn.ClipGradByNorm(cfg_['clip_grad_norm'])
cfg_.pop('clip_grad_norm')
self.optimizers['optim'] = build_optimizer(cfg_, lr, parameters)
else:
for opt_name, opt_cfg in cfg.items():
cfg_ = opt_cfg.copy()
net_names = cfg_.pop('net_names')
parameters = []
for net_name in net_names:
parameters += self.nets[net_name].parameters()
self.optimizers[opt_name] = build_optimizer(
cfg_, lr, parameters)
return self.optimizers
def forward(self):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册