提交 de43c1ee 编写于 作者: G guohongzilong

fix group params order

上级 5359d63e
......@@ -360,16 +360,18 @@ class Optimizer(Cell):
if len(ordered_parameters) != len(self.group_params):
raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
ordered_params = [None] * params_length
ordered_learning_rate = [None] * params_length
ordered_weight_decay = [None] * params_length
params_name = [param.name for param in ordered_parameters]
for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay):
index = params_name.index(param.name)
ordered_params[index] = param
ordered_learning_rate[index] = lr
ordered_weight_decay[index] = wd
self.group_params = list(ordered_parameters)
self.group_params = ordered_params
self.group_lr = ordered_learning_rate
self.group_weight_decay = ordered_weight_decay
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册