Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b2e06fc9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b2e06fc9
编写于
8月 14, 2023
作者:
J
jiangfan06
提交者:
GitHub
8月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add take_along_axis xpu kernel and plugin (#56125)
上级
46f9d9b7
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
579 addition
and
0 deletion
+579
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+2
-0
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+10
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu
...nels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu
+79
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp
...le/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp
+196
-0
paddle/phi/kernels/xpu/take_along_axis_kernel.cc
paddle/phi/kernels/xpu/take_along_axis_kernel.cc
+122
-0
test/xpu/test_take_along_axis_op_xpu.py
test/xpu/test_take_along_axis_op_xpu.py
+170
-0
未找到文件。
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
b2e06fc9
...
@@ -815,6 +815,8 @@ XPUOpMap& get_kl2_ops() {
...
@@ -815,6 +815,8 @@ XPUOpMap& get_kl2_ops() {
{
"sum"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"sum"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"swish"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"swish"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"swish_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"swish_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"take_along_axis"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"tanh_grad"
,
{
"tanh_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"tanh"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"tanh"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
...
...
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
b2e06fc9
...
@@ -32,6 +32,7 @@ DLL_EXPORT int fast_where(Context* ctx,
...
@@ -32,6 +32,7 @@ DLL_EXPORT int fast_where(Context* ctx,
T
*
out
,
T
*
out
,
int64_t
len
);
int64_t
len
);
template
<
typename
T
,
typename
TID
>
template
<
typename
T
,
typename
TID
>
DLL_EXPORT
int
fast_gather_nd
(
Context
*
ctx
,
DLL_EXPORT
int
fast_gather_nd
(
Context
*
ctx
,
const
T
*
x
,
const
T
*
x
,
const
TID
*
index
,
const
TID
*
index
,
...
@@ -56,6 +57,15 @@ static inline int fast_gather_nd(Context* ctx,
...
@@ -56,6 +57,15 @@ static inline int fast_gather_nd(Context* ctx,
std
::
vector
<
int64_t
>
(
index_shape
.
begin
(),
index_shape
.
end
()));
std
::
vector
<
int64_t
>
(
index_shape
.
begin
(),
index_shape
.
end
()));
}
}
template
<
typename
T
,
typename
TID
>
DLL_EXPORT
int
take_along_axis
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
std
::
vector
<
int64_t
>&
xshape
,
const
std
::
vector
<
int64_t
>&
idxshape
,
int64_t
axis
);
}
// namespace plugin
}
// namespace plugin
}
// namespace api
}
// namespace api
}
// namespace xpu
}
// namespace xpu
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu
0 → 100644
浏览文件 @
b2e06fc9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
template <typename T, typename TID>
__global__ void take_along_axis(const T* x,
const TID* indices,
T* y,
const int64_t* shape,
int64_t shape_size,
int64_t batch,
int64_t xlen,
int64_t ylen) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
__simd__ char lm_x[5 * sizeof(int64_t)];
__simd__ char lm_y[sizeof(T)];
__simd__ char lm_idx[sizeof(TID)];
__shared__ int64_t sm_shape[512];
if (cid == 0) {
GM2SM(shape, sm_shape, shape_size * sizeof(int64_t));
}
sync_all();
for (int64_t i = tid; i < batch * ylen; i += nthreads) {
GM2LM(indices + i, lm_idx, sizeof(TID));
TID idx = ((TID*)lm_idx)[0];
if (idx < 0) {
idx += xlen;
}
if (idx < xlen) {
GM2LM(x + i / ylen * xlen + idx, lm_y, sizeof(T));
LM2GM(lm_y, y + i, sizeof(T));
}
}
return;
}
#define _XPU_DEF__TAKE_ALONG_AXIS_(DTYPE, IDTYPE) \
template __global__ void take_along_axis<DTYPE, IDTYPE>( \
const DTYPE* x, \
const IDTYPE* indices, \
DTYPE* y, \
const int64_t* shape, \
int64_t shape_size, \
int64_t batch, \
int64_t xlen, \
int64_t ylen);
_XPU_DEF__TAKE_ALONG_AXIS_(float, int);
_XPU_DEF__TAKE_ALONG_AXIS_(float16, int);
_XPU_DEF__TAKE_ALONG_AXIS_(float, int64_t);
_XPU_DEF__TAKE_ALONG_AXIS_(float16, int64_t);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp
0 → 100644
浏览文件 @
b2e06fc9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/refactor/util/vector_util.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
T
,
typename
TID
>
__attribute__
((
global
))
void
take_along_axis
(
const
T
*
x
,
const
TID
*
indices
,
T
*
y
,
const
int64_t
*
shape
,
int64_t
shape_size
,
int64_t
batch
,
int64_t
xlen
,
int64_t
ylen
);
}
// namespace plugin
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
,
typename
TID
>
static
int
cpu_wrapper
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
std
::
vector
<
int64_t
>
xshape
,
const
std
::
vector
<
int64_t
>&
idxshape
,
int64_t
axis
)
{
int64_t
ylen
=
vector_prod
(
idxshape
);
for
(
int64_t
i
=
0
;
i
<
ylen
;
i
++
)
{
std
::
vector
<
int64_t
>
sp_x_id
=
id_to_split_id
(
idxshape
,
i
);
sp_x_id
[
axis
]
=
index
[
i
];
// -xshape[axis] <= index value < xshape[axis]
WRAPPER_ASSERT_LT
(
ctx
,
sp_x_id
[
axis
],
xshape
[
axis
]);
WRAPPER_ASSERT_GE
(
ctx
,
sp_x_id
[
axis
],
-
xshape
[
axis
]);
if
(
sp_x_id
[
axis
]
<
0
)
{
sp_x_id
[
axis
]
+=
xshape
[
axis
];
}
int64_t
xid
=
split_id_to_id
(
xshape
,
sp_x_id
);
y
[
i
]
=
x
[
xid
];
}
return
SUCCESS
;
}
template
<
typename
T
,
typename
TID
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
std
::
vector
<
int64_t
>
xshape
,
const
std
::
vector
<
int64_t
>&
idxshape
,
int64_t
axis
)
{
int64_t
m_idx
=
1
;
int64_t
shape_new_size
=
idxshape
.
size
()
-
1
;
std
::
vector
<
int64_t
>
shape_new
=
xshape
;
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
{
m_idx
*=
idxshape
[
i
];
}
for
(
int64_t
i
=
axis
+
1
;
i
<
xshape
.
size
();
i
++
)
{
shape_new
[
i
-
1
]
=
xshape
[
i
];
}
int64_t
t_x
=
xshape
[
axis
];
int64_t
t_idx
=
idxshape
[
axis
];
int64_t
n_idx
=
vector_prod
(
idxshape
)
/
m_idx
/
t_idx
;
if
(
m_idx
<
64
&&
n_idx
==
1
)
{
api
::
ctx_guard
RAII_GUARD
(
ctx
);
int64_t
*
shape_xpu
=
RAII_GUARD
.
alloc_l3_or_gm
<
int64_t
>
(
shape_new_size
);
WRAPPER_ASSERT_WORKSPACE
(
ctx
,
shape_xpu
);
int
ret
=
do_host2device
(
ctx
,
shape_new
.
data
(),
shape_xpu
,
(
shape_new_size
)
*
sizeof
(
int64_t
));
WRAPPER_ASSERT_SUCCESS
(
ctx
,
ret
);
using
XPU_TID
=
typename
XPUIndexType
<
TID
>::
type
;
const
XPU_TID
*
casted_index
=
static_cast
<
const
XPU_TID
*>
(
static_cast
<
const
void
*>
(
index
));
xpu2
::
plugin
::
take_along_axis
<
T
,
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
casted_index
,
y
,
reinterpret_cast
<
xpu2
::
int64_t
*>
(
shape_xpu
),
shape_new_size
,
m_idx
,
t_x
,
t_idx
);
}
else
{
return
gather_element
(
ctx
,
x
,
index
,
y
,
xshape
,
idxshape
,
axis
);
}
return
SUCCESS
;
}
template
<
typename
T
,
typename
TID
>
int
take_along_axis
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
index
,
T
*
y
,
const
std
::
vector
<
int64_t
>&
xshape
,
const
std
::
vector
<
int64_t
>&
idxshape
,
int64_t
axis
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T2
(
ctx
,
"take_along_axis"
,
T
,
TID
);
WRAPPER_DUMP_PARAM6
(
ctx
,
x
,
index
,
y
,
xshape
,
idxshape
,
axis
);
WRAPPER_DUMP
(
ctx
);
int64_t
xlen
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
xlen
,
xshape
);
WRAPPER_CHECK_PTR
(
ctx
,
T
,
xlen
,
x
);
int64_t
idxlen
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
idxlen
,
idxshape
);
WRAPPER_CHECK_PTR
(
ctx
,
TID
,
idxlen
,
index
);
WRAPPER_CHECK_PTR
(
ctx
,
T
,
idxlen
,
y
);
WRAPPER_ASSERT_EQ
(
ctx
,
xshape
.
size
(),
idxshape
.
size
());
// x and index tensor should have same rank
int64_t
neg_rank
=
-
xshape
.
size
();
WRAPPER_ASSERT_GE
(
ctx
,
axis
,
neg_rank
);
WRAPPER_ASSERT_LT
(
ctx
,
axis
,
xshape
.
size
());
axis
=
(
axis
<
0
)
?
(
axis
+
xshape
.
size
())
:
axis
;
for
(
int64_t
i
=
0
;
i
<
xshape
.
size
();
i
++
)
{
if
(
i
!=
axis
)
{
WRAPPER_ASSERT_EQ
(
ctx
,
xshape
[
i
],
idxshape
[
i
]);
}
}
if
(
ctx
->
dev
().
type
()
==
api
::
kCPU
)
{
return
cpu_wrapper
<
T
,
TID
>
(
ctx
,
x
,
index
,
y
,
xshape
,
idxshape
,
axis
);
}
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
,
TID
>
(
ctx
,
x
,
index
,
y
,
xshape
,
idxshape
,
axis
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
take_along_axis
(
Context
*
,
const
float
*
,
const
int
*
,
float
*
,
const
std
::
vector
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
,
int64_t
);
template
int
take_along_axis
(
Context
*
,
const
float
*
,
const
int64_t
*
,
float
*
,
const
std
::
vector
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
,
int64_t
);
template
int
take_along_axis
(
Context
*
,
const
float16
*
,
const
int
*
,
float16
*
,
const
std
::
vector
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
,
int64_t
);
template
int
take_along_axis
(
Context
*
,
const
float16
*
,
const
int64_t
*
,
float16
*
,
const
std
::
vector
<
int64_t
>
&
,
const
std
::
vector
<
int64_t
>&
,
int64_t
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
paddle/phi/kernels/xpu/take_along_axis_kernel.cc
0 → 100644
浏览文件 @
b2e06fc9
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
TakeAlongAxisKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
index
,
int
axis
,
DenseTensor
*
out
)
{
out
->
Resize
(
index
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
out
);
if
(
x
.
numel
()
==
0
||
index
.
numel
()
==
0
)
return
;
const
auto
&
index_type
=
index
.
dtype
();
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
errors
::
InvalidArgument
(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s"
,
DataTypeToString
(
index_type
),
DataTypeToString
(
DataType
::
INT32
),
DataTypeToString
(
DataType
::
INT64
)));
std
::
vector
<
int64_t
>
xshape
(
x
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
x
.
dims
().
size
();
++
i
)
{
xshape
[
i
]
=
x
.
dims
()[
i
];
}
std
::
vector
<
int64_t
>
idxshape
(
index
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
index
.
dims
().
size
();
++
i
)
{
idxshape
[
i
]
=
index
.
dims
()[
i
];
}
if
(
xshape
.
size
()
<=
1
&&
idxshape
.
size
()
<=
1
)
{
for
(
int
i
=
xshape
.
size
();
i
<
2
;
++
i
)
{
xshape
.
push_back
(
1
);
idxshape
.
push_back
(
1
);
}
}
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
int
r
=
XPU_SUCCESS
;
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG
(
WARNING
)
<<
"Add -DWITH_XPU_PLUGIN=ON to build "
"xpu::plugin::take_along_axis(), or use "
"xpu::gather_element() instead, which leads low performance "
"in some cases."
;
if
(
index_type
==
DataType
::
INT32
)
{
r
=
xpu
::
gather_element
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
xshape
,
idxshape
,
axis
);
}
else
{
r
=
xpu
::
gather_element
<
XPUType
,
int64_t
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
xshape
,
idxshape
,
axis
);
}
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"gather_element"
);
#else
if
(
index_type
==
DataType
::
INT32
)
{
r
=
xpu
::
plugin
::
take_along_axis
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
xshape
,
idxshape
,
axis
);
}
else
{
r
=
xpu
::
plugin
::
take_along_axis
<
XPUType
,
int64_t
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
xshape
,
idxshape
,
axis
);
}
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"take_along_axis"
);
#endif
}
}
// namespace phi
PD_REGISTER_KERNEL
(
take_along_axis
,
XPU
,
ALL_LAYOUT
,
phi
::
TakeAlongAxisKernel
,
phi
::
dtype
::
float16
,
float
)
{}
test/xpu/test_take_along_axis_op_xpu.py
0 → 100644
浏览文件 @
b2e06fc9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
get_test_cover_info
import
(
XPUOpTestWrapper
,
create_test_class
,
get_xpu_op_support_types
,
)
from
op_test_xpu
import
XPUOpTest
import
paddle
paddle
.
enable_static
()
class
XPUTestTakeAlongAxis
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'take_along_axis'
class
TestXPUTakeAlongAxisOp
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
"take_along_axis"
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
dtype
=
self
.
in_type
self
.
init_config
()
xnp
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
self
.
target
=
np
.
take_along_axis
(
xnp
,
self
.
index
,
self
.
axis
)
broadcast_shape_list
=
list
(
self
.
x_shape
)
broadcast_shape_list
[
self
.
axis
]
=
self
.
index
.
shape
[
self
.
axis
]
self
.
broadcast_shape
=
tuple
(
broadcast_shape_list
)
self
.
index_broadcast
=
np
.
broadcast_to
(
self
.
index
,
self
.
broadcast_shape
)
self
.
inputs
=
{
'Input'
:
xnp
,
'Index'
:
self
.
index_broadcast
,
}
self
.
attrs
=
{
'Axis'
:
self
.
axis
}
self
.
outputs
=
{
'Result'
:
self
.
target
}
def
init_config
(
self
):
self
.
in_type
=
np
.
float32
self
.
x_shape
=
(
1
,
4
,
10
)
self
.
index_type
=
np
.
int32
self
.
index
=
np
.
array
([[[
0
,
1
,
3
,
5
,
6
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
def
test_check_output
(
self
):
if
paddle
.
is_compiled_with_xpu
():
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
if
paddle
.
is_compiled_with_xpu
():
self
.
check_grad_with_place
(
self
.
place
,
[
'Input'
],
'Result'
)
class
TestCase1
(
TestXPUTakeAlongAxisOp
):
def
init_config
(
self
):
self
.
in_type
=
np
.
float32
self
.
x_shape
=
(
1
,
10
,
100
)
self
.
index_type
=
np
.
int32
self
.
index
=
np
.
array
([[[
0
,
1
,
3
,
5
,
13
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
class
TestCase2
(
TestXPUTakeAlongAxisOp
):
def
init_config
(
self
):
self
.
in_type
=
np
.
float32
self
.
x_shape
=
(
1
,
10
,
100
)
self
.
index_type
=
np
.
int64
self
.
index
=
np
.
array
([[[
0
,
1
,
3
,
5
,
13
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
class
TestCase3
(
TestXPUTakeAlongAxisOp
):
def
init_config
(
self
):
self
.
in_type
=
np
.
float16
self
.
x_shape
=
(
1
,
10
,
100
)
self
.
index_type
=
np
.
int32
self
.
index
=
np
.
array
([[[
0
,
1
,
3
,
5
,
13
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
class
TestCase4
(
TestXPUTakeAlongAxisOp
):
def
init_config
(
self
):
self
.
in_type
=
np
.
float16
self
.
x_shape
=
(
1
,
10
,
100
)
self
.
index_type
=
np
.
int64
self
.
index
=
np
.
array
([[[
0
,
1
,
3
,
5
,
13
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
2
class
TestCase5
(
TestXPUTakeAlongAxisOp
):
def
init_config
(
self
):
self
.
in_type
=
np
.
float32
self
.
x_shape
=
(
1
,
10
,
100
)
self
.
index_type
=
np
.
int32
self
.
index
=
np
.
array
([[[
0
],
[
1
],
[
3
],
[
5
],
[
8
]]]).
astype
(
self
.
index_type
)
self
.
axis
=
1
class
XPUTestTakeAlongAxisAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
0
)
self
.
shape
=
[
3
,
3
]
self
.
index_shape
=
[
1
,
3
]
self
.
index_np
=
np
.
array
([[
0
,
1
,
2
]]).
astype
(
'int64'
)
self
.
x_np
=
np
.
random
.
random
(
self
.
shape
).
astype
(
np
.
float32
)
self
.
place
=
[
paddle
.
XPUPlace
(
0
)]
self
.
axis
=
0
def
test_api_static
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
'X'
,
self
.
shape
)
index
=
paddle
.
static
.
data
(
'Index'
,
self
.
index_shape
,
"int64"
)
out
=
paddle
.
take_along_axis
(
x
,
index
,
self
.
axis
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
[
0
])
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
,
'Index'
:
self
.
index_np
},
fetch_list
=
[
out
]
)
out_ref
=
np
.
array
(
np
.
take_along_axis
(
self
.
x_np
,
self
.
index_np
,
self
.
axis
)
)
for
out
in
res
:
np
.
testing
.
assert_allclose
(
out
,
out_ref
,
rtol
=
0.001
)
def
test_api_dygraph
(
self
):
paddle
.
disable_static
(
self
.
place
[
0
])
x_tensor
=
paddle
.
to_tensor
(
self
.
x_np
)
self
.
index
=
paddle
.
to_tensor
(
self
.
index_np
)
out
=
paddle
.
take_along_axis
(
x_tensor
,
self
.
index
,
self
.
axis
)
out_ref
=
np
.
array
(
np
.
take_along_axis
(
self
.
x_np
,
self
.
index_np
,
self
.
axis
)
)
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
out_ref
,
rtol
=
0.001
)
paddle
.
enable_static
()
class
TestTakeAlongAxisAPICase1
(
XPUTestTakeAlongAxisAPI
):
def
setUp
(
self
):
np
.
random
.
seed
(
0
)
self
.
shape
=
[
2
,
2
]
self
.
index_shape
=
[
4
,
2
]
self
.
index_np
=
np
.
array
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
],
[
1
,
0
]]).
astype
(
'int64'
)
self
.
x_np
=
np
.
random
.
random
(
self
.
shape
).
astype
(
np
.
float32
)
self
.
place
=
[
paddle
.
XPUPlace
(
0
)]
self
.
axis
=
0
support_types
=
get_xpu_op_support_types
(
'take_along_axis'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestTakeAlongAxis
,
stype
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录