Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b3959fe4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b3959fe4
编写于
4月 18, 2022
作者:
L
Lijunhui
提交者:
GitHub
4月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[KP] Add Reduce op registry & UT for xpu_kp compilation (#41869)
上级
14c35a58
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
378 addition
and
199 deletion
+378
-199
paddle/fluid/framework/new_executor/standalone_executor_test.cc
.../fluid/framework/new_executor/standalone_executor_test.cc
+2
-1
paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h
paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h
+8
-0
paddle/fluid/platform/fast_divmod.h
paddle/fluid/platform/fast_divmod.h
+3
-0
paddle/phi/kernels/funcs/aligned_vector.h
paddle/phi/kernels/funcs/aligned_vector.h
+3
-0
paddle/phi/kernels/funcs/reduce_function.h
paddle/phi/kernels/funcs/reduce_function.h
+7
-2
paddle/phi/kernels/gpu/reduce.h
paddle/phi/kernels/gpu/reduce.h
+11
-1
paddle/phi/kernels/kps/reduce_all_kernel.cu
paddle/phi/kernels/kps/reduce_all_kernel.cu
+5
-1
paddle/phi/kernels/kps/reduce_max_kernel.cu
paddle/phi/kernels/kps/reduce_max_kernel.cu
+6
-1
paddle/phi/kernels/kps/reduce_mean_kernel.cu
paddle/phi/kernels/kps/reduce_mean_kernel.cu
+5
-1
paddle/phi/kernels/kps/reduce_min_kernel.cu
paddle/phi/kernels/kps/reduce_min_kernel.cu
+5
-1
paddle/phi/kernels/kps/reduce_sum_kernel.cu
paddle/phi/kernels/kps/reduce_sum_kernel.cu
+7
-1
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+1
-1
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
+12
-11
paddle/phi/kernels/primitive/functor_primitives_xpu2.h
paddle/phi/kernels/primitive/functor_primitives_xpu2.h
+7
-0
python/paddle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py
...addle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py
+111
-0
python/paddle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py
...addle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py
+52
-44
python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py
...ddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py
+0
-8
python/paddle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py
...addle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py
+81
-0
python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py
...addle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py
+52
-126
未找到文件。
paddle/fluid/framework/new_executor/standalone_executor_test.cc
浏览文件 @
b3959fe4
...
...
@@ -71,8 +71,10 @@ PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL
(
matmul
,
GPU
,
ALL_LAYOUT
);
#ifdef PADDLE_WITH_XPU_KP
PD_DECLARE_KERNEL
(
add_raw
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
max_raw
,
GPU
,
ALL_LAYOUT
);
#else
PD_DECLARE_KERNEL
(
add_raw
,
KPS
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
max_raw
,
KPS
,
ALL_LAYOUT
);
#endif
PD_DECLARE_KERNEL
(
add
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
mean
,
GPU
,
ALL_LAYOUT
);
...
...
@@ -85,7 +87,6 @@ PD_DECLARE_KERNEL(matmul_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL
(
transpose_grad
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
sum
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
sum_grad
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
max_raw
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
sgd
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
slice
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
slice_grad
,
GPU
,
ALL_LAYOUT
);
...
...
paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h
浏览文件 @
b3959fe4
...
...
@@ -97,6 +97,14 @@ XPUOpMap& get_kp_ops() {
XPUKernelSet
({
pOpKernelType
(
vartype
::
INT32
,
XPUPlace
())})},
{
"equal"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
INT32
,
XPUPlace
())})},
{
"not_equal"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
INT32
,
XPUPlace
())})},
// reduce op
{
"reduce_mean"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_max"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_min"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_sum"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_prod"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_all"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
BOOL
,
XPUPlace
())})},
{
"reduce_any"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
BOOL
,
XPUPlace
())})},
};
return
s_xpu_kp_kernels
;
...
...
paddle/fluid/platform/fast_divmod.h
浏览文件 @
b3959fe4
...
...
@@ -18,6 +18,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#define INT_BITS 32
#if defined(__xpu__)
#define __forceinline__ __inline__
#endif
namespace
paddle
{
namespace
platform
{
...
...
paddle/phi/kernels/funcs/aligned_vector.h
浏览文件 @
b3959fe4
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/phi/core/hostdevice.h"
#if defined(__xpu__)
#define CHAR_BIT 8
#endif
namespace
phi
{
...
...
paddle/phi/kernels/funcs/reduce_function.h
浏览文件 @
b3959fe4
...
...
@@ -33,10 +33,14 @@
namespace
cub
=
hipcub
;
#endif
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/array.h"
...
...
@@ -183,7 +187,7 @@ struct IndexCalculator {
strides
=
details
::
VectorToArray
<
int
,
kMaxRank
>
(
full_strides
);
reduce_strides
=
details
::
VectorToArray
<
int
,
kMaxRank
>
(
cal_strides
);
#ifndef PADDLE_WITH_XPU_KP
std
::
vector
<
kps
::
details
::
FastDivMod
>
cal_divmoders
;
// namespace
std
::
vector
<
kps
::
details
::
FastDivMod
>
cal_divmoders
;
// fast divmod
for
(
auto
i
:
cal_strides
)
{
cal_divmoders
.
push_back
(
kps
::
details
::
FastDivMod
(
i
));
...
...
@@ -325,9 +329,10 @@ struct ReduceConfig {
// step4: set the block and grid for launch kernel
SetBlockDim
();
#ifndef PADDLE_WITH_XPU_KP
// step5: limit the grid to prevent thead overflow
paddle
::
platform
::
LimitGridDim
(
dev_ctx
,
&
grid
);
#endif
}
// when should_reduce_again is true, we need malloc temp space for temp data
...
...
paddle/phi/kernels/gpu/reduce.h
浏览文件 @
b3959fe4
...
...
@@ -41,7 +41,7 @@ void Reduce(const KPDevice& dev_ctx,
for
(
auto
i
:
reduce_dims
)
{
reduce_num
*=
(
x
.
dims
())[
i
];
}
#ifndef PADDLE_WITH_XPU_KP
if
(
out_dtype
!=
phi
::
DataType
::
UNDEFINED
&&
out_dtype
!=
x
.
dtype
())
{
auto
tmp_tensor
=
phi
::
Cast
<
T
>
(
dev_ctx
,
x
,
out_dtype
);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES
(
...
...
@@ -73,6 +73,16 @@ void Reduce(const KPDevice& dev_ctx,
reduce_dims
,
is_mean
);
}
#else
using
MPType
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
ReduceOp
,
TransformOp
<
T
,
MPType
>>
(
dev_ctx
,
x
,
out
,
TransformOp
<
T
,
MPType
>
(
reduce_num
),
reduce_dims
,
is_mean
);
#endif
}
}
// namespace phi
...
...
paddle/phi/kernels/
gpu
/reduce_all_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_all_kernel.cu
浏览文件 @
b3959fe4
...
...
@@ -33,4 +33,8 @@ void AllRawKernel(const Context& dev_ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
all_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
AllRawKernel
,
bool
)
{}
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
all_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
AllRawKernel
,
bool
)
{}
#else
PD_REGISTER_KERNEL
(
all_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
AllRawKernel
,
bool
)
{}
#endif
paddle/phi/kernels/
gpu
/reduce_max_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_max_kernel.cu
浏览文件 @
b3959fe4
...
...
@@ -33,5 +33,10 @@ void MaxRawKernel(const Context& dev_ctx,
}
// namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
max_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
MaxRawKernel
,
float
)
{}
#else
PD_REGISTER_KERNEL
(
max_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
MaxRawKernel
,
float
,
double
,
int
,
int64_t
)
{}
max_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
MaxRawKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
paddle/phi/kernels/
gpu
/reduce_mean_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_mean_kernel.cu
浏览文件 @
b3959fe4
...
...
@@ -33,10 +33,13 @@ void MeanRawKernel(const Context& dev_ctx,
}
// namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
mean_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
MeanRawKernel
,
float
)
{}
#else
using
float16
=
phi
::
dtype
::
float16
;
PD_REGISTER_KERNEL
(
mean_raw
,
GPU
,
KPS
,
ALL_LAYOUT
,
phi
::
MeanRawKernel
,
float
,
...
...
@@ -45,3 +48,4 @@ PD_REGISTER_KERNEL(mean_raw,
float16
,
int
,
int64_t
)
{}
#endif
paddle/phi/kernels/
gpu
/reduce_min_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_min_kernel.cu
浏览文件 @
b3959fe4
...
...
@@ -33,5 +33,9 @@ void MinRawKernel(const Context& dev_ctx,
}
// namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
min_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
MinRawKernel
,
float
)
{}
#else
PD_REGISTER_KERNEL
(
min_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
MinRawKernel
,
float
,
double
,
int
,
int64_t
)
{}
min_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
MinRawKernel
,
float
,
double
,
int
,
int64_t
)
{}
#endif
paddle/phi/kernels/
gpu
/reduce_sum_kernel.cu
→
paddle/phi/kernels/
kps
/reduce_sum_kernel.cu
浏览文件 @
b3959fe4
...
...
@@ -33,13 +33,18 @@ void SumRawKernel(const Context& dev_ctx,
}
// namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL
(
sum_raw
,
KPS
,
ALL_LAYOUT
,
phi
::
SumRawKernel
,
float
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
#else
using
float16
=
phi
::
dtype
::
float16
;
using
bfloat16
=
phi
::
dtype
::
bfloat16
;
using
complex64
=
::
phi
::
dtype
::
complex
<
float
>
;
using
complex128
=
::
phi
::
dtype
::
complex
<
double
>
;
PD_REGISTER_KERNEL
(
sum_raw
,
GPU
,
KPS
,
ALL_LAYOUT
,
phi
::
SumRawKernel
,
bool
,
...
...
@@ -54,3 +59,4 @@ PD_REGISTER_KERNEL(sum_raw,
complex128
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
paddle
::
experimental
::
DataType
::
UNDEFINED
);
}
#endif
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
b3959fe4
...
...
@@ -336,7 +336,7 @@ __device__ __forceinline__ void Reduce(T* out,
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
}
}
BlockXReduce
<
T
,
ReduceFunctor
,
NY
>
(
out
,
reducer
);
details
::
BlockXReduce
<
T
,
ReduceFunctor
,
NY
>
(
out
,
reducer
);
}
else
{
// else kLocalMode
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
...
...
paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
浏览文件 @
b3959fe4
...
...
@@ -77,7 +77,7 @@ struct BroadcastConfig {
#pragma pack()
template
<
typename
T
>
__device__
__forceinline__
void
WriteData
(
T
*
_global_ptr_
dst
,
__device__
__forceinline__
void
WriteData
(
T
_global_ptr_
*
dst
,
T
*
src
,
int
num
)
{
if
(
num
>
0
)
{
...
...
@@ -403,8 +403,9 @@ template <typename Tx,
typename
IndexCal
,
typename
Functor
,
bool
IsBoundary
=
false
>
__device__
__forceinline__
void
ReadDataReduce
(
Ty
*
dst
,
const
Tx
*
__restrict__
src
,
__device__
__forceinline__
void
ReadDataReduce
(
Ty
*
dst
,
const
Tx
_global_ptr_
*
__restrict__
src
,
int
block_offset
,
const
IndexCal
&
index_cal
,
int
size_nx
,
...
...
paddle/phi/kernels/primitive/functor_primitives_xpu2.h
浏览文件 @
b3959fe4
...
...
@@ -25,6 +25,12 @@ namespace kps {
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityFunctor
{
#ifdef PADDLE_WITH_XPU_KP
HOSTDEVICE
inline
IdentityFunctor
()
{}
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
HOSTDEVICE
Ty
operator
()(
const
Tx
x
)
const
{
return
static_cast
<
Ty
>
(
x
);
}
HOSTDEVICE
inline
void
SetDiv
(
int
n
)
{}
#else
inline
IdentityFunctor
()
{}
explicit
inline
IdentityFunctor
(
int
n
)
{}
...
...
@@ -38,6 +44,7 @@ struct IdentityFunctor {
return
static_cast
<
Ty
>
(
x
);
}
__device__
inline
void
SetDiv
(
int
n
)
{}
#endif
};
/**
...
...
python/paddle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py
0 → 100644
浏览文件 @
b3959fe4
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceAllOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_all'
class
XPUTestReduceAllBase
(
XPUOpTest
):
def
setUp
(
self
):
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_all'
self
.
attrs
=
{
'use_xpu'
:
True
,
'reduce_all'
:
True
,
'keep_dim'
:
True
,
'dim'
:
(
3
,
5
,
4
)
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
2
,
5
,
3
,
2
,
2
,
3
,
4
,
2
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
all
(
axis
=
self
.
attrs
[
'dim'
])}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
pass
class
XPUTestReduceAllCase1
(
XPUTestReduceAllBase
):
def
set_case
(
self
):
self
.
op_type
=
'reduce_all'
self
.
attrs
=
{
'use_xpu'
:
True
,
'reduce_all'
:
True
,
'keep_dim'
:
True
,
'dim'
:
[
1
]
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
5
,
6
,
10
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
all
()}
class
XPUTestReduceAllCase2
(
XPUTestReduceAllBase
):
def
set_case
(
self
):
self
.
op_type
=
'reduce_all'
self
.
attrs
=
{
'use_xpu'
:
True
,
'reduce_all'
:
True
,
'keep_dim'
:
False
,
'dim'
:
(
3
,
6
)
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
2
,
5
,
3
,
2
,
2
,
3
,
4
,
2
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
all
(
axis
=
self
.
attrs
[
'dim'
])}
class
XPUTestReduceAllCase3
(
XPUTestReduceAllBase
):
def
set_case
(
self
):
self
.
op_type
=
'reduce_all'
self
.
attrs
=
{
'use_xpu'
:
True
,
'keep_dim'
:
True
,
'dim'
:
[
1
]
# 'reduce_all': True,
}
self
.
inputs
=
{
'X'
:
np
.
random
.
randint
(
0
,
2
,
(
5
,
6
,
10
)).
astype
(
"bool"
)
}
self
.
outputs
=
{
'Out'
:
np
.
expand_dims
(
self
.
inputs
[
'X'
].
all
(
axis
=
1
),
axis
=
1
)
}
support_types
=
get_xpu_op_support_types
(
'reduce_all'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceAllOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py
浏览文件 @
b3959fe4
...
...
@@ -18,26 +18,33 @@ import unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
from
op_test_xpu
import
OpTest
,
XPUOpTest
from
op_test
import
skip_check_grad_ci
import
paddle
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
,
Program
,
program_guard
from
paddle.fluid.framework
import
convert_np_dtype_to_dtype_
"""
class TestXPUReduceMaxOp(XPUOpTest):
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceMaxOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_max'
class
XPUTestReduceMaxBase
(
XPUOpTest
):
def
setUp
(
self
):
self.init_op_type()
self.initTestCase()
self.use_xpu = True
self.use_mkldnn = False
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
init_case
()
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_max'
self
.
attrs
=
{
'dim': self.axis
,
'keep_dim': self.keep_dim
,
'reduce_all': self.reduce_all
'use_xpu'
:
True
,
'reduce_all'
:
self
.
reduce_all
,
'keep_dim'
:
self
.
keep_dim
}
self.inputs = {'X': np.random.random(self.shape).astype('float32'
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
if
self
.
attrs
[
'reduce_all'
]:
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
max
()}
else
:
...
...
@@ -46,28 +53,29 @@ class TestXPUReduceMaxOp(XPUOpTest):
keepdims
=
self
.
attrs
[
'keep_dim'
])
}
def
init_case
(
self
):
self
.
shape
=
(
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
False
self
.
keep_dim
=
False
def
test_check_output
(
self
):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
pass
def init_op_type(self
):
self.op_type = 'reduce_max'
self.use_mkldnn = False
self.keep_dim = False
class
XPUTestReduceMaxCase1
(
XPUTestReduceMaxBase
):
def
init_case
(
self
):
self
.
shape
=
(
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
False
self
.
keep_dim
=
True
def initTestCase(self):
self.shape = (5, 6, 10
)
self.axis = (-1, )
"""
support_types
=
get_xpu_op_support_types
(
'reduce_max'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceMaxOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py
浏览文件 @
b3959fe4
...
...
@@ -194,13 +194,5 @@ class TestKeepDim8DReduce(Test1DReduce):
}
class
TestReduceAll
(
Test1DReduce
):
def
setUp
(
self
):
self
.
op_type
=
"reduce_mean"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
5
,
6
,
2
,
10
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'reduce_all'
:
True
,
'use_xpu'
:
True
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
mean
()}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py
0 → 100644
浏览文件 @
b3959fe4
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceMinOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_min'
class
XPUTestReduceMinBase
(
XPUOpTest
):
def
setUp
(
self
):
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
init_case
()
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_min'
self
.
attrs
=
{
'use_xpu'
:
True
,
'reduce_all'
:
self
.
reduce_all
,
'keep_dim'
:
self
.
keep_dim
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
if
self
.
attrs
[
'reduce_all'
]:
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
min
()}
else
:
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
min
(
axis
=
self
.
axis
,
keepdims
=
self
.
attrs
[
'keep_dim'
])
}
def
init_case
(
self
):
self
.
shape
=
(
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
False
self
.
keep_dim
=
False
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
pass
class
XPUTestReduceMinCase1
(
XPUTestReduceMinBase
):
def
init_case
(
self
):
self
.
shape
=
(
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
False
self
.
keep_dim
=
True
support_types
=
get_xpu_op_support_types
(
'reduce_min'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceMinOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py
浏览文件 @
b3959fe4
...
...
@@ -18,25 +18,31 @@ import unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
from
op_test_xpu
import
OpTest
,
XPUOpTest
from
op_test
import
skip_check_grad_ci
import
paddle
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
,
Program
,
program_guard
from
paddle.fluid.framework
import
convert_np_dtype_to_dtype_
from
op_test
import
OpTest
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
paddle
.
enable_static
()
class
XPUTestReduceSumOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'reduce_sum'
class
TestXPUReduceSumOp
(
XPUOpTest
):
class
XPUTestReduceSumBase
(
XPUOpTest
):
def
setUp
(
self
):
self
.
init_op_type
()
self
.
initTestCase
()
self
.
use_xpu
=
True
self
.
use_mkldnn
=
False
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
init_case
()
self
.
set_case
()
def
set_case
(
self
):
self
.
op_type
=
'reduce_sum'
self
.
attrs
=
{
'dim'
:
self
.
axis
,
'keep_dim'
:
self
.
keep_dim
,
'reduce_all'
:
self
.
reduce_all
'use_xpu'
:
True
,
'reduce_all'
:
self
.
reduce_all
,
'keep_dim'
:
self
.
keep_dim
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
if
self
.
attrs
[
'reduce_all'
]:
...
...
@@ -47,109 +53,29 @@ class TestXPUReduceSumOp(XPUOpTest):
keepdims
=
self
.
attrs
[
'keep_dim'
])
}
def
test_check_output
(
self
):
if
paddle
.
is_compiled_with_xpu
():
paddle
.
enable_static
()
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
)
def
test_check_grad
(
self
):
if
paddle
.
is_compiled_with_xpu
():
paddle
.
enable_static
()
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
)
def
init_op_type
(
self
):
self
.
op_type
=
"reduce_sum"
self
.
use_mkldnn
=
False
self
.
keep_dim
=
False
self
.
reduce_all
=
False
def
initTestCase
(
self
):
def
init_case
(
self
):
self
.
shape
=
(
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
False
self
.
keep_dim
=
False
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
class
TestSumOp5D
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
1
,
2
,
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
class
TestSumOp6D
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
1
,
1
,
2
,
5
,
6
,
10
)
self
.
axis
=
(
0
,
)
class
TestSumOp8D
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
1
,
3
,
1
,
2
,
1
,
4
,
3
,
10
)
self
.
axis
=
(
0
,
3
)
class
Test1DReduce
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
120
self
.
axis
=
(
0
,
)
class
Test2DReduce0
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
20
,
10
)
self
.
axis
=
(
0
,
)
class
Test2DReduce1
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
20
,
10
)
self
.
axis
=
(
1
,
)
class
Test3DReduce0
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
5
,
6
,
7
)
self
.
axis
=
(
1
,
)
class
Test3DReduce1
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
5
,
6
,
7
)
self
.
axis
=
(
2
,
)
class
Test3DReduce2
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
5
,
6
,
7
)
self
.
axis
=
(
-
2
,
)
class
Test3DReduce3
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
5
,
6
,
7
)
self
.
axis
=
(
1
,
2
)
def
test_check_grad
(
self
):
pass
class
TestKeepDimReduce
(
TestXPUReduceSumOp
):
def
initTestC
ase
(
self
):
class
XPUTestReduceSumCase1
(
XPUTestReduceSumBase
):
def
init_c
ase
(
self
):
self
.
shape
=
(
5
,
6
,
10
)
self
.
axis
=
(
1
,
)
self
.
keep_dim
=
True
class
TestKeepDim8DReduce
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
2
,
5
,
3
,
2
,
2
,
3
,
4
,
2
)
self
.
axis
=
(
3
,
4
,
5
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
False
self
.
keep_dim
=
True
class
TestReduceAll
(
TestXPUReduceSumOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
5
,
6
,
2
,
10
)
self
.
axis
=
(
0
,
)
self
.
reduce_all
=
True
support_types
=
get_xpu_op_support_types
(
'reduce_sum'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestReduceSumOp
,
stype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录