未验证 提交 6692dc9a 编写于 作者: Z zhangyikun02 提交者: GitHub

TensorSetConstantXPU support to use xpu::constant when T is float/float16 (#55122)

* TensorSetConstantXPU support to use xpu::constant when T is float/float16

* add xpu_wait for TensorSetConstantXPU
上级 70183c4b
......@@ -21,6 +21,11 @@ limitations under the License. */
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#endif
namespace phi {
namespace funcs {
......@@ -112,7 +117,8 @@ struct TensorSetConstantXPU {
: tensor_(tensor), value_(value), place_(place) {}
template <typename T>
void apply() const {
auto* begin = tensor_->mutable_data<T>(place_);
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
auto begin = ctx->Alloc<T>(tensor_);
int numel = tensor_->numel();
std::unique_ptr<T[]> data_cpu(new T[numel]);
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
......@@ -126,6 +132,41 @@ struct TensorSetConstantXPU {
U value_;
phi::Place place_;
};
template <>
struct TensorSetConstantXPU<float> {
TensorSetConstantXPU(phi::DenseTensor* tensor, float value, phi::Place place)
: tensor_(tensor), value_(value), place_(place) {}
template <typename T>
void apply() const {
auto* ctx = phi::DeviceContextPool::Instance().Get(place_);
auto begin = ctx->Alloc<T>(tensor_);
int numel = tensor_->numel();
if (((std::is_same<T, float>::value) ||
(std::is_same<T, phi::dtype::float16>::value)) &&
(place_ == phi::XPUPlace())) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* dev_ctx = static_cast<phi::XPUContext*>(ctx);
int r = xpu::constant<XPUType>(dev_ctx->x_context(),
reinterpret_cast<XPUType*>(begin),
numel,
static_cast<XPUType>(value_));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
dev_ctx->Wait();
} else {
std::unique_ptr<T[]> data_cpu(new T[numel]);
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
memory_utils::Copy(place_,
begin,
phi::CPUPlace(),
static_cast<void*>(data_cpu.get()),
numel * sizeof(T));
}
}
phi::DenseTensor* tensor_;
float value_;
phi::Place place_;
};
#endif
template <typename Context, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册