未验证 提交 4ae76d21 编写于 作者: A Aurelius84 提交者: GitHub

[Op]Fix adam/adamw beta1_pow/beta2_pow place while copying (#41732)

上级 de2a3942
......@@ -190,8 +190,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
......
......@@ -139,8 +139,8 @@ void AdamDenseParamSparseGradKernel(
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
......
......@@ -156,8 +156,8 @@ void AdamwDenseParamSparseGradKernel(
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册