未验证 提交 5e4b419b 编写于 作者: L Leo Chen 提交者: GitHub

copy beta pow to same place when skip_update=1 (#37245)

* copy beta pow to same place when skip_update=1

* fix xpu
上级 1e9b3a3d
...@@ -198,11 +198,11 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> { ...@@ -198,11 +198,11 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
*mom2, ctx.GetPlace(), *mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out); ctx.template device_context<platform::DeviceContext>(), mom2_out);
framework::TensorCopy( framework::TensorCopy(
*beta1_pow, ctx.GetPlace(), *beta1_pow, beta1_pow->place(),
ctx.template device_context<platform::DeviceContext>(), ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out); beta1_pow_out);
framework::TensorCopy( framework::TensorCopy(
*beta2_pow, ctx.GetPlace(), *beta2_pow, beta2_pow->place(),
ctx.template device_context<platform::DeviceContext>(), ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out); beta2_pow_out);
return; return;
......
...@@ -84,11 +84,11 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -84,11 +84,11 @@ class AdamNPUKernel : public framework::OpKernel<T> {
*mom2, ctx.GetPlace(), *mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out); ctx.template device_context<platform::DeviceContext>(), mom2_out);
framework::TensorCopy( framework::TensorCopy(
*beta1_pow, ctx.GetPlace(), *beta1_pow, beta1_pow->place(),
ctx.template device_context<platform::DeviceContext>(), ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out); beta1_pow_out);
framework::TensorCopy( framework::TensorCopy(
*beta2_pow, ctx.GetPlace(), *beta2_pow, beta2_pow->place(),
ctx.template device_context<platform::DeviceContext>(), ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out); beta2_pow_out);
return; return;
......
...@@ -86,11 +86,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -86,11 +86,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
mom2, ctx.GetPlace(), mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &mom2_out); ctx.template device_context<platform::DeviceContext>(), &mom2_out);
framework::TensorCopy( framework::TensorCopy(
beta1_pow, ctx.GetPlace(), beta1_pow, beta1_pow.place(),
ctx.template device_context<platform::DeviceContext>(), ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out); beta1_pow_out);
framework::TensorCopy( framework::TensorCopy(
beta2_pow, ctx.GetPlace(), beta2_pow, beta2_pow.place(),
ctx.template device_context<platform::DeviceContext>(), ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out); beta2_pow_out);
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册