未验证 提交 49c8253f 编写于 作者: Z zhangbo9674 提交者: GitHub

modify adam to adamw in AdamW (#36028)

* adam to adamw in AdamW

* add lr_ratio in adamw

* refine logic bug in cpu adamw

* delete fix bug for cpu adamw

* delete fix bug for cpu adamw
上级 b23b17c0
...@@ -71,6 +71,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -71,6 +71,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adam", {"adam",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}}, "Beta2Pow", "MasterParam"}},
{"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
...@@ -110,6 +113,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -110,6 +113,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adam", {"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...@@ -129,7 +135,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -129,7 +135,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}}, "MasterParamOut"}},
{"adamw", {"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"average_accumulates", {"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}}, "out_old_num_accumulates", "out_num_updates"}},
......
...@@ -298,14 +298,14 @@ class AdamW(Adam): ...@@ -298,14 +298,14 @@ class AdamW(Adam):
_beta2 = self._beta2 if not isinstance( _beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0) self._beta2, Variable) else self._beta2.numpy().item(0)
_, _, _, _, _, _ = _C_ops.adam( _, _, _, _, _, _ = _C_ops.adamw(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0], beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[0],
moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight, moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode,
'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1,
'beta2', _beta2, 'coeff', self._coeff, 'multi_precision', 'beta2', _beta2, 'coeff', self._coeff, 'multi_precision',
find_master) find_master, "lr_ratio", lr_ratio_)
return None return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册