From 460e4fc6e82221be4639e061d4e11bd9332a5f21 Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Fri, 11 Aug 2023 10:13:25 +0800 Subject: [PATCH] [XPU] Add fast_gather_nd plugin (#56103) --- paddle/phi/kernels/xpu/gather_nd_kernel.cc | 21 ++ .../kernels/xpu/plugin/include/xpu/plugin.h | 24 ++ .../src/kernel/kunlun2cpp/fast_gather_nd.xpu | 259 ++++++++++++++++ .../xpu/plugin/src/wrapper/fast_gather_nd.cpp | 281 ++++++++++++++++++ 4 files changed, 585 insertions(+) create mode 100644 paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu create mode 100644 paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp diff --git a/paddle/phi/kernels/xpu/gather_nd_kernel.cc b/paddle/phi/kernels/xpu/gather_nd_kernel.cc index 9966d3795d5..43581963987 100644 --- a/paddle/phi/kernels/xpu/gather_nd_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_nd_kernel.cc @@ -87,6 +87,7 @@ void GatherNdKernel(const Context &ctx, x_shape.data(), static_cast(x_shape.size()), nullptr}; int ret = XPU_SUCCESS; +#ifndef PADDLE_WITH_XPU_PLUGIN if (index_type == DataType::INT32) { ret = xpu::gather_nd( ctx.x_context(), @@ -105,6 +106,26 @@ void GatherNdKernel(const Context &ctx, index_shape); } PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd"); +#else + if (index_type == DataType::INT32) { + ret = xpu::plugin::fast_gather_nd( + ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + x_vec, + index_shape); + } else { + ret = xpu::plugin::fast_gather_nd( + ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + x_vec, + index_shape); + } + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "fast_gather_nd"); +#endif } } // namespace phi diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index 712c02977de..eb7588252a6 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -31,6 +31,30 @@ DLL_EXPORT int fast_where(Context* ctx, const T* y, T* out, int64_t len); +template +DLL_EXPORT int fast_gather_nd(Context* ctx, + const T* x, + const TID* index, + T* y, + const VectorParam& xshape, + const std::vector& index_shape); +template +static inline int fast_gather_nd(Context* ctx, + const T* x, + const TID* index, + T* y, + const VectorParam& xshape, + const std::vector& index_shape) { + auto deleter = [](int64_t* ptr) { delete[] ptr; }; + std::shared_ptr xshape_i64(new int64_t[xshape.len], deleter); + return fast_gather_nd( + ctx, + x, + index, + y, + vpi32_to_vpi64(xshape, xshape_i64.get()), + std::vector(index_shape.begin(), index_shape.end())); +} } // namespace plugin } // namespace api diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu new file mode 100644 index 00000000000..09c69561d73 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu @@ -0,0 +1,259 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +template +__global__ void fast_gather1d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_stride0, + int8_t* y) { + int cid = core_id(); + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + const int index_len = 320 / sizeof(TID); + __simd__ TID local_index[index_len]; + const int buf_len = 5824 / sizeof(int8_t); + __simd__ int8_t local_x[buf_len]; + if (x_stride0 > buf_len) { + for (int64_t i = tid; i < count; i += nthreads) { + GM2LM(index + i, local_index, sizeof(TID)); + int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0; + for (int64_t j = 0; j < x_stride0; j += buf_len) { + int read_len = min(static_cast(x_stride0), x_stride0 - j); + GM2LM(x + offset + j, local_x, read_len); + LM2GM(local_x, y + i * x_stride0 + j, read_len); + } + } + } else { + int64_t count_per_thread = min(index_len, buf_len / x_stride0); + for (int64_t i = tid * count_per_thread; i < count; + i += nthreads * count_per_thread) { + int count_in_thread = + min(static_cast(count_per_thread), count - i); + GM2LM(index + i, local_index, count_in_thread * sizeof(TID)); + for (int64_t j = 0; j < count_in_thread; j++) { + int64_t offset = ((local_index[j] + x_dim0) % x_dim0) * x_stride0; + GM2LM_ASYNC(x + offset, local_x + j * x_stride0, x_stride0); + } + mfence_lm(); + LM2GM(local_x, y + i * x_stride0, x_stride0 * count_in_thread); + } + } +} + +template +__global__ void fast_gather2d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_dim1, + int64_t x_stride0, + int64_t x_stride1, + int8_t* y) { + int cid = core_id(); + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + const int index_len = 640 / sizeof(TID); + __simd__ TID local_index[index_len]; + const int buf_len = 5504 / sizeof(int8_t); + __simd__ int8_t local_x[buf_len]; + if (x_stride1 > buf_len) { + for (int64_t i = tid; i < count; i += nthreads) { + GM2LM(index + i * 2, local_index, 2 * sizeof(TID)); + int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 + + ((local_index[1] + x_dim1) % x_dim1) * x_stride1; + for (int64_t j = 0; j < x_stride1; j += buf_len) { + int read_len = min(static_cast(x_stride1), x_stride1 - j); + GM2LM(x + offset + j, local_x, read_len); + LM2GM(local_x, y + i * x_stride1 + j, read_len); + } + } + } else { + int64_t count_per_thread = min(index_len / 2, buf_len / x_stride1); + for (int64_t i = tid * count_per_thread; i < count; + i += nthreads * count_per_thread) { + int count_in_thread = + min(static_cast(count_per_thread), count - i); + GM2LM(index + i * 2, local_index, 2 * count_in_thread * sizeof(TID)); + for (int64_t j = 0; j < count_in_thread; j++) { + int64_t offset = + ((local_index[j * 2] + x_dim0) % x_dim0) * x_stride0 + + ((local_index[j * 2 + 1] + x_dim1) % x_dim1) * x_stride1; + GM2LM_ASYNC(x + offset, local_x + j * x_stride1, x_stride1); + } + mfence_lm(); + LM2GM(local_x, y + i * x_stride1, x_stride1 * count_in_thread); + } + } +} + +template +__global__ void fast_gather3d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_dim1, + int64_t x_dim2, + int64_t x_stride0, + int64_t x_stride1, + int64_t x_stride2, + int8_t* y) { + int cid = core_id(); + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + const int index_len = 960 / sizeof(TID); + __simd__ TID local_index[index_len]; + const int buf_len = 5184 / sizeof(int8_t); + __simd__ int8_t local_x[buf_len]; + if (x_stride2 > buf_len) { + for (int64_t i = tid; i < count; i += nthreads) { + GM2LM(index + i * 3, local_index, 3 * sizeof(TID)); + int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 + + ((local_index[1] + x_dim1) % x_dim1) * x_stride1 + + ((local_index[2] + x_dim2) % x_dim2) * x_stride2; + for (int64_t j = 0; j < x_stride2; j += buf_len) { + int read_len = min(static_cast(x_stride2), x_stride2 - j); + GM2LM(x + offset + j, local_x, read_len); + LM2GM(local_x, y + i * x_stride2 + j, read_len); + } + } + } else { + int64_t count_per_thread = min(index_len / 3, buf_len / x_stride2); + for (int64_t i = tid * count_per_thread; i < count; + i += nthreads * count_per_thread) { + int count_in_thread = + min(static_cast(count_per_thread), count - i); + GM2LM(index + i * 3, local_index, 3 * count_in_thread * sizeof(TID)); + for (int64_t j = 0; j < count_in_thread; j++) { + int64_t offset = + ((local_index[j * 3] + x_dim0) % x_dim0) * x_stride0 + + ((local_index[j * 3 + 1] + x_dim1) % x_dim1) * x_stride1 + + ((local_index[j * 3 + 2] + x_dim2) % x_dim2) * x_stride2; + GM2LM_ASYNC(x + offset, local_x + j * x_stride2, x_stride2); + } + mfence_lm(); + LM2GM(local_x, y + i * x_stride2, x_stride2 * count_in_thread); + } + } +} + +template +__global__ void fast_gather4d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_dim1, + int64_t x_dim2, + int64_t x_dim3, + int64_t x_stride0, + int64_t x_stride1, + int64_t x_stride2, + int64_t x_stride3, + int8_t* y) { + int cid = core_id(); + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + const int index_len = 1280 / sizeof(TID); + __simd__ TID local_index[index_len]; + const int buf_len = 4864 / sizeof(int8_t); + __simd__ int8_t local_x[buf_len]; + if (x_stride3 > buf_len) { + for (int64_t i = tid; i < count; i += nthreads) { + GM2LM(index + i * 4, local_index, 4 * sizeof(TID)); + int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 + + ((local_index[1] + x_dim1) % x_dim1) * x_stride1 + + ((local_index[2] + x_dim2) % x_dim2) * x_stride2 + + ((local_index[3] + x_dim3) % x_dim3) * x_stride3; + for (int64_t j = 0; j < x_stride3; j += buf_len) { + int read_len = min(static_cast(x_stride3), x_stride3 - j); + GM2LM(x + offset + j, local_x, read_len); + LM2GM(local_x, y + i * x_stride3 + j, read_len); + } + } + } else { + int64_t count_per_thread = min(index_len / 4, buf_len / x_stride3); + for (int64_t i = tid * count_per_thread; i < count; + i += nthreads * count_per_thread) { + int count_in_thread = + min(static_cast(count_per_thread), count - i); + GM2LM(index + i * 4, local_index, 4 * count_in_thread * sizeof(TID)); + for (int64_t j = 0; j < count_in_thread; j++) { + int64_t offset = + ((local_index[j * 4] + x_dim0) % x_dim0) * x_stride0 + + ((local_index[j * 4 + 1] + x_dim1) % x_dim1) * x_stride1 + + ((local_index[j * 4 + 2] + x_dim2) % x_dim2) * x_stride2 + + ((local_index[j * 4 + 3] + x_dim3) % x_dim3) * x_stride3; + GM2LM_ASYNC(x + offset, local_x + j * x_stride3, x_stride3); + } + mfence_lm(); + LM2GM(local_x, y + i * x_stride3, x_stride3 * count_in_thread); + } + } +} + +#define _XPU_DEF__FAST_GATHERND_(IDTYPE) \ + template __global__ void fast_gather1d(const int8_t* x, \ + const IDTYPE* index, \ + int64_t count, \ + int64_t x_dim0, \ + int64_t x_stride0, \ + int8_t* y); \ + template __global__ void fast_gather2d(const int8_t* x, \ + const IDTYPE* index, \ + int64_t count, \ + int64_t x_dim0, \ + int64_t x_dim1, \ + int64_t x_stride0, \ + int64_t x_stride1, \ + int8_t* y); \ + template __global__ void fast_gather3d(const int8_t* x, \ + const IDTYPE* index, \ + int64_t count, \ + int64_t x_dim0, \ + int64_t x_dim1, \ + int64_t x_dim2, \ + int64_t x_stride0, \ + int64_t x_stride1, \ + int64_t x_stride2, \ + int8_t* y); \ + template __global__ void fast_gather4d(const int8_t* x, \ + const IDTYPE* index, \ + int64_t count, \ + int64_t x_dim0, \ + int64_t x_dim1, \ + int64_t x_dim2, \ + int64_t x_dim3, \ + int64_t x_stride0, \ + int64_t x_stride1, \ + int64_t x_stride2, \ + int64_t x_stride3, \ + int8_t* y); +_XPU_DEF__FAST_GATHERND_(int); +_XPU_DEF__FAST_GATHERND_(int8_t); +_XPU_DEF__FAST_GATHERND_(int64_t); +_XPU_DEF__FAST_GATHERND_(bool); + +} // namespace plugin +} // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp new file mode 100644 index 00000000000..24215092768 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp @@ -0,0 +1,281 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void fast_gather1d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_stride0, + int8_t* y); +template +__attribute__((global)) void fast_gather2d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_dim1, + int64_t x_stride0, + int64_t x_stride1, + int8_t* y); +template +__attribute__((global)) void fast_gather3d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_dim1, + int64_t x_dim2, + int64_t x_stride0, + int64_t x_stride1, + int64_t x_stride2, + int8_t* y); +template +__attribute__((global)) void fast_gather4d(const int8_t* x, + const TID* index, + int64_t count, + int64_t x_dim0, + int64_t x_dim1, + int64_t x_dim2, + int64_t x_dim3, + int64_t x_stride0, + int64_t x_stride1, + int64_t x_stride2, + int64_t x_stride3, + int8_t* y); +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper(Context* ctx, + const T* x, + const TID* index, + T* y, + const VectorParam& x_shape, + const std::vector& index_shape) { + int64_t x_shape_size = x_shape.len; + int64_t index_shape_size = index_shape.size(); + int64_t gather_time = 1; + for (int64_t i = 0; i < index_shape_size - 1; i++) { + gather_time *= index_shape[i]; + } + int64_t end_size = index_shape.back(); + int64_t gather_size = 1; + for (int64_t i = end_size; i < x_shape_size; i++) { + gather_size *= x_shape.cpu[i]; + } + const int64_t gather_bytes = gather_size * sizeof(T); + for (int64_t i = 0; i < gather_time; i++) { + int64_t x_index = 0; + int64_t step = 1; + for (int64_t j = end_size - 1; j >= 0; j--) { + x_index += (index[i * end_size + j] * step); + step *= x_shape.cpu[j]; + } + memcpy(y, x + x_index * gather_size, gather_bytes); + y += gather_size; + } + return api::SUCCESS; +} + +template +static int xpu2_wrapper(Context* ctx, + const T* x, + const TID* index, + T* y, + const VectorParam& x_shape, + const std::vector& index_shape) { + using XPU_TID = typename XPUIndexType::type; + int64_t x_shape_size = x_shape.len; + int64_t index_shape_size = index_shape.size(); + int64_t end_size = index_shape.back(); + int64_t gather_time = 1; + for (int64_t i = 0; i < index_shape_size - 1; i++) { + gather_time *= index_shape[i]; + } + std::vector gather_strides(end_size); + gather_strides[end_size - 1] = sizeof(T); + for (int64_t i = end_size; i < x_shape_size; i++) { + gather_strides[end_size - 1] *= x_shape.cpu[i]; + } + for (int64_t i = end_size - 2; i >= 0; i--) { + gather_strides[i] = gather_strides[i + 1] * x_shape.cpu[i + 1]; + } + auto casted_x = static_cast(static_cast(x)); + auto casted_index = + static_cast(static_cast(index)); + auto casted_y = static_cast(static_cast(y)); + switch (end_size) { + case 1: + xpu2::plugin::fast_gather1d + <<ncluster(), 64, ctx->xpu_stream>>>(casted_x, + casted_index, + gather_time, + x_shape.cpu[0], + gather_strides[0], + casted_y); + return api::SUCCESS; + case 2: + xpu2::plugin::fast_gather2d + <<ncluster(), 64, ctx->xpu_stream>>>(casted_x, + casted_index, + gather_time, + x_shape.cpu[0], + x_shape.cpu[1], + gather_strides[0], + gather_strides[1], + casted_y); + return api::SUCCESS; + case 3: + xpu2::plugin::fast_gather3d + <<ncluster(), 64, ctx->xpu_stream>>>(casted_x, + casted_index, + gather_time, + x_shape.cpu[0], + x_shape.cpu[1], + x_shape.cpu[2], + gather_strides[0], + gather_strides[1], + gather_strides[2], + casted_y); + return api::SUCCESS; + case 4: + xpu2::plugin::fast_gather4d + <<ncluster(), 64, ctx->xpu_stream>>>(casted_x, + casted_index, + gather_time, + x_shape.cpu[0], + x_shape.cpu[1], + x_shape.cpu[2], + x_shape.cpu[3], + gather_strides[0], + gather_strides[1], + gather_strides[2], + gather_strides[3], + casted_y); + return api::SUCCESS; + defaut: + break; + } + return gather_nd(ctx, x, index, y, x_shape, index_shape); +} + +template +int fast_gather_nd(Context* ctx, + const T* x, + const TID* index, + T* y, + const VectorParam& x_shape, + const std::vector& index_shape) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T2(ctx, "fast_gather_nd", T, TID); + WRAPPER_DUMP_PARAM6( + ctx, x, index, y, x_shape, index_shape, ctx->_l3_mgr.get_size()); + WRAPPER_DUMP(ctx); + WRAPPER_ASSERT_GT(ctx, x_shape.len, 0); + WRAPPER_ASSERT_LE(ctx, x_shape.len, 32); + WRAPPER_ASSERT_GT(ctx, index_shape.size(), 0); + int64_t x_len = 1; + for (int64_t i = 0; i < x_shape.len; i++) { + x_len *= x_shape.cpu[i]; + } + WRAPPER_CHECK_PTR(ctx, T, x_len, x); + int64_t index_len = -1; + WRAPPER_CHECK_SHAPE(ctx, &index_len, index_shape); + WRAPPER_CHECK_PTR(ctx, TID, index_len, index); + // index.shape[-1] <= x.rank + WRAPPER_ASSERT_LE(ctx, index_shape.back(), x_shape.len); + std::vector y_shape; + for (int64_t i = 0; i < index_shape.size() - 1; i++) { + y_shape.push_back(index_shape[i]); + } + for (int64_t i = index_shape.back(); i < x_shape.len; i++) { + y_shape.push_back(x_shape.cpu[i]); + } + int64_t y_len = -1; + WRAPPER_CHECK_SHAPE(ctx, &y_len, y_shape); + WRAPPER_CHECK_PTR(ctx, T, y_len, y); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, x, index, y, x_shape, index_shape); + } + if (ctx->dev().type() == api::kXPU2) { + return xpu2_wrapper(ctx, x, index, y, x_shape, index_shape); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int fast_gather_nd(Context*, + const float*, + const int*, + float*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const int*, + const int*, + int*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const int64_t*, + const int*, + int64_t*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const float16*, + const int*, + float16*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const float*, + const int64_t*, + float*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const int*, + const int64_t*, + int*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const int64_t*, + const int64_t*, + int64_t*, + const VectorParam&, + const std::vector&); +template int fast_gather_nd(Context*, + const float16*, + const int64_t*, + float16*, + const VectorParam&, + const std::vector&); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu -- GitLab