未验证 提交 374e757f 编写于 作者: M mayang002 提交者: GitHub

Xpu ernie3: support fp16 for xpu kernels: full_like/stack/where (#51271)

* [xpu-ernie3] support fp16 for full_like/stack/where xpu kernels

* [xpu-ernie3] support fp16 for full_like/stack/where xpu kernels
上级 f3448977
......@@ -319,6 +319,16 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT32})},
{"full_batch_size_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fill_constant_batch_size_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"unfold",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unfold_grad",
......@@ -595,6 +605,7 @@ XPUOpMap& get_kl2_ops() {
{"shape",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -668,7 +679,8 @@ XPUOpMap& get_kl2_ops() {
{"stack",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"stack_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"strided_slice",
......@@ -810,7 +822,8 @@ XPUOpMap& get_kl2_ops() {
{"where",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"sin", XPUKernelSet({phi::DataType::FLOAT32})},
{"cos", XPUKernelSet({phi::DataType::FLOAT32})},
{"linspace",
......
......@@ -88,13 +88,32 @@ void FullLikeKernel(const Context& dev_ctx,
phi::errors::InvalidArgument("The filled value is Inf."));
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
int r = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
if (out->numel() > 0) {
int r = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
}
template <typename T, typename Context>
void FullBatchSizeLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape,
const Scalar& val,
DataType dtype,
int x_batch_size_dim,
int out_batch_size_dim,
DenseTensor* out) {
if (x.lod().size() && x_batch_size_dim == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
odims[out_batch_size_dim] = static_cast<int>(x.lod().back().size()) - 1;
FullKernel<T, Context>(dev_ctx, phi::vectorize(odims), val, dtype, out);
}
FullLikeKernel<T, Context>(dev_ctx, x, val, dtype, out);
}
} // namespace phi
PD_REGISTER_KERNEL(full,
......@@ -122,3 +141,15 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(full_batch_size_like,
XPU,
ALL_LAYOUT,
phi::FullBatchSizeLikeKernel,
float,
int,
int64_t,
bool,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -55,5 +55,11 @@ void StackKernel(const Context& dev_ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(
stack, XPU, ALL_LAYOUT, phi::StackKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(stack,
XPU,
ALL_LAYOUT,
phi::StackKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
......@@ -25,10 +25,11 @@ void WhereKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const bool* cond_data = condition.data<bool>();
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
T* out_data = ctx.template Alloc<T>(out);
const XPUType* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
const XPUType* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
XPUType* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
auto cond_dims = phi::vectorize<int>(condition.dims());
auto x_dims = phi::vectorize<int>(x.dims());
......@@ -44,10 +45,16 @@ void WhereKernel(const Context& ctx,
int ret = xpu::select(
ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select");
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "xpu::select");
}
} // namespace phi
PD_REGISTER_KERNEL(
where, XPU, ALL_LAYOUT, phi::WhereKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(where,
XPU,
ALL_LAYOUT,
phi::WhereKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册