未验证 提交 460e4fc6 编写于 作者: H hong19860320 提交者: GitHub

[XPU] Add fast_gather_nd plugin (#56103)

上级 dfe97dc8
......@@ -87,6 +87,7 @@ void GatherNdKernel(const Context &ctx,
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
int ret = XPU_SUCCESS;
#ifndef PADDLE_WITH_XPU_PLUGIN
if (index_type == DataType::INT32) {
ret = xpu::gather_nd<XPUType, int>(
ctx.x_context(),
......@@ -105,6 +106,26 @@ void GatherNdKernel(const Context &ctx,
index_shape);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather_nd");
#else
if (index_type == DataType::INT32) {
ret = xpu::plugin::fast_gather_nd<XPUType, int>(
ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType *>(out->data<T>()),
x_vec,
index_shape);
} else {
ret = xpu::plugin::fast_gather_nd<XPUType, int64_t>(
ctx.x_context(),
reinterpret_cast<const XPUType *>(x.data<T>()),
index.data<int64_t>(),
reinterpret_cast<XPUType *>(out->data<T>()),
x_vec,
index_shape);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "fast_gather_nd");
#endif
}
} // namespace phi
......
......@@ -31,6 +31,30 @@ DLL_EXPORT int fast_where(Context* ctx,
const T* y,
T* out,
int64_t len);
template <typename T, typename TID>
DLL_EXPORT int fast_gather_nd(Context* ctx,
const T* x,
const TID* index,
T* y,
const VectorParam<int64_t>& xshape,
const std::vector<int64_t>& index_shape);
template <typename T, typename TID>
static inline int fast_gather_nd(Context* ctx,
const T* x,
const TID* index,
T* y,
const VectorParam<int>& xshape,
const std::vector<int>& index_shape) {
auto deleter = [](int64_t* ptr) { delete[] ptr; };
std::shared_ptr<int64_t> xshape_i64(new int64_t[xshape.len], deleter);
return fast_gather_nd(
ctx,
x,
index,
y,
vpi32_to_vpi64(xshape, xshape_i64.get()),
std::vector<int64_t>(index_shape.begin(), index_shape.end()));
}
} // namespace plugin
} // namespace api
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
template <typename TID>
__global__ void fast_gather1d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_stride0,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 320 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 5824 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride0 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i, local_index, sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0;
for (int64_t j = 0; j < x_stride0; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride0), x_stride0 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride0 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len, buf_len / x_stride0);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i, local_index, count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset = ((local_index[j] + x_dim0) % x_dim0) * x_stride0;
GM2LM_ASYNC(x + offset, local_x + j * x_stride0, x_stride0);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride0, x_stride0 * count_in_thread);
}
}
}
template <typename TID>
__global__ void fast_gather2d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_stride0,
int64_t x_stride1,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 640 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 5504 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride1 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i * 2, local_index, 2 * sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 +
((local_index[1] + x_dim1) % x_dim1) * x_stride1;
for (int64_t j = 0; j < x_stride1; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride1), x_stride1 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride1 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len / 2, buf_len / x_stride1);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i * 2, local_index, 2 * count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset =
((local_index[j * 2] + x_dim0) % x_dim0) * x_stride0 +
((local_index[j * 2 + 1] + x_dim1) % x_dim1) * x_stride1;
GM2LM_ASYNC(x + offset, local_x + j * x_stride1, x_stride1);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride1, x_stride1 * count_in_thread);
}
}
}
template <typename TID>
__global__ void fast_gather3d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_dim2,
int64_t x_stride0,
int64_t x_stride1,
int64_t x_stride2,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 960 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 5184 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride2 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i * 3, local_index, 3 * sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 +
((local_index[1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[2] + x_dim2) % x_dim2) * x_stride2;
for (int64_t j = 0; j < x_stride2; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride2), x_stride2 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride2 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len / 3, buf_len / x_stride2);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i * 3, local_index, 3 * count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset =
((local_index[j * 3] + x_dim0) % x_dim0) * x_stride0 +
((local_index[j * 3 + 1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[j * 3 + 2] + x_dim2) % x_dim2) * x_stride2;
GM2LM_ASYNC(x + offset, local_x + j * x_stride2, x_stride2);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride2, x_stride2 * count_in_thread);
}
}
}
template <typename TID>
__global__ void fast_gather4d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_dim2,
int64_t x_dim3,
int64_t x_stride0,
int64_t x_stride1,
int64_t x_stride2,
int64_t x_stride3,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 1280 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 4864 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride3 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i * 4, local_index, 4 * sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 +
((local_index[1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[2] + x_dim2) % x_dim2) * x_stride2 +
((local_index[3] + x_dim3) % x_dim3) * x_stride3;
for (int64_t j = 0; j < x_stride3; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride3), x_stride3 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride3 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len / 4, buf_len / x_stride3);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i * 4, local_index, 4 * count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset =
((local_index[j * 4] + x_dim0) % x_dim0) * x_stride0 +
((local_index[j * 4 + 1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[j * 4 + 2] + x_dim2) % x_dim2) * x_stride2 +
((local_index[j * 4 + 3] + x_dim3) % x_dim3) * x_stride3;
GM2LM_ASYNC(x + offset, local_x + j * x_stride3, x_stride3);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride3, x_stride3 * count_in_thread);
}
}
}
#define _XPU_DEF__FAST_GATHERND_(IDTYPE) \
template __global__ void fast_gather1d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_stride0, \
int8_t* y); \
template __global__ void fast_gather2d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_dim1, \
int64_t x_stride0, \
int64_t x_stride1, \
int8_t* y); \
template __global__ void fast_gather3d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_dim1, \
int64_t x_dim2, \
int64_t x_stride0, \
int64_t x_stride1, \
int64_t x_stride2, \
int8_t* y); \
template __global__ void fast_gather4d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_dim1, \
int64_t x_dim2, \
int64_t x_dim3, \
int64_t x_stride0, \
int64_t x_stride1, \
int64_t x_stride2, \
int64_t x_stride3, \
int8_t* y);
_XPU_DEF__FAST_GATHERND_(int);
_XPU_DEF__FAST_GATHERND_(int8_t);
_XPU_DEF__FAST_GATHERND_(int64_t);
_XPU_DEF__FAST_GATHERND_(bool);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {
template <typename TID>
__attribute__((global)) void fast_gather1d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_stride0,
int8_t* y);
template <typename TID>
__attribute__((global)) void fast_gather2d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_stride0,
int64_t x_stride1,
int8_t* y);
template <typename TID>
__attribute__((global)) void fast_gather3d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_dim2,
int64_t x_stride0,
int64_t x_stride1,
int64_t x_stride2,
int8_t* y);
template <typename TID>
__attribute__((global)) void fast_gather4d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_dim2,
int64_t x_dim3,
int64_t x_stride0,
int64_t x_stride1,
int64_t x_stride2,
int64_t x_stride3,
int8_t* y);
} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T, typename TID>
static int cpu_wrapper(Context* ctx,
const T* x,
const TID* index,
T* y,
const VectorParam<int64_t>& x_shape,
const std::vector<int64_t>& index_shape) {
int64_t x_shape_size = x_shape.len;
int64_t index_shape_size = index_shape.size();
int64_t gather_time = 1;
for (int64_t i = 0; i < index_shape_size - 1; i++) {
gather_time *= index_shape[i];
}
int64_t end_size = index_shape.back();
int64_t gather_size = 1;
for (int64_t i = end_size; i < x_shape_size; i++) {
gather_size *= x_shape.cpu[i];
}
const int64_t gather_bytes = gather_size * sizeof(T);
for (int64_t i = 0; i < gather_time; i++) {
int64_t x_index = 0;
int64_t step = 1;
for (int64_t j = end_size - 1; j >= 0; j--) {
x_index += (index[i * end_size + j] * step);
step *= x_shape.cpu[j];
}
memcpy(y, x + x_index * gather_size, gather_bytes);
y += gather_size;
}
return api::SUCCESS;
}
template <typename T, typename TID>
static int xpu2_wrapper(Context* ctx,
const T* x,
const TID* index,
T* y,
const VectorParam<int64_t>& x_shape,
const std::vector<int64_t>& index_shape) {
using XPU_TID = typename XPUIndexType<TID>::type;
int64_t x_shape_size = x_shape.len;
int64_t index_shape_size = index_shape.size();
int64_t end_size = index_shape.back();
int64_t gather_time = 1;
for (int64_t i = 0; i < index_shape_size - 1; i++) {
gather_time *= index_shape[i];
}
std::vector<int64_t> gather_strides(end_size);
gather_strides[end_size - 1] = sizeof(T);
for (int64_t i = end_size; i < x_shape_size; i++) {
gather_strides[end_size - 1] *= x_shape.cpu[i];
}
for (int64_t i = end_size - 2; i >= 0; i--) {
gather_strides[i] = gather_strides[i + 1] * x_shape.cpu[i + 1];
}
auto casted_x = static_cast<const int8_t*>(static_cast<const void*>(x));
auto casted_index =
static_cast<const XPU_TID*>(static_cast<const void*>(index));
auto casted_y = static_cast<int8_t*>(static_cast<void*>(y));
switch (end_size) {
case 1:
xpu2::plugin::fast_gather1d<XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(casted_x,
casted_index,
gather_time,
x_shape.cpu[0],
gather_strides[0],
casted_y);
return api::SUCCESS;
case 2:
xpu2::plugin::fast_gather2d<XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(casted_x,
casted_index,
gather_time,
x_shape.cpu[0],
x_shape.cpu[1],
gather_strides[0],
gather_strides[1],
casted_y);
return api::SUCCESS;
case 3:
xpu2::plugin::fast_gather3d<XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(casted_x,
casted_index,
gather_time,
x_shape.cpu[0],
x_shape.cpu[1],
x_shape.cpu[2],
gather_strides[0],
gather_strides[1],
gather_strides[2],
casted_y);
return api::SUCCESS;
case 4:
xpu2::plugin::fast_gather4d<XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(casted_x,
casted_index,
gather_time,
x_shape.cpu[0],
x_shape.cpu[1],
x_shape.cpu[2],
x_shape.cpu[3],
gather_strides[0],
gather_strides[1],
gather_strides[2],
gather_strides[3],
casted_y);
return api::SUCCESS;
defaut:
break;
}
return gather_nd(ctx, x, index, y, x_shape, index_shape);
}
template <typename T, typename TID>
int fast_gather_nd(Context* ctx,
const T* x,
const TID* index,
T* y,
const VectorParam<int64_t>& x_shape,
const std::vector<int64_t>& index_shape) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "fast_gather_nd", T, TID);
WRAPPER_DUMP_PARAM6(
ctx, x, index, y, x_shape, index_shape, ctx->_l3_mgr.get_size());
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, x_shape.len, 0);
WRAPPER_ASSERT_LE(ctx, x_shape.len, 32);
WRAPPER_ASSERT_GT(ctx, index_shape.size(), 0);
int64_t x_len = 1;
for (int64_t i = 0; i < x_shape.len; i++) {
x_len *= x_shape.cpu[i];
}
WRAPPER_CHECK_PTR(ctx, T, x_len, x);
int64_t index_len = -1;
WRAPPER_CHECK_SHAPE(ctx, &index_len, index_shape);
WRAPPER_CHECK_PTR(ctx, TID, index_len, index);
// index.shape[-1] <= x.rank
WRAPPER_ASSERT_LE(ctx, index_shape.back(), x_shape.len);
std::vector<int64_t> y_shape;
for (int64_t i = 0; i < index_shape.size() - 1; i++) {
y_shape.push_back(index_shape[i]);
}
for (int64_t i = index_shape.back(); i < x_shape.len; i++) {
y_shape.push_back(x_shape.cpu[i]);
}
int64_t y_len = -1;
WRAPPER_CHECK_SHAPE(ctx, &y_len, y_shape);
WRAPPER_CHECK_PTR(ctx, T, y_len, y);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T, TID>(ctx, x, index, y, x_shape, index_shape);
}
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T, TID>(ctx, x, index, y, x_shape, index_shape);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int fast_gather_nd(Context*,
const float*,
const int*,
float*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const int*,
const int*,
int*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const int64_t*,
const int*,
int64_t*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const float16*,
const int*,
float16*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const float*,
const int64_t*,
float*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const int*,
const int64_t*,
int*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const int64_t*,
const int64_t*,
int64_t*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
template int fast_gather_nd(Context*,
const float16*,
const int64_t*,
float16*,
const VectorParam<int64_t>&,
const std::vector<int64_t>&);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册