未验证 提交 18e0e01d 编写于 作者: L Leo Guo 提交者: GitHub

Modify full kernel for xpu. test=kunlun (#50209)

上级 350cd82a
...@@ -248,11 +248,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -248,11 +248,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT16, phi::DataType::INT16,
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT64,
phi::DataType::FLOAT32, phi::DataType::FLOAT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16})},
phi::DataType::COMPLEX64,
phi::DataType::COMPLEX128})},
{"flatten2_grad", {"flatten2_grad",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
......
...@@ -70,3 +70,17 @@ PD_REGISTER_KERNEL(full_sr, ...@@ -70,3 +70,17 @@ PD_REGISTER_KERNEL(full_sr,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#endif #endif
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(full_sr,
XPU,
ALL_LAYOUT,
phi::sr::FullKernel,
float,
uint8_t,
int16_t,
int,
int64_t,
bool,
phi::dtype::float16) {}
#endif
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
...@@ -59,8 +60,19 @@ void FullKernel(const Context& dev_ctx, ...@@ -59,8 +60,19 @@ void FullKernel(const Context& dev_ctx,
const Scalar& val, const Scalar& val,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
using XPUInTDType = typename XPUTypeTrait<T>::Type;
out->Resize(phi::make_ddim(shape.GetData())); out->Resize(phi::make_ddim(shape.GetData()));
FullValueXPU<T>(dev_ctx, out, val.to<T>()); int numel = out->numel();
dev_ctx.template Alloc<T>(out);
auto value = val.to<double>();
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
if (numel > 0) {
int r = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -103,16 +115,11 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -103,16 +115,11 @@ void FullLikeKernel(const Context& dev_ctx,
phi::errors::InvalidArgument("The filled value is Inf.")); phi::errors::InvalidArgument("The filled value is Inf."));
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>()); auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
int ret = xpu::constant(dev_ctx.x_context(), int r = xpu::constant(dev_ctx.x_context(),
out_data, out_data,
out->numel(), out->numel(),
static_cast<XPUInTDType>(value)); static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
ret,
XPU_SUCCESS,
phi::errors::External("XPU CONSTANT API return wrong value[%d %s].",
ret,
XPUAPIErrorMsg[ret]));
} }
} // namespace phi } // namespace phi
...@@ -122,24 +129,23 @@ PD_REGISTER_KERNEL(full, ...@@ -122,24 +129,23 @@ PD_REGISTER_KERNEL(full,
ALL_LAYOUT, ALL_LAYOUT,
phi::FullKernel, phi::FullKernel,
float, float,
double,
uint8_t, uint8_t,
int16_t, int16_t,
int, int,
int64_t, int64_t,
bool, bool,
phi::dtype::float16, phi::dtype::float16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(full_like, PD_REGISTER_KERNEL(full_like,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::FullLikeKernel, phi::FullLikeKernel,
float, float,
uint8_t,
int16_t,
int, int,
int64_t, int64_t,
bool,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册