提交 0f8b5087 编写于 作者: Q qjing666
......@@ -691,7 +691,7 @@ class FLDistributeTranspiler(object):
opti_to_param = dict()
param_to_opti = dict()
for op in self.optimize_ops:
if op.type == "sgd":
if (op.type == "sgd") or (op.type == "adam"):
origin_name = op.output("ParamOut")
var = self.origin_program.global_block().var(origin_name[0])
new_var_name = "%s.opti.trainer_%d" % (origin_name[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册