提交 73f4d8e4 编写于 作者: G gaotingquan 提交者: Tingquan Gao

to avoid cause issues for unset no_weight_decay models.

there seems be a diff for optimizer about using [] and [{"params":}, {"params":}] params
上级 31ea33c8
...@@ -113,22 +113,26 @@ class Momentum(object): ...@@ -113,22 +113,26 @@ class Momentum(object):
# model_list is None in static graph # model_list is None in static graph
parameters = None parameters = None
if model_list: if model_list:
params_with_decay = [] # TODO(gaotingquan): to avoid cause issues for unset no_weight_decay models
params_without_decay = [] if len(self.no_weight_decay_name_list) > 0:
for m in model_list: params_with_decay = []
for n, p in m.named_parameters(): params_without_decay = []
if any(nd in n for nd in self.no_weight_decay_name_list) \ for m in model_list:
or (self.one_dim_param_no_weight_decay and len(p.shape) == 1): for n, p in m.named_parameters():
params_without_decay.append(p) if any(nd in n for nd in self.no_weight_decay_name_list) \
else: or (self.one_dim_param_no_weight_decay and len(p.shape) == 1):
params_with_decay.append(p) params_without_decay.append(p)
parameters = [{ else:
"params": params_with_decay, params_with_decay.append(p)
"weight_decay": self.weight_decay parameters = [{
}, { "params": params_with_decay,
"params": params_without_decay, "weight_decay": self.weight_decay
"weight_decay": 0.0 }, {
}] "params": params_without_decay,
"weight_decay": 0.0
}]
else:
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -279,9 +283,8 @@ class AdamW(object): ...@@ -279,9 +283,8 @@ class AdamW(object):
if self.one_dim_param_no_weight_decay: if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [ self.no_weight_decay_param_name_list += [
p.name p.name for model in model_list
for model in model_list for n, p in model.named_parameters() for n, p in model.named_parameters() if len(p.shape) == 1
if len(p.shape) == 1
] if model_list else [] ] if model_list else []
opt = optim.AdamW( opt = optim.AdamW(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册