未验证 提交 18e0e01d 编写于 作者: L Leo Guo 提交者: GitHub

Modify full kernel for xpu. test=kunlun (#50209)

上级 350cd82a
......@@ -248,11 +248,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT16,
phi::DataType::UINT8,
phi::DataType::BOOL,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::COMPLEX64,
phi::DataType::COMPLEX128})},
phi::DataType::FLOAT16})},
{"flatten2_grad",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......
......@@ -70,3 +70,17 @@ PD_REGISTER_KERNEL(full_sr,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(full_sr,
XPU,
ALL_LAYOUT,
phi::sr::FullKernel,
float,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#endif
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
......@@ -59,8 +60,19 @@ void FullKernel(const Context& dev_ctx,
const Scalar& val,
DataType dtype,
DenseTensor* out) {
using XPUInTDType = typename XPUTypeTrait<T>::Type;
out->Resize(phi::make_ddim(shape.GetData()));
FullValueXPU<T>(dev_ctx, out, val.to<T>());
int numel = out->numel();
dev_ctx.template Alloc<T>(out);
auto value = val.to<double>();
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
if (numel > 0) {
int r = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
}
template <typename T, typename Context>
......@@ -103,16 +115,11 @@ void FullLikeKernel(const Context& dev_ctx,
phi::errors::InvalidArgument("The filled value is Inf."));
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
int ret = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
phi::errors::External("XPU CONSTANT API return wrong value[%d %s].",
ret,
XPUAPIErrorMsg[ret]));
int r = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
} // namespace phi
......@@ -122,24 +129,23 @@ PD_REGISTER_KERNEL(full,
ALL_LAYOUT,
phi::FullKernel,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::float16) {}
PD_REGISTER_KERNEL(full_like,
XPU,
ALL_LAYOUT,
phi::FullLikeKernel,
float,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册