未验证 提交 374e757f 编写于 作者: M mayang002 提交者: GitHub

Xpu ernie3: support fp16 for xpu kernels: full_like/stack/where (#51271)

* [xpu-ernie3] support fp16 for full_like/stack/where xpu kernels

* [xpu-ernie3] support fp16 for full_like/stack/where xpu kernels
上级 f3448977
...@@ -319,6 +319,16 @@ XPUOpMap& get_kl2_ops() { ...@@ -319,6 +319,16 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"full_batch_size_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fill_constant_batch_size_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"unfold", {"unfold",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unfold_grad", {"unfold_grad",
...@@ -595,6 +605,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -595,6 +605,7 @@ XPUOpMap& get_kl2_ops() {
{"shape", {"shape",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16})}, phi::DataType::FLOAT16})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -668,7 +679,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -668,7 +679,8 @@ XPUOpMap& get_kl2_ops() {
{"stack", {"stack",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32})}, phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"stack_grad", {"stack_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"strided_slice", {"strided_slice",
...@@ -810,7 +822,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -810,7 +822,8 @@ XPUOpMap& get_kl2_ops() {
{"where", {"where",
XPUKernelSet({phi::DataType::INT32, XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"sin", XPUKernelSet({phi::DataType::FLOAT32})}, {"sin", XPUKernelSet({phi::DataType::FLOAT32})},
{"cos", XPUKernelSet({phi::DataType::FLOAT32})}, {"cos", XPUKernelSet({phi::DataType::FLOAT32})},
{"linspace", {"linspace",
......
...@@ -88,13 +88,32 @@ void FullLikeKernel(const Context& dev_ctx, ...@@ -88,13 +88,32 @@ void FullLikeKernel(const Context& dev_ctx,
phi::errors::InvalidArgument("The filled value is Inf.")); phi::errors::InvalidArgument("The filled value is Inf."));
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>()); auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
if (out->numel() > 0) {
int r = xpu::constant(dev_ctx.x_context(), int r = xpu::constant(dev_ctx.x_context(),
out_data, out_data,
out->numel(), out->numel(),
static_cast<XPUInTDType>(value)); static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
} }
template <typename T, typename Context>
void FullBatchSizeLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const Scalar& val,
DataType dtype,
int x_batch_size_dim,
int out_batch_size_dim,
DenseTensor* out) {
if (x.lod().size() && x_batch_size_dim == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
odims[out_batch_size_dim] = static_cast<int>(x.lod().back().size()) - 1;
FullKernel<T, Context>(dev_ctx, phi::vectorize(odims), val, dtype, out);
}
FullLikeKernel<T, Context>(dev_ctx, x, val, dtype, out);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(full, PD_REGISTER_KERNEL(full,
...@@ -122,3 +141,15 @@ PD_REGISTER_KERNEL(full_like, ...@@ -122,3 +141,15 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
PD_REGISTER_KERNEL(full_batch_size_like,
XPU,
ALL_LAYOUT,
phi::FullBatchSizeLikeKernel,
float,
int,
int64_t,
bool,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
...@@ -55,5 +55,11 @@ void StackKernel(const Context& dev_ctx, ...@@ -55,5 +55,11 @@ void StackKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(stack,
stack, XPU, ALL_LAYOUT, phi::StackKernel, float, int, int64_t) {} XPU,
ALL_LAYOUT,
phi::StackKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
...@@ -25,10 +25,11 @@ void WhereKernel(const Context& ctx, ...@@ -25,10 +25,11 @@ void WhereKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const bool* cond_data = condition.data<bool>(); const bool* cond_data = condition.data<bool>();
const T* x_data = x.data<T>(); const XPUType* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
const T* y_data = y.data<T>(); const XPUType* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
T* out_data = ctx.template Alloc<T>(out); XPUType* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
auto cond_dims = phi::vectorize<int>(condition.dims()); auto cond_dims = phi::vectorize<int>(condition.dims());
auto x_dims = phi::vectorize<int>(x.dims()); auto x_dims = phi::vectorize<int>(x.dims());
...@@ -44,10 +45,16 @@ void WhereKernel(const Context& ctx, ...@@ -44,10 +45,16 @@ void WhereKernel(const Context& ctx,
int ret = xpu::select( int ret = xpu::select(
ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims); ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select"); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(where,
where, XPU, ALL_LAYOUT, phi::WhereKernel, float, int, int64_t) {} XPU,
ALL_LAYOUT,
phi::WhereKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册