From 0036316ec68d3d1fe055d1edbc3e3aefd4dd87ab Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Thu, 9 Feb 2023 11:01:10 +0800 Subject: [PATCH] add logical_and, logical_or and logical_xor for xpu (#50228) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 + paddle/phi/kernels/xpu/logical_kernel.cc | 140 +++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index b3635652ff..2d9b06c605 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -360,7 +360,10 @@ XPUOpMap& get_kl2_ops() { {"log_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"log_softmax", XPUKernelSet({phi::DataType::FLOAT32})}, {"log_softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"logical_and", XPUKernelSet({phi::DataType::BOOL})}, {"logical_not", XPUKernelSet({phi::DataType::BOOL})}, + {"logical_or", XPUKernelSet({phi::DataType::BOOL})}, + {"logical_xor", XPUKernelSet({phi::DataType::BOOL})}, {"lookup_table_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"lookup_table_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"masked_select", diff --git a/paddle/phi/kernels/xpu/logical_kernel.cc b/paddle/phi/kernels/xpu/logical_kernel.cc index e6a0ea242d..57dc8b4387 100644 --- a/paddle/phi/kernels/xpu/logical_kernel.cc +++ b/paddle/phi/kernels/xpu/logical_kernel.cc @@ -28,6 +28,146 @@ void LogicalNotKernel(const Context& ctx, xpu::logical_not(ctx.x_context(), x.data(), out->data(), x.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "logical_not"); } + +template +void LogicalBinaryKernel( + const XPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out, + std::function func, + std::string funcname = "unknown") { + dev_ctx.template Alloc(out); + + int r = xpu::SUCCESS; + const auto* x_data = x.data(); + const auto* y_data = y.data(); + auto* out_data = out->data(); + + if (x.numel() == out->numel() && y.numel() == out->numel()) { + r = func(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(out_data), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, funcname); + return; + } + + // x or y need to do broadcast + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + int axis = std::abs(x_dims.size() - y_dims.size()); + + std::vector x_dims_vec(max_dim, 1); + std::vector y_dims_vec(max_dim, 1); + if (x_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + x_dims_vec[i] = x_dims[i]; + } + } else { + for (int i = 0; i < x_dims.size(); i++) { + x_dims_vec[i + axis] = x_dims[i]; + } + } + if (y_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + y_dims_vec[i] = y_dims[i]; + } + } else { + for (int i = 0; i < y_dims.size(); i++) { + y_dims_vec[i + axis] = y_dims[i]; + } + } + if (x_dims_vec.size() == 0) { + x_dims_vec = std::vector({1}); + } + + if (y_dims_vec.size() == 0) { + y_dims_vec = std::vector({1}); + } + + bool is_x_need_broadcast = false; + bool is_y_need_broadcast = false; + auto out_vec = phi::vectorize(out->dims()); + for (int i = 0; i < max_dim; i++) { + if (x_dims_vec[i] != out_vec[i]) { + is_x_need_broadcast = true; + break; + } + } + for (int i = 0; i < max_dim; i++) { + if (y_dims_vec[i] != out_vec[i]) { + is_y_need_broadcast = true; + break; + } + } + + auto xpu_context = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_context); + if (is_x_need_broadcast) { + T* x_data_broadcast = RAII_GUARD.alloc_l3_or_gm(out->numel()); + r = xpu::broadcast(xpu_context, + reinterpret_cast(x_data), + reinterpret_cast(x_data_broadcast), + x_dims_vec, + out_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); + x_data = x_data_broadcast; + } + if (is_y_need_broadcast) { + T* y_data_broadcast = RAII_GUARD.alloc_l3_or_gm(out->numel()); + r = xpu::broadcast(xpu_context, + reinterpret_cast(y_data), + reinterpret_cast(y_data_broadcast), + y_dims_vec, + out_vec); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); + y_data = y_data_broadcast; + } + + r = func(xpu_context, + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(out_data), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, funcname); +} + +template +void LogicalAndKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + return LogicalBinaryKernel( + dev_ctx, x, y, out, xpu::logical_and, "logical_and"); +} + +template +void LogicalOrKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + return LogicalBinaryKernel( + dev_ctx, x, y, out, xpu::logical_or, "logical_or"); +} + +template +void LogicalXorKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + return LogicalBinaryKernel( + dev_ctx, x, y, out, xpu::logical_xor, "logical_xor"); +} } // namespace phi PD_REGISTER_KERNEL(logical_not, XPU, ALL_LAYOUT, phi::LogicalNotKernel, bool) {} +PD_REGISTER_KERNEL(logical_and, XPU, ALL_LAYOUT, phi::LogicalAndKernel, bool) {} +PD_REGISTER_KERNEL(logical_or, XPU, ALL_LAYOUT, phi::LogicalOrKernel, bool) {} +PD_REGISTER_KERNEL(logical_xor, XPU, ALL_LAYOUT, phi::LogicalXorKernel, bool) {} -- GitLab