Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
460e4fc6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
460e4fc6
编写于
8月 11, 2023
作者:
H
hong19860320
提交者:
GitHub
8月 11, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add fast_gather_nd plugin (#56103)
上级
dfe97dc8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
585 addition
and
0 deletion
+585
-0
paddle/phi/kernels/xpu/gather_nd_kernel.cc
paddle/phi/kernels/xpu/gather_nd_kernel.cc
+21
-0
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+24
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu
...rnels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu
+259
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp
+281
-0
未找到文件。
paddle/phi/kernels/xpu/gather_nd_kernel.cc
浏览文件 @
460e4fc6
...
...
@@ -87,6 +87,7 @@ void GatherNdKernel(const Context &ctx,
x_shape
.
data
(),
static_cast
<
int
>
(
x_shape
.
size
()),
nullptr
};
int
ret
=
XPU_SUCCESS
;
#ifndef PADDLE_WITH_XPU_PLUGIN
if
(
index_type
==
DataType
::
INT32
)
{
ret
=
xpu
::
gather_nd
<
XPUType
,
int
>
(
ctx
.
x_context
(),
...
...
@@ -105,6 +106,26 @@ void GatherNdKernel(const Context &ctx,
index_shape
);
}
PADDLE_ENFORCE_XDNN_SUCCESS
(
ret
,
"gather_nd"
);
#else
if
(
index_type
==
DataType
::
INT32
)
{
ret
=
xpu
::
plugin
::
fast_gather_nd
<
XPUType
,
int
>
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
x_vec
,
index_shape
);
}
else
{
ret
=
xpu
::
plugin
::
fast_gather_nd
<
XPUType
,
int64_t
>
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
x_vec
,
index_shape
);
}
PADDLE_ENFORCE_XDNN_SUCCESS
(
ret
,
"fast_gather_nd"
);
#endif
}
}
// namespace phi
...
...
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
460e4fc6
...
...
@@ -31,6 +31,30 @@ DLL_EXPORT int fast_where(Context* ctx,
const
T
*
y
,
T
*
out
,
int64_t
len
);
template
<
typename
T
,
typename
TID
>
DLL_EXPORT
int
fast_gather_nd
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
VectorParam
<
int64_t
>&
xshape
,
const
std
::
vector
<
int64_t
>&
index_shape
);
template
<
typename
T
,
typename
TID
>
static
inline
int
fast_gather_nd
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
VectorParam
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
index_shape
)
{
auto
deleter
=
[](
int64_t
*
ptr
)
{
delete
[]
ptr
;
};
std
::
shared_ptr
<
int64_t
>
xshape_i64
(
new
int64_t
[
xshape
.
len
],
deleter
);
return
fast_gather_nd
(
ctx
,
x
,
index
,
y
,
vpi32_to_vpi64
(
xshape
,
xshape_i64
.
get
()),
std
::
vector
<
int64_t
>
(
index_shape
.
begin
(),
index_shape
.
end
()));
}
}
// namespace plugin
}
// namespace api
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_gather_nd.xpu
0 → 100644
浏览文件 @
460e4fc6
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
template <typename TID>
__global__ void fast_gather1d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_stride0,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 320 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 5824 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride0 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i, local_index, sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0;
for (int64_t j = 0; j < x_stride0; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride0), x_stride0 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride0 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len, buf_len / x_stride0);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i, local_index, count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset = ((local_index[j] + x_dim0) % x_dim0) * x_stride0;
GM2LM_ASYNC(x + offset, local_x + j * x_stride0, x_stride0);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride0, x_stride0 * count_in_thread);
}
}
}
template <typename TID>
__global__ void fast_gather2d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_stride0,
int64_t x_stride1,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 640 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 5504 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride1 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i * 2, local_index, 2 * sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 +
((local_index[1] + x_dim1) % x_dim1) * x_stride1;
for (int64_t j = 0; j < x_stride1; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride1), x_stride1 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride1 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len / 2, buf_len / x_stride1);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i * 2, local_index, 2 * count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset =
((local_index[j * 2] + x_dim0) % x_dim0) * x_stride0 +
((local_index[j * 2 + 1] + x_dim1) % x_dim1) * x_stride1;
GM2LM_ASYNC(x + offset, local_x + j * x_stride1, x_stride1);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride1, x_stride1 * count_in_thread);
}
}
}
template <typename TID>
__global__ void fast_gather3d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_dim2,
int64_t x_stride0,
int64_t x_stride1,
int64_t x_stride2,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 960 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 5184 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride2 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i * 3, local_index, 3 * sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 +
((local_index[1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[2] + x_dim2) % x_dim2) * x_stride2;
for (int64_t j = 0; j < x_stride2; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride2), x_stride2 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride2 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len / 3, buf_len / x_stride2);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i * 3, local_index, 3 * count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset =
((local_index[j * 3] + x_dim0) % x_dim0) * x_stride0 +
((local_index[j * 3 + 1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[j * 3 + 2] + x_dim2) % x_dim2) * x_stride2;
GM2LM_ASYNC(x + offset, local_x + j * x_stride2, x_stride2);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride2, x_stride2 * count_in_thread);
}
}
}
template <typename TID>
__global__ void fast_gather4d(const int8_t* x,
const TID* index,
int64_t count,
int64_t x_dim0,
int64_t x_dim1,
int64_t x_dim2,
int64_t x_dim3,
int64_t x_stride0,
int64_t x_stride1,
int64_t x_stride2,
int64_t x_stride3,
int8_t* y) {
int cid = core_id();
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
const int index_len = 1280 / sizeof(TID);
__simd__ TID local_index[index_len];
const int buf_len = 4864 / sizeof(int8_t);
__simd__ int8_t local_x[buf_len];
if (x_stride3 > buf_len) {
for (int64_t i = tid; i < count; i += nthreads) {
GM2LM(index + i * 4, local_index, 4 * sizeof(TID));
int64_t offset = ((local_index[0] + x_dim0) % x_dim0) * x_stride0 +
((local_index[1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[2] + x_dim2) % x_dim2) * x_stride2 +
((local_index[3] + x_dim3) % x_dim3) * x_stride3;
for (int64_t j = 0; j < x_stride3; j += buf_len) {
int read_len = min(static_cast<int64_t>(x_stride3), x_stride3 - j);
GM2LM(x + offset + j, local_x, read_len);
LM2GM(local_x, y + i * x_stride3 + j, read_len);
}
}
} else {
int64_t count_per_thread = min(index_len / 4, buf_len / x_stride3);
for (int64_t i = tid * count_per_thread; i < count;
i += nthreads * count_per_thread) {
int count_in_thread =
min(static_cast<int64_t>(count_per_thread), count - i);
GM2LM(index + i * 4, local_index, 4 * count_in_thread * sizeof(TID));
for (int64_t j = 0; j < count_in_thread; j++) {
int64_t offset =
((local_index[j * 4] + x_dim0) % x_dim0) * x_stride0 +
((local_index[j * 4 + 1] + x_dim1) % x_dim1) * x_stride1 +
((local_index[j * 4 + 2] + x_dim2) % x_dim2) * x_stride2 +
((local_index[j * 4 + 3] + x_dim3) % x_dim3) * x_stride3;
GM2LM_ASYNC(x + offset, local_x + j * x_stride3, x_stride3);
}
mfence_lm();
LM2GM(local_x, y + i * x_stride3, x_stride3 * count_in_thread);
}
}
}
#define _XPU_DEF__FAST_GATHERND_(IDTYPE) \
template __global__ void fast_gather1d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_stride0, \
int8_t* y); \
template __global__ void fast_gather2d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_dim1, \
int64_t x_stride0, \
int64_t x_stride1, \
int8_t* y); \
template __global__ void fast_gather3d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_dim1, \
int64_t x_dim2, \
int64_t x_stride0, \
int64_t x_stride1, \
int64_t x_stride2, \
int8_t* y); \
template __global__ void fast_gather4d<IDTYPE>(const int8_t* x, \
const IDTYPE* index, \
int64_t count, \
int64_t x_dim0, \
int64_t x_dim1, \
int64_t x_dim2, \
int64_t x_dim3, \
int64_t x_stride0, \
int64_t x_stride1, \
int64_t x_stride2, \
int64_t x_stride3, \
int8_t* y);
_XPU_DEF__FAST_GATHERND_(int);
_XPU_DEF__FAST_GATHERND_(int8_t);
_XPU_DEF__FAST_GATHERND_(int64_t);
_XPU_DEF__FAST_GATHERND_(bool);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_gather_nd.cpp
0 → 100644
浏览文件 @
460e4fc6
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
TID
>
__attribute__
((
global
))
void
fast_gather1d
(
const
int8_t
*
x
,
const
TID
*
index
,
int64_t
count
,
int64_t
x_dim0
,
int64_t
x_stride0
,
int8_t
*
y
);
template
<
typename
TID
>
__attribute__
((
global
))
void
fast_gather2d
(
const
int8_t
*
x
,
const
TID
*
index
,
int64_t
count
,
int64_t
x_dim0
,
int64_t
x_dim1
,
int64_t
x_stride0
,
int64_t
x_stride1
,
int8_t
*
y
);
template
<
typename
TID
>
__attribute__
((
global
))
void
fast_gather3d
(
const
int8_t
*
x
,
const
TID
*
index
,
int64_t
count
,
int64_t
x_dim0
,
int64_t
x_dim1
,
int64_t
x_dim2
,
int64_t
x_stride0
,
int64_t
x_stride1
,
int64_t
x_stride2
,
int8_t
*
y
);
template
<
typename
TID
>
__attribute__
((
global
))
void
fast_gather4d
(
const
int8_t
*
x
,
const
TID
*
index
,
int64_t
count
,
int64_t
x_dim0
,
int64_t
x_dim1
,
int64_t
x_dim2
,
int64_t
x_dim3
,
int64_t
x_stride0
,
int64_t
x_stride1
,
int64_t
x_stride2
,
int64_t
x_stride3
,
int8_t
*
y
);
}
// namespace plugin
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
,
typename
TID
>
static
int
cpu_wrapper
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
VectorParam
<
int64_t
>&
x_shape
,
const
std
::
vector
<
int64_t
>&
index_shape
)
{
int64_t
x_shape_size
=
x_shape
.
len
;
int64_t
index_shape_size
=
index_shape
.
size
();
int64_t
gather_time
=
1
;
for
(
int64_t
i
=
0
;
i
<
index_shape_size
-
1
;
i
++
)
{
gather_time
*=
index_shape
[
i
];
}
int64_t
end_size
=
index_shape
.
back
();
int64_t
gather_size
=
1
;
for
(
int64_t
i
=
end_size
;
i
<
x_shape_size
;
i
++
)
{
gather_size
*=
x_shape
.
cpu
[
i
];
}
const
int64_t
gather_bytes
=
gather_size
*
sizeof
(
T
);
for
(
int64_t
i
=
0
;
i
<
gather_time
;
i
++
)
{
int64_t
x_index
=
0
;
int64_t
step
=
1
;
for
(
int64_t
j
=
end_size
-
1
;
j
>=
0
;
j
--
)
{
x_index
+=
(
index
[
i
*
end_size
+
j
]
*
step
);
step
*=
x_shape
.
cpu
[
j
];
}
memcpy
(
y
,
x
+
x_index
*
gather_size
,
gather_bytes
);
y
+=
gather_size
;
}
return
api
::
SUCCESS
;
}
template
<
typename
T
,
typename
TID
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
VectorParam
<
int64_t
>&
x_shape
,
const
std
::
vector
<
int64_t
>&
index_shape
)
{
using
XPU_TID
=
typename
XPUIndexType
<
TID
>::
type
;
int64_t
x_shape_size
=
x_shape
.
len
;
int64_t
index_shape_size
=
index_shape
.
size
();
int64_t
end_size
=
index_shape
.
back
();
int64_t
gather_time
=
1
;
for
(
int64_t
i
=
0
;
i
<
index_shape_size
-
1
;
i
++
)
{
gather_time
*=
index_shape
[
i
];
}
std
::
vector
<
int64_t
>
gather_strides
(
end_size
);
gather_strides
[
end_size
-
1
]
=
sizeof
(
T
);
for
(
int64_t
i
=
end_size
;
i
<
x_shape_size
;
i
++
)
{
gather_strides
[
end_size
-
1
]
*=
x_shape
.
cpu
[
i
];
}
for
(
int64_t
i
=
end_size
-
2
;
i
>=
0
;
i
--
)
{
gather_strides
[
i
]
=
gather_strides
[
i
+
1
]
*
x_shape
.
cpu
[
i
+
1
];
}
auto
casted_x
=
static_cast
<
const
int8_t
*>
(
static_cast
<
const
void
*>
(
x
));
auto
casted_index
=
static_cast
<
const
XPU_TID
*>
(
static_cast
<
const
void
*>
(
index
));
auto
casted_y
=
static_cast
<
int8_t
*>
(
static_cast
<
void
*>
(
y
));
switch
(
end_size
)
{
case
1
:
xpu2
::
plugin
::
fast_gather1d
<
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
casted_x
,
casted_index
,
gather_time
,
x_shape
.
cpu
[
0
],
gather_strides
[
0
],
casted_y
);
return
api
::
SUCCESS
;
case
2
:
xpu2
::
plugin
::
fast_gather2d
<
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
casted_x
,
casted_index
,
gather_time
,
x_shape
.
cpu
[
0
],
x_shape
.
cpu
[
1
],
gather_strides
[
0
],
gather_strides
[
1
],
casted_y
);
return
api
::
SUCCESS
;
case
3
:
xpu2
::
plugin
::
fast_gather3d
<
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
casted_x
,
casted_index
,
gather_time
,
x_shape
.
cpu
[
0
],
x_shape
.
cpu
[
1
],
x_shape
.
cpu
[
2
],
gather_strides
[
0
],
gather_strides
[
1
],
gather_strides
[
2
],
casted_y
);
return
api
::
SUCCESS
;
case
4
:
xpu2
::
plugin
::
fast_gather4d
<
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
casted_x
,
casted_index
,
gather_time
,
x_shape
.
cpu
[
0
],
x_shape
.
cpu
[
1
],
x_shape
.
cpu
[
2
],
x_shape
.
cpu
[
3
],
gather_strides
[
0
],
gather_strides
[
1
],
gather_strides
[
2
],
gather_strides
[
3
],
casted_y
);
return
api
::
SUCCESS
;
defaut:
break
;
}
return
gather_nd
(
ctx
,
x
,
index
,
y
,
x_shape
,
index_shape
);
}
template
<
typename
T
,
typename
TID
>
int
fast_gather_nd
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
VectorParam
<
int64_t
>&
x_shape
,
const
std
::
vector
<
int64_t
>&
index_shape
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T2
(
ctx
,
"fast_gather_nd"
,
T
,
TID
);
WRAPPER_DUMP_PARAM6
(
ctx
,
x
,
index
,
y
,
x_shape
,
index_shape
,
ctx
->
_l3_mgr
.
get_size
());
WRAPPER_DUMP
(
ctx
);
WRAPPER_ASSERT_GT
(
ctx
,
x_shape
.
len
,
0
);
WRAPPER_ASSERT_LE
(
ctx
,
x_shape
.
len
,
32
);
WRAPPER_ASSERT_GT
(
ctx
,
index_shape
.
size
(),
0
);
int64_t
x_len
=
1
;
for
(
int64_t
i
=
0
;
i
<
x_shape
.
len
;
i
++
)
{
x_len
*=
x_shape
.
cpu
[
i
];
}
WRAPPER_CHECK_PTR
(
ctx
,
T
,
x_len
,
x
);
int64_t
index_len
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
index_len
,
index_shape
);
WRAPPER_CHECK_PTR
(
ctx
,
TID
,
index_len
,
index
);
// index.shape[-1] <= x.rank
WRAPPER_ASSERT_LE
(
ctx
,
index_shape
.
back
(),
x_shape
.
len
);
std
::
vector
<
int64_t
>
y_shape
;
for
(
int64_t
i
=
0
;
i
<
index_shape
.
size
()
-
1
;
i
++
)
{
y_shape
.
push_back
(
index_shape
[
i
]);
}
for
(
int64_t
i
=
index_shape
.
back
();
i
<
x_shape
.
len
;
i
++
)
{
y_shape
.
push_back
(
x_shape
.
cpu
[
i
]);
}
int64_t
y_len
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
y_len
,
y_shape
);
WRAPPER_CHECK_PTR
(
ctx
,
T
,
y_len
,
y
);
if
(
ctx
->
dev
().
type
()
==
api
::
kCPU
)
{
return
cpu_wrapper
<
T
,
TID
>
(
ctx
,
x
,
index
,
y
,
x_shape
,
index_shape
);
}
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
,
TID
>
(
ctx
,
x
,
index
,
y
,
x_shape
,
index_shape
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
fast_gather_nd
(
Context
*
,
const
float
*
,
const
int
*
,
float
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
int
*
,
const
int
*
,
int
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
int64_t
*
,
const
int
*
,
int64_t
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
float16
*
,
const
int
*
,
float16
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
float
*
,
const
int64_t
*
,
float
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
int
*
,
const
int64_t
*
,
int
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
int64_t
*
,
const
int64_t
*
,
int64_t
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
template
int
fast_gather_nd
(
Context
*
,
const
float16
*
,
const
int64_t
*
,
float16
*
,
const
VectorParam
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录