未验证 提交 77703f37 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #6103 from LDOUBLEV/dygraph

fix det cml + pact + distribute training bug
...@@ -43,12 +43,15 @@ class Momentum(object): ...@@ -43,12 +43,15 @@ class Momentum(object):
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -76,6 +79,9 @@ class Adam(object): ...@@ -76,6 +79,9 @@ class Adam(object):
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Adam( opt = optim.Adam(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -85,7 +91,7 @@ class Adam(object): ...@@ -85,7 +91,7 @@ class Adam(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
lazy_mode=self.lazy_mode, lazy_mode=self.lazy_mode,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -118,6 +124,9 @@ class RMSProp(object): ...@@ -118,6 +124,9 @@ class RMSProp(object):
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -125,7 +134,7 @@ class RMSProp(object): ...@@ -125,7 +134,7 @@ class RMSProp(object):
epsilon=self.epsilon, epsilon=self.epsilon,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -149,6 +158,9 @@ class Adadelta(object): ...@@ -149,6 +158,9 @@ class Adadelta(object):
self.name = name self.name = name
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Adadelta( opt = optim.Adadelta(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
epsilon=self.epsilon, epsilon=self.epsilon,
...@@ -156,7 +168,7 @@ class Adadelta(object): ...@@ -156,7 +168,7 @@ class Adadelta(object):
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -190,17 +202,20 @@ class AdamW(object): ...@@ -190,17 +202,20 @@ class AdamW(object):
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model): def __call__(self, model):
parameters = model.parameters() parameters = [
param for param in model.parameters() if param.trainable is True
]
self.no_weight_decay_param_name_list = [ self.no_weight_decay_param_name_list = [
p.name for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) p.name for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
] ]
if self.one_dim_param_no_weight_decay: if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [ self.no_weight_decay_param_name_list += [
p.name for n, p in model.named_parameters() if len(p.shape) == 1 p.name for n, p in model.named_parameters() if len(p.shape) == 1
] ]
opt = optim.AdamW( opt = optim.AdamW(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -216,4 +231,4 @@ class AdamW(object): ...@@ -216,4 +231,4 @@ class AdamW(object):
return opt return opt
def _apply_decay_param_fun(self, name): def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list return name not in self.no_weight_decay_param_name_list
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册