未验证 提交 efa85f8c 编写于 作者: L Leo Chen 提交者: GitHub

fix adam (#32016)

上级 6503ef56
...@@ -68,12 +68,16 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -68,12 +68,16 @@ class AdamNPUKernel : public framework::OpKernel<T> {
beta1_pow_out->mutable_data<T>(ctx.GetPlace()); beta1_pow_out->mutable_data<T>(ctx.GetPlace());
TensorFromVector(std::vector<float>{beta1}, ctx.device_context(), TensorFromVector(std::vector<float>{beta1}, ctx.device_context(),
beta1_pow_out); beta1_pow_out);
} else {
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
} }
if (beta2_pow->place() == platform::CPUPlace()) { if (beta2_pow->place() == platform::CPUPlace()) {
float beta2 = *beta2_pow->data<float>(); float beta2 = *beta2_pow->data<float>();
beta2_pow_out->mutable_data<T>(ctx.GetPlace()); beta2_pow_out->mutable_data<T>(ctx.GetPlace());
TensorFromVector(std::vector<float>{beta2}, ctx.device_context(), TensorFromVector(std::vector<float>{beta2}, ctx.device_context(),
beta2_pow_out); beta2_pow_out);
} else {
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
} }
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册