未验证 提交 b2e06fc9 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add take_along_axis xpu kernel and plugin (#56125)

上级 46f9d9b7
...@@ -815,6 +815,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -815,6 +815,8 @@ XPUOpMap& get_kl2_ops() {
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"take_along_axis",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"tanh_grad", {"tanh_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"tanh", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"tanh", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
...@@ -32,6 +32,7 @@ DLL_EXPORT int fast_where(Context* ctx, ...@@ -32,6 +32,7 @@ DLL_EXPORT int fast_where(Context* ctx,
T* out, T* out,
int64_t len); int64_t len);
template <typename T, typename TID> template <typename T, typename TID>
DLL_EXPORT int fast_gather_nd(Context* ctx, DLL_EXPORT int fast_gather_nd(Context* ctx,
const T* x, const T* x,
const TID* index, const TID* index,
...@@ -56,6 +57,15 @@ static inline int fast_gather_nd(Context* ctx, ...@@ -56,6 +57,15 @@ static inline int fast_gather_nd(Context* ctx,
std::vector<int64_t>(index_shape.begin(), index_shape.end())); std::vector<int64_t>(index_shape.begin(), index_shape.end()));
} }
template <typename T, typename TID>
DLL_EXPORT int take_along_axis(Context* ctx,
const T* x,
const TID* index,
T* y,
const std::vector<int64_t>& xshape,
const std::vector<int64_t>& idxshape,
int64_t axis);
} // namespace plugin } // namespace plugin
} // namespace api } // namespace api
} // namespace xpu } // namespace xpu
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
template <typename T, typename TID>
__global__ void take_along_axis(const T* x,
const TID* indices,
T* y,
const int64_t* shape,
int64_t shape_size,
int64_t batch,
int64_t xlen,
int64_t ylen) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
__simd__ char lm_x[5 * sizeof(int64_t)];
__simd__ char lm_y[sizeof(T)];
__simd__ char lm_idx[sizeof(TID)];
__shared__ int64_t sm_shape[512];
if (cid == 0) {
GM2SM(shape, sm_shape, shape_size * sizeof(int64_t));
}
sync_all();
for (int64_t i = tid; i < batch * ylen; i += nthreads) {
GM2LM(indices + i, lm_idx, sizeof(TID));
TID idx = ((TID*)lm_idx)[0];
if (idx < 0) {
idx += xlen;
}
if (idx < xlen) {
GM2LM(x + i / ylen * xlen + idx, lm_y, sizeof(T));
LM2GM(lm_y, y + i, sizeof(T));
}
}
return;
}
#define _XPU_DEF__TAKE_ALONG_AXIS_(DTYPE, IDTYPE) \
template __global__ void take_along_axis<DTYPE, IDTYPE>( \
const DTYPE* x, \
const IDTYPE* indices, \
DTYPE* y, \
const int64_t* shape, \
int64_t shape_size, \
int64_t batch, \
int64_t xlen, \
int64_t ylen);
_XPU_DEF__TAKE_ALONG_AXIS_(float, int);
_XPU_DEF__TAKE_ALONG_AXIS_(float16, int);
_XPU_DEF__TAKE_ALONG_AXIS_(float, int64_t);
_XPU_DEF__TAKE_ALONG_AXIS_(float16, int64_t);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/refactor/util/vector_util.h"
namespace xpu2 {
namespace plugin {
template <typename T, typename TID>
__attribute__((global)) void take_along_axis(const T* x,
const TID* indices,
T* y,
const int64_t* shape,
int64_t shape_size,
int64_t batch,
int64_t xlen,
int64_t ylen);
} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T, typename TID>
static int cpu_wrapper(Context* ctx,
const T* x,
const TID* index,
T* y,
const std::vector<int64_t> xshape,
const std::vector<int64_t>& idxshape,
int64_t axis) {
int64_t ylen = vector_prod(idxshape);
for (int64_t i = 0; i < ylen; i++) {
std::vector<int64_t> sp_x_id = id_to_split_id(idxshape, i);
sp_x_id[axis] = index[i];
// -xshape[axis] <= index value < xshape[axis]
WRAPPER_ASSERT_LT(ctx, sp_x_id[axis], xshape[axis]);
WRAPPER_ASSERT_GE(ctx, sp_x_id[axis], -xshape[axis]);
if (sp_x_id[axis] < 0) {
sp_x_id[axis] += xshape[axis];
}
int64_t xid = split_id_to_id(xshape, sp_x_id);
y[i] = x[xid];
}
return SUCCESS;
}
template <typename T, typename TID>
static int xpu2_wrapper(Context* ctx,
const T* x,
const TID* index,
T* y,
const std::vector<int64_t> xshape,
const std::vector<int64_t>& idxshape,
int64_t axis) {
int64_t m_idx = 1;
int64_t shape_new_size = idxshape.size() - 1;
std::vector<int64_t> shape_new = xshape;
for (int64_t i = 0; i < axis; i++) {
m_idx *= idxshape[i];
}
for (int64_t i = axis + 1; i < xshape.size(); i++) {
shape_new[i - 1] = xshape[i];
}
int64_t t_x = xshape[axis];
int64_t t_idx = idxshape[axis];
int64_t n_idx = vector_prod(idxshape) / m_idx / t_idx;
if (m_idx < 64 && n_idx == 1) {
api::ctx_guard RAII_GUARD(ctx);
int64_t* shape_xpu = RAII_GUARD.alloc_l3_or_gm<int64_t>(shape_new_size);
WRAPPER_ASSERT_WORKSPACE(ctx, shape_xpu);
int ret = do_host2device(
ctx, shape_new.data(), shape_xpu, (shape_new_size) * sizeof(int64_t));
WRAPPER_ASSERT_SUCCESS(ctx, ret);
using XPU_TID = typename XPUIndexType<TID>::type;
const XPU_TID* casted_index =
static_cast<const XPU_TID*>(static_cast<const void*>(index));
xpu2::plugin::take_along_axis<T, XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
x,
casted_index,
y,
reinterpret_cast<xpu2::int64_t*>(shape_xpu),
shape_new_size,
m_idx,
t_x,
t_idx);
} else {
return gather_element(ctx, x, index, y, xshape, idxshape, axis);
}
return SUCCESS;
}
template <typename T, typename TID>
int take_along_axis(Context* ctx,
const T* x,
const TID* index,
T* y,
const std::vector<int64_t>& xshape,
const std::vector<int64_t>& idxshape,
int64_t axis) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "take_along_axis", T, TID);
WRAPPER_DUMP_PARAM6(ctx, x, index, y, xshape, idxshape, axis);
WRAPPER_DUMP(ctx);
int64_t xlen = -1;
WRAPPER_CHECK_SHAPE(ctx, &xlen, xshape);
WRAPPER_CHECK_PTR(ctx, T, xlen, x);
int64_t idxlen = -1;
WRAPPER_CHECK_SHAPE(ctx, &idxlen, idxshape);
WRAPPER_CHECK_PTR(ctx, TID, idxlen, index);
WRAPPER_CHECK_PTR(ctx, T, idxlen, y);
WRAPPER_ASSERT_EQ(
ctx,
xshape.size(),
idxshape.size()); // x and index tensor should have same rank
int64_t neg_rank = -xshape.size();
WRAPPER_ASSERT_GE(ctx, axis, neg_rank);
WRAPPER_ASSERT_LT(ctx, axis, xshape.size());
axis = (axis < 0) ? (axis + xshape.size()) : axis;
for (int64_t i = 0; i < xshape.size(); i++) {
if (i != axis) {
WRAPPER_ASSERT_EQ(ctx, xshape[i], idxshape[i]);
}
}
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T, TID>(ctx, x, index, y, xshape, idxshape, axis);
}
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T, TID>(ctx, x, index, y, xshape, idxshape, axis);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int take_along_axis(Context*,
const float*,
const int*,
float*,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t);
template int take_along_axis(Context*,
const float*,
const int64_t*,
float*,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t);
template int take_along_axis(Context*,
const float16*,
const int*,
float16*,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t);
template int take_along_axis(Context*,
const float16*,
const int64_t*,
float16*,
const std::vector<int64_t>&,
const std::vector<int64_t>&,
int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TakeAlongAxisKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
int axis,
DenseTensor* out) {
out->Resize(index.dims());
dev_ctx.template Alloc<T>(out);
if (x.numel() == 0 || index.numel() == 0) return;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
DataTypeToString(index_type),
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64)));
std::vector<int64_t> xshape(x.dims().size());
for (int i = 0; i < x.dims().size(); ++i) {
xshape[i] = x.dims()[i];
}
std::vector<int64_t> idxshape(index.dims().size());
for (int i = 0; i < index.dims().size(); ++i) {
idxshape[i] = index.dims()[i];
}
if (xshape.size() <= 1 && idxshape.size() <= 1) {
for (int i = xshape.size(); i < 2; ++i) {
xshape.push_back(1);
idxshape.push_back(1);
}
}
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG(WARNING) << "Add -DWITH_XPU_PLUGIN=ON to build "
"xpu::plugin::take_along_axis(), or use "
"xpu::gather_element() instead, which leads low performance "
"in some cases.";
if (index_type == DataType::INT32) {
r = xpu::gather_element<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
idxshape,
axis);
} else {
r = xpu::gather_element<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int64_t>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
idxshape,
axis);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_element");
#else
if (index_type == DataType::INT32) {
r = xpu::plugin::take_along_axis<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
idxshape,
axis);
} else {
r = xpu::plugin::take_along_axis<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int64_t>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
idxshape,
axis);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "take_along_axis");
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(take_along_axis,
XPU,
ALL_LAYOUT,
phi::TakeAlongAxisKernel,
phi::dtype::float16,
float) {}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
class XPUTestTakeAlongAxis(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'take_along_axis'
class TestXPUTakeAlongAxisOp(XPUOpTest):
def setUp(self):
self.op_type = "take_along_axis"
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.init_config()
xnp = np.random.random(self.x_shape).astype(self.dtype)
self.target = np.take_along_axis(xnp, self.index, self.axis)
broadcast_shape_list = list(self.x_shape)
broadcast_shape_list[self.axis] = self.index.shape[self.axis]
self.broadcast_shape = tuple(broadcast_shape_list)
self.index_broadcast = np.broadcast_to(
self.index, self.broadcast_shape
)
self.inputs = {
'Input': xnp,
'Index': self.index_broadcast,
}
self.attrs = {'Axis': self.axis}
self.outputs = {'Result': self.target}
def init_config(self):
self.in_type = np.float32
self.x_shape = (1, 4, 10)
self.index_type = np.int32
self.index = np.array([[[0, 1, 3, 5, 6]]]).astype(self.index_type)
self.axis = 2
def test_check_output(self):
if paddle.is_compiled_with_xpu():
self.check_output_with_place(self.place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
self.check_grad_with_place(self.place, ['Input'], 'Result')
class TestCase1(TestXPUTakeAlongAxisOp):
def init_config(self):
self.in_type = np.float32
self.x_shape = (1, 10, 100)
self.index_type = np.int32
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
self.axis = 2
class TestCase2(TestXPUTakeAlongAxisOp):
def init_config(self):
self.in_type = np.float32
self.x_shape = (1, 10, 100)
self.index_type = np.int64
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
self.axis = 2
class TestCase3(TestXPUTakeAlongAxisOp):
def init_config(self):
self.in_type = np.float16
self.x_shape = (1, 10, 100)
self.index_type = np.int32
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
self.axis = 2
class TestCase4(TestXPUTakeAlongAxisOp):
def init_config(self):
self.in_type = np.float16
self.x_shape = (1, 10, 100)
self.index_type = np.int64
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
self.axis = 2
class TestCase5(TestXPUTakeAlongAxisOp):
def init_config(self):
self.in_type = np.float32
self.x_shape = (1, 10, 100)
self.index_type = np.int32
self.index = np.array([[[0], [1], [3], [5], [8]]]).astype(
self.index_type
)
self.axis = 1
class XPUTestTakeAlongAxisAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.shape = [3, 3]
self.index_shape = [1, 3]
self.index_np = np.array([[0, 1, 2]]).astype('int64')
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.XPUPlace(0)]
self.axis = 0
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.shape)
index = paddle.static.data('Index', self.index_shape, "int64")
out = paddle.take_along_axis(x, index, self.axis)
exe = paddle.static.Executor(self.place[0])
res = exe.run(
feed={'X': self.x_np, 'Index': self.index_np}, fetch_list=[out]
)
out_ref = np.array(
np.take_along_axis(self.x_np, self.index_np, self.axis)
)
for out in res:
np.testing.assert_allclose(out, out_ref, rtol=0.001)
def test_api_dygraph(self):
paddle.disable_static(self.place[0])
x_tensor = paddle.to_tensor(self.x_np)
self.index = paddle.to_tensor(self.index_np)
out = paddle.take_along_axis(x_tensor, self.index, self.axis)
out_ref = np.array(
np.take_along_axis(self.x_np, self.index_np, self.axis)
)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)
paddle.enable_static()
class TestTakeAlongAxisAPICase1(XPUTestTakeAlongAxisAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 2]
self.index_shape = [4, 2]
self.index_np = np.array([[0, 0], [1, 0], [0, 0], [1, 0]]).astype(
'int64'
)
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.XPUPlace(0)]
self.axis = 0
support_types = get_xpu_op_support_types('take_along_axis')
for stype in support_types:
create_test_class(globals(), XPUTestTakeAlongAxis, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册