Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
aec49361
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aec49361
编写于
6月 07, 2022
作者:
N
niuliling123
提交者:
GitHub
6月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU KP]Add xpu register, any, amax, amin op test (#43204)
上级
a2020d0c
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
315 addition
and
20 deletion
+315
-20
paddle/fluid/operators/reduce_ops/reduce_amax_op.kps
paddle/fluid/operators/reduce_ops/reduce_amax_op.kps
+13
-1
paddle/fluid/operators/reduce_ops/reduce_amin_op.kps
paddle/fluid/operators/reduce_ops/reduce_amin_op.kps
+13
-1
paddle/phi/kernels/funcs/reduce_function.h
paddle/phi/kernels/funcs/reduce_function.h
+8
-7
paddle/phi/kernels/kps/reduce_any_kernel.cu
paddle/phi/kernels/kps/reduce_any_kernel.cu
+5
-1
paddle/phi/kernels/kps/reduce_max_kernel.cu
paddle/phi/kernels/kps/reduce_max_kernel.cu
+0
-1
paddle/phi/kernels/kps/reduce_prod_kernel.cu
paddle/phi/kernels/kps/reduce_prod_kernel.cu
+5
-2
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+8
-7
paddle/phi/kernels/reduce_all_kernel.cc
paddle/phi/kernels/reduce_all_kernel.cc
+4
-0
paddle/phi/kernels/reduce_any_kernel.cc
paddle/phi/kernels/reduce_any_kernel.cc
+4
-0
paddle/phi/kernels/reduce_max_kernel.cc
paddle/phi/kernels/reduce_max_kernel.cc
+4
-0
paddle/phi/kernels/reduce_mean_kernel.cc
paddle/phi/kernels/reduce_mean_kernel.cc
+4
-0
paddle/phi/kernels/reduce_min_kernel.cc
paddle/phi/kernels/reduce_min_kernel.cc
+4
-0
paddle/phi/kernels/reduce_prod_kernel.cc
paddle/phi/kernels/reduce_prod_kernel.cc
+4
-0
paddle/phi/kernels/reduce_sum_kernel.cc
paddle/phi/kernels/reduce_sum_kernel.cc
+6
-0
python/paddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py
...ddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py
+67
-0
python/paddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py
...ddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py
+67
-0
python/paddle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py
...addle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py
+99
-0
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_amax_op.
cu
→
paddle/fluid/operators/reduce_ops/reduce_amax_op.
kps
浏览文件 @
aec49361
...
...
@@ -12,13 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/phi/core/kernel_registry.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
// reduce_max
#ifdef PADDLE_WITH_XPU_KP
REGISTER_OP_KERNEL(
reduce_amax, KP, plat::XPUPlace,
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>);
#else
REGISTER_OP_CUDA_KERNEL(
reduce_amax,
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MaxFunctor, kps::IdentityFunctor>);
#endif
paddle/fluid/operators/reduce_ops/reduce_amin_op.
cu
→
paddle/fluid/operators/reduce_ops/reduce_amin_op.
kps
浏览文件 @
aec49361
...
...
@@ -12,13 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/phi/core/kernel_registry.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
// reduce_min
#ifdef PADDLE_WITH_XPU_KP
REGISTER_OP_KERNEL(
reduce_amin, KP, plat::XPUPlace,
ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>);
#else
REGISTER_OP_CUDA_KERNEL(
reduce_amin,
ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MinFunctor, kps::IdentityFunctor>);
#endif
paddle/phi/kernels/funcs/reduce_function.h
浏览文件 @
aec49361
...
...
@@ -236,8 +236,9 @@ struct IndexCalculator {
template
<
bool
ReduceLastDim
=
false
>
struct
ReduceIndexMapping
{
const
kps
::
DimConfig
dim
;
HOSTDEVICE
explicit
ReduceIndexMapping
(
const
kps
::
DimConfig
&
dims
)
:
dim
(
dims
)
{}
int
loop_size
;
HOSTDEVICE
ReduceIndexMapping
(
const
kps
::
DimConfig
&
dims
,
int
max_loop
=
1
)
:
dim
(
dims
),
loop_size
(
max_loop
)
{}
#ifdef PADDLE_WITH_XPU_KP
__device__
__forceinline__
int
BlockIdX
()
{
...
...
@@ -277,10 +278,10 @@ struct ReduceIndexMapping {
}
__device__
__forceinline__
int
GetLoopSize
()
{
if
(
ReduceLastDim
)
{
return
dim
.
deal_size_y
;
}
else
{
if
((
!
ReduceLastDim
)
&&
(
loop_size
==
1
))
{
return
dim
.
deal_size_x
;
}
else
{
return
loop_size
;
}
}
#else
...
...
@@ -670,7 +671,7 @@ __global__ void ReduceAnyKernel(const Tx* x,
int
store_offset
=
0
;
int
stride_left
=
0
;
if
(
reduce_last_dim
)
{
auto
block
=
ReduceIndexMapping
<
true
>
(
dim
);
auto
block
=
ReduceIndexMapping
<
true
>
(
dim
,
left_num
);
input_idx
=
block
.
BlockIdY
()
*
block
.
BlockDimX
();
left_idx
=
block
.
BlockIdX
()
*
block
.
BlockDimY
()
+
THREAD_ID_Y
;
stride
=
block
.
GridDimY
()
*
block
.
BlockDimX
();
...
...
@@ -681,7 +682,7 @@ __global__ void ReduceAnyKernel(const Tx* x,
stride_left
=
1
;
tid
=
THREAD_ID_X
;
}
else
{
auto
block
=
ReduceIndexMapping
<
false
>
(
dim
);
auto
block
=
ReduceIndexMapping
<
false
>
(
dim
,
left_num
);
input_idx
=
block
.
BlockIdY
()
*
block
.
BlockDimY
();
left_idx
=
block
.
BlockIdX
()
*
block
.
BlockDimX
()
+
THREAD_ID_X
;
stride
=
block
.
GridDimY
()
*
block
.
BlockDimY
();
...
...
paddle/phi/kernels/
gpu
/reduce_any_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_any_kernel.cu
浏览文件 @
aec49361
...
...
@@ -32,4 +32,8 @@ void AnyRawKernel(const Context& dev_ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
any_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
AnyRawKernel
,
bool
)
{}
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
any_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
AnyRawKernel
,
bool
)
{}
#else
PD_REGISTER_KERNEL
(
any_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
AnyRawKernel
,
bool
)
{}
#endif
paddle/phi/kernels/kps/reduce_max_kernel.cu
浏览文件 @
aec49361
...
...
@@ -37,5 +37,4 @@ PD_REGISTER_KERNEL(max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float) {}
#else
PD_REGISTER_KERNEL
(
max_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
MaxRawKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
paddle/phi/kernels/
gpu
/reduce_prod_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_prod_kernel.cu
浏览文件 @
aec49361
...
...
@@ -31,12 +31,15 @@ void ProdRawKernel(const Context& dev_ctx,
}
}
// namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
prod_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
ProdRawKernel
,
float
)
{}
#else
PD_REGISTER_KERNEL
(
prod_raw
,
GPU
,
KPS
,
ALL_LAYOUT
,
phi
::
ProdRawKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
aec49361
...
...
@@ -48,7 +48,7 @@ static inline __device__ void sync_all() {
#define ncores 64
template
<
typename
T
,
typename
OpFunc
,
int
VecSize
>
__device__
void
BlockXReduce
(
T
*
data
,
OpFunc
reducer
)
{
__device__
void
BlockXReduce
(
T
*
out
,
const
T
*
data
,
OpFunc
reducer
)
{
__shared__
T
sum_array
[
ncores
*
VecSize
];
int
core_idx
=
core_id
()
*
VecSize
;
mfence
();
...
...
@@ -57,21 +57,22 @@ __device__ void BlockXReduce(T* data, OpFunc reducer) {
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
mfence
();
sum_array
[
core_idx
+
i
]
=
data
[
i
];
sum_array
[
i
*
ncores
+
core_idx
]
=
data
[
i
];
mfence
();
data
[
i
]
=
0
;
}
sync_all
();
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
i
++
)
{
T
start
=
data
[
i
*
ncores
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncores
;
j
++
)
{
for
(
int
j
=
1
;
j
<
ncores
;
j
++
)
{
mfence
();
T
tmp
=
sum_array
[
j
*
VecSize
+
i
];
T
tmp
=
sum_array
[
i
*
ncores
+
j
];
mfence
();
data
[
i
]
=
reducer
(
data
[
i
]
,
tmp
);
start
=
reducer
(
start
,
tmp
);
mfence
();
}
out
[
i
]
=
start
;
}
sync_all
();
}
...
...
@@ -346,7 +347,7 @@ __device__ __forceinline__ void Reduce(T* out,
if
(
reduce_last_dim
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
*
NX
;
i
++
)
{
// reduce along blockDim.x
details
::
BlockXReduce
<
T
,
ReduceFunctor
,
1
>
(
&
out
[
i
],
reducer
);
details
::
BlockXReduce
<
T
,
ReduceFunctor
,
1
>
(
&
out
[
i
],
&
in
[
i
],
reducer
);
}
}
}
else
{
// else kLocalMode
...
...
paddle/phi/kernels/reduce_all_kernel.cc
浏览文件 @
aec49361
...
...
@@ -36,3 +36,7 @@ PD_REGISTER_KERNEL(all, CPU, ALL_LAYOUT, phi::AllKernel, bool) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
all
,
GPU
,
ALL_LAYOUT
,
phi
::
AllKernel
,
bool
)
{}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
all
,
KPS
,
ALL_LAYOUT
,
phi
::
AllKernel
,
bool
)
{}
#endif
paddle/phi/kernels/reduce_any_kernel.cc
浏览文件 @
aec49361
...
...
@@ -36,3 +36,7 @@ PD_REGISTER_KERNEL(any, CPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
any
,
GPU
,
ALL_LAYOUT
,
phi
::
AnyKernel
,
bool
)
{}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
any
,
KPS
,
ALL_LAYOUT
,
phi
::
AnyKernel
,
bool
)
{}
#endif
paddle/phi/kernels/reduce_max_kernel.cc
浏览文件 @
aec49361
...
...
@@ -38,3 +38,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL
(
max
,
GPU
,
ALL_LAYOUT
,
phi
::
MaxKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
max
,
KPS
,
ALL_LAYOUT
,
phi
::
MaxKernel
,
float
)
{}
#endif
paddle/phi/kernels/reduce_mean_kernel.cc
浏览文件 @
aec49361
...
...
@@ -46,3 +46,7 @@ PD_REGISTER_KERNEL(mean,
int64_t
,
phi
::
dtype
::
float16
)
{}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
mean
,
KPS
,
ALL_LAYOUT
,
phi
::
MeanKernel
,
float
)
{}
#endif
paddle/phi/kernels/reduce_min_kernel.cc
浏览文件 @
aec49361
...
...
@@ -38,3 +38,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL
(
min
,
GPU
,
ALL_LAYOUT
,
phi
::
MinKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
min
,
KPS
,
ALL_LAYOUT
,
phi
::
MinKernel
,
float
)
{}
#endif
paddle/phi/kernels/reduce_prod_kernel.cc
浏览文件 @
aec49361
...
...
@@ -38,3 +38,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL
(
prod
,
GPU
,
ALL_LAYOUT
,
phi
::
ProdKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
prod
,
KPS
,
ALL_LAYOUT
,
phi
::
ProdKernel
,
float
)
{}
#endif
paddle/phi/kernels/reduce_sum_kernel.cc
浏览文件 @
aec49361
...
...
@@ -69,3 +69,9 @@ PD_REGISTER_KERNEL(sum,
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL
(
sum
,
KPS
,
ALL_LAYOUT
,
phi
::
SumKernel
,
float
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
#endif
python/paddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py
0 → 100644
浏览文件 @
aec49361
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceAmaxOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_amax'
class
XPUTestReduceAmaxBase
(
XPUOpTest
):
def
setUp
(
self
):
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_amax'
self
.
shape
=
(
20
,
10
)
self
.
attrs
=
{
'use_xpu'
:
True
,
'keep_dim'
:
False
,
'dim'
:
(
1
,
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
100
,
self
.
shape
).
astype
(
"float32"
)
}
expect_intput
=
self
.
inputs
[
'X'
]
self
.
outputs
=
{
'Out'
:
np
.
amax
(
expect_intput
,
axis
=
self
.
attrs
[
'dim'
],
keepdims
=
self
.
attrs
[
'keep_dim'
])
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
support_types
=
get_xpu_op_support_types
(
'reduce_amax'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceAmaxOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py
0 → 100644
浏览文件 @
aec49361
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceAmaxOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_amin'
class
XPUTestReduceAmaxBase
(
XPUOpTest
):
def
setUp
(
self
):
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_amin'
self
.
shape
=
(
20
,
10
)
self
.
attrs
=
{
'use_xpu'
:
True
,
'keep_dim'
:
False
,
'dim'
:
(
1
,
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
100
,
self
.
shape
).
astype
(
"float32"
)
}
expect_intput
=
self
.
inputs
[
'X'
]
self
.
outputs
=
{
'Out'
:
np
.
amin
(
expect_intput
,
axis
=
self
.
attrs
[
'dim'
],
keepdims
=
self
.
attrs
[
'keep_dim'
])
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
support_types
=
get_xpu_op_support_types
(
'reduce_amin'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceAmaxOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py
0 → 100644
浏览文件 @
aec49361
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceAnyOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_any'
class
XPUTestReduceAnyBase
(
XPUOpTest
):
def
setUp
(
self
):
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_any'
self
.
attrs
=
{
'use_xpu'
:
True
,
'reduce_all'
:
True
,
'keep_dim'
:
True
,
'dim'
:
(
3
,
5
,
4
)
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
2
,
5
,
3
,
2
,
2
,
3
,
4
,
2
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
any
(
axis
=
self
.
attrs
[
'dim'
])}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
pass
class
XPUTestReduceAnyCase1
(
XPUTestReduceAnyBase
):
def
set_case
(
self
):
self
.
op_type
=
'reduce_any'
self
.
attrs
=
{
'use_xpu'
:
True
,
'dim'
:
[
1
]
# 'reduce_all': True,
# 'keep_dim': True,
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
5
,
6
,
10
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
any
(
axis
=
1
)}
class
XPUTestReduceAnyCase2
(
XPUTestReduceAnyBase
):
def
set_case
(
self
):
self
.
op_type
=
'reduce_any'
self
.
attrs
=
{
'use_xpu'
:
True
,
'reduce_all'
:
True
,
'keep_dim'
:
False
,
'dim'
:
(
3
,
6
)
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
2
,
5
,
3
,
2
,
2
,
3
,
4
,
2
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
any
(
axis
=
self
.
attrs
[
'dim'
])}
support_types
=
get_xpu_op_support_types
(
'reduce_any'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceAnyOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录