未验证 提交 49773f36 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] Fix bug that epsilon become 0 using power (#32469)

上级 203ac4f3
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/optimizers/adam_op.h" #include "paddle/fluid/operators/optimizers/adam_op.h"
...@@ -122,8 +123,9 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -122,8 +123,9 @@ class AdamNPUKernel : public framework::OpKernel<T> {
FillNpuTensorWithConstant<T>(&beta2_tensor, beta2); FillNpuTensorWithConstant<T>(&beta2_tensor, beta2);
Tensor epsilon_tensor(framework::proto::VarType::FP32); Tensor epsilon_tensor(framework::proto::VarType::FP32);
epsilon_tensor.mutable_data<T>({1}, ctx.GetPlace()); TensorFromVector(std::vector<T>{epsilon},
FillNpuTensorWithConstant<T>(&epsilon_tensor, epsilon); ctx.template device_context<platform::DeviceContext>(),
&epsilon_tensor);
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册