未验证 提交 0f578db9 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] refine FillNpuTensorWithConstant (#32682)

上级 69d237c2
...@@ -90,6 +90,9 @@ aclrtStream GetCurrentNPUStream(int device_id = -1); ...@@ -90,6 +90,9 @@ aclrtStream GetCurrentNPUStream(int device_id = -1);
template <typename T> template <typename T>
void FillNpuTensorWithConstant(Tensor *tensor, T val) { void FillNpuTensorWithConstant(Tensor *tensor, T val) {
// NOTE(zhiqiu): we found that power sometimes returns 0 when val is small
// like 1e-8.
constexpr float MIN_PRECISION_FOR_POWER = 1e-3;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor->IsInitialized(), true, tensor->IsInitialized(), true,
platform::errors::InvalidArgument("The tensor should be initialized.")); platform::errors::InvalidArgument("The tensor should be initialized."));
...@@ -97,7 +100,8 @@ void FillNpuTensorWithConstant(Tensor *tensor, T val) { ...@@ -97,7 +100,8 @@ void FillNpuTensorWithConstant(Tensor *tensor, T val) {
platform::is_npu_place(tensor->place()), true, platform::is_npu_place(tensor->place()), true,
platform::errors::InvalidArgument("The tensor should be on NPUPlace.")); platform::errors::InvalidArgument("The tensor should be on NPUPlace."));
// do async for better performance // do async for better performance
if (typeid(float) == typeid(T) || typeid(platform::float16) == typeid(T)) { if ((typeid(float) == typeid(T) || typeid(platform::float16) == typeid(T)) &&
static_cast<float>(val) > MIN_PRECISION_FOR_POWER) {
Tensor tmp(tensor->type()); Tensor tmp(tensor->type());
tmp.Resize(tensor->dims()); tmp.Resize(tensor->dims());
tmp.mutable_data<T>(tensor->place()); tmp.mutable_data<T>(tensor->place());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册