From b3959fe4e63a2bbddbce6cf238d21a709da6e0dd Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Mon, 18 Apr 2022 18:25:46 +0800 Subject: [PATCH] [KP] Add Reduce op registry & UT for xpu_kp compilation (#41869) --- .../new_executor/standalone_executor_test.cc | 3 +- .../platform/device/xpu/xpu_op_kpfirst_list.h | 8 + paddle/fluid/platform/fast_divmod.h | 3 + paddle/phi/kernels/funcs/aligned_vector.h | 3 + paddle/phi/kernels/funcs/reduce_function.h | 9 +- paddle/phi/kernels/gpu/reduce.h | 12 +- .../kernels/{gpu => kps}/reduce_all_kernel.cu | 6 +- .../kernels/{gpu => kps}/reduce_max_kernel.cu | 7 +- .../{gpu => kps}/reduce_mean_kernel.cu | 6 +- .../kernels/{gpu => kps}/reduce_min_kernel.cu | 6 +- .../kernels/{gpu => kps}/reduce_sum_kernel.cu | 8 +- .../primitive/compute_primitives_xpu2.h | 2 +- .../primitive/datamover_primitives_xpu2.h | 23 +-- .../primitive/functor_primitives_xpu2.h | 7 + .../unittests/xpu/test_reduce_all_op_xpu.py | 111 +++++++++++ .../unittests/xpu/test_reduce_max_op_xpu.py | 96 +++++----- .../unittests/xpu/test_reduce_mean_op_xpu.py | 8 - .../unittests/xpu/test_reduce_min_op_xpu.py | 81 ++++++++ .../unittests/xpu/test_reduce_sum_op_xpu.py | 178 +++++------------- 19 files changed, 378 insertions(+), 199 deletions(-) rename paddle/phi/kernels/{gpu => kps}/reduce_all_kernel.cu (87%) rename paddle/phi/kernels/{gpu => kps}/reduce_max_kernel.cu (87%) rename paddle/phi/kernels/{gpu => kps}/reduce_mean_kernel.cu (91%) rename paddle/phi/kernels/{gpu => kps}/reduce_min_kernel.cu (87%) rename paddle/phi/kernels/{gpu => kps}/reduce_sum_kernel.cu (89%) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index fe4b47cba62..5efd0fb4207 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -71,8 +71,10 @@ PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); #ifdef PADDLE_WITH_XPU_KP PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT); #else PD_DECLARE_KERNEL(add_raw, KPS, ALL_LAYOUT); +PD_DECLARE_KERNEL(max_raw, KPS, ALL_LAYOUT); #endif PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT); @@ -85,7 +87,6 @@ PD_DECLARE_KERNEL(matmul_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(transpose_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT); diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h index 9afde00a98b..99a1eb97de5 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h @@ -97,6 +97,14 @@ XPUOpMap& get_kp_ops() { XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, {"equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, + // reduce op + {"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, + {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, }; return s_xpu_kp_kernels; diff --git a/paddle/fluid/platform/fast_divmod.h b/paddle/fluid/platform/fast_divmod.h index 39eefab774d..bef551078b3 100644 --- a/paddle/fluid/platform/fast_divmod.h +++ b/paddle/fluid/platform/fast_divmod.h @@ -18,6 +18,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/aligned_vector.h" #define INT_BITS 32 +#if defined(__xpu__) +#define __forceinline__ __inline__ +#endif namespace paddle { namespace platform { diff --git a/paddle/phi/kernels/funcs/aligned_vector.h b/paddle/phi/kernels/funcs/aligned_vector.h index d71a61f107a..14a9560b841 100644 --- a/paddle/phi/kernels/funcs/aligned_vector.h +++ b/paddle/phi/kernels/funcs/aligned_vector.h @@ -15,6 +15,9 @@ limitations under the License. */ #pragma once #include #include "paddle/phi/core/hostdevice.h" +#if defined(__xpu__) +#define CHAR_BIT 8 +#endif namespace phi { diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index b414dfc5d6e..42fee144883 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -33,10 +33,14 @@ namespace cub = hipcub; #endif +#ifndef PADDLE_WITH_XPU_KP #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#endif + +#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/array.h" @@ -183,7 +187,7 @@ struct IndexCalculator { strides = details::VectorToArray(full_strides); reduce_strides = details::VectorToArray(cal_strides); #ifndef PADDLE_WITH_XPU_KP - std::vector cal_divmoders; // namespace + std::vector cal_divmoders; // fast divmod for (auto i : cal_strides) { cal_divmoders.push_back(kps::details::FastDivMod(i)); @@ -325,9 +329,10 @@ struct ReduceConfig { // step4: set the block and grid for launch kernel SetBlockDim(); - +#ifndef PADDLE_WITH_XPU_KP // step5: limit the grid to prevent thead overflow paddle::platform::LimitGridDim(dev_ctx, &grid); +#endif } // when should_reduce_again is true, we need malloc temp space for temp data diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index 6fb81edd6bf..7f6ecef8087 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -41,7 +41,7 @@ void Reduce(const KPDevice& dev_ctx, for (auto i : reduce_dims) { reduce_num *= (x.dims())[i]; } - +#ifndef PADDLE_WITH_XPU_KP if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) { auto tmp_tensor = phi::Cast(dev_ctx, x, out_dtype); PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( @@ -73,6 +73,16 @@ void Reduce(const KPDevice& dev_ctx, reduce_dims, is_mean); } +#else + using MPType = typename kps::details::MPTypeTrait::Type; + phi::funcs::ReduceKernel>( + dev_ctx, + x, + out, + TransformOp(reduce_num), + reduce_dims, + is_mean); +#endif } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_all_kernel.cu b/paddle/phi/kernels/kps/reduce_all_kernel.cu similarity index 87% rename from paddle/phi/kernels/gpu/reduce_all_kernel.cu rename to paddle/phi/kernels/kps/reduce_all_kernel.cu index 2963d3f206c..dc6355a213f 100644 --- a/paddle/phi/kernels/gpu/reduce_all_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_all_kernel.cu @@ -33,4 +33,8 @@ void AllRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(all_raw, GPU, ALL_LAYOUT, phi::AllRawKernel, bool) {} +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {} +#else +PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {} +#endif diff --git a/paddle/phi/kernels/gpu/reduce_max_kernel.cu b/paddle/phi/kernels/kps/reduce_max_kernel.cu similarity index 87% rename from paddle/phi/kernels/gpu/reduce_max_kernel.cu rename to paddle/phi/kernels/kps/reduce_max_kernel.cu index 98c3986c51d..dd63b05bda1 100644 --- a/paddle/phi/kernels/gpu/reduce_max_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_max_kernel.cu @@ -33,5 +33,10 @@ void MaxRawKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float) {} +#else PD_REGISTER_KERNEL( - max_raw, GPU, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {} + max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {} + +#endif diff --git a/paddle/phi/kernels/gpu/reduce_mean_kernel.cu b/paddle/phi/kernels/kps/reduce_mean_kernel.cu similarity index 91% rename from paddle/phi/kernels/gpu/reduce_mean_kernel.cu rename to paddle/phi/kernels/kps/reduce_mean_kernel.cu index 5a2cc8036a1..8e4a65df122 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_mean_kernel.cu @@ -33,10 +33,13 @@ void MeanRawKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(mean_raw, KPS, ALL_LAYOUT, phi::MeanRawKernel, float) {} +#else using float16 = phi::dtype::float16; PD_REGISTER_KERNEL(mean_raw, - GPU, + KPS, ALL_LAYOUT, phi::MeanRawKernel, float, @@ -45,3 +48,4 @@ PD_REGISTER_KERNEL(mean_raw, float16, int, int64_t) {} +#endif diff --git a/paddle/phi/kernels/gpu/reduce_min_kernel.cu b/paddle/phi/kernels/kps/reduce_min_kernel.cu similarity index 87% rename from paddle/phi/kernels/gpu/reduce_min_kernel.cu rename to paddle/phi/kernels/kps/reduce_min_kernel.cu index ba37d54895d..59d69c29dec 100644 --- a/paddle/phi/kernels/gpu/reduce_min_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_min_kernel.cu @@ -33,5 +33,9 @@ void MinRawKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float) {} +#else PD_REGISTER_KERNEL( - min_raw, GPU, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {} + min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {} +#endif diff --git a/paddle/phi/kernels/gpu/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu similarity index 89% rename from paddle/phi/kernels/gpu/reduce_sum_kernel.cu rename to paddle/phi/kernels/kps/reduce_sum_kernel.cu index 28bdbd009bd..6c039897ddd 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -33,13 +33,18 @@ void SumRawKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(sum_raw, KPS, ALL_LAYOUT, phi::SumRawKernel, float) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} +#else using float16 = phi::dtype::float16; using bfloat16 = phi::dtype::bfloat16; using complex64 = ::phi::dtype::complex; using complex128 = ::phi::dtype::complex; PD_REGISTER_KERNEL(sum_raw, - GPU, + KPS, ALL_LAYOUT, phi::SumRawKernel, bool, @@ -54,3 +59,4 @@ PD_REGISTER_KERNEL(sum_raw, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } +#endif diff --git a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h index 1f4ef2ed932..4d65dd6dd5d 100644 --- a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h @@ -336,7 +336,7 @@ __device__ __forceinline__ void Reduce(T* out, out[i] = reducer(out[i], in[i * NX + j]); } } - BlockXReduce(out, reducer); + details::BlockXReduce(out, reducer); } else { // else kLocalMode #pragma unroll for (int i = 0; i < NY; ++i) { diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index d2cfdbdec30..a18fc7cbb31 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -77,7 +77,7 @@ struct BroadcastConfig { #pragma pack() template -__device__ __forceinline__ void WriteData(T* _global_ptr_ dst, +__device__ __forceinline__ void WriteData(T _global_ptr_* dst, T* src, int num) { if (num > 0) { @@ -403,16 +403,17 @@ template -__device__ __forceinline__ void ReadDataReduce(Ty* dst, - const Tx* __restrict__ src, - int block_offset, - const IndexCal& index_cal, - int size_nx, - int size_ny, - int stride_nx, - int stride_ny, - Functor func, - bool reduce_last_dim) { +__device__ __forceinline__ void ReadDataReduce( + Ty* dst, + const Tx _global_ptr_* __restrict__ src, + int block_offset, + const IndexCal& index_cal, + int size_nx, + int size_ny, + int stride_nx, + int stride_ny, + Functor func, + bool reduce_last_dim) { __local__ Tx in_temp[1]; int thread_offset = 0; int left_idx = 0; diff --git a/paddle/phi/kernels/primitive/functor_primitives_xpu2.h b/paddle/phi/kernels/primitive/functor_primitives_xpu2.h index 8a21e61eaa7..b01e0474f2d 100755 --- a/paddle/phi/kernels/primitive/functor_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/functor_primitives_xpu2.h @@ -25,6 +25,12 @@ namespace kps { */ template struct IdentityFunctor { +#ifdef PADDLE_WITH_XPU_KP + HOSTDEVICE inline IdentityFunctor() {} + HOSTDEVICE explicit inline IdentityFunctor(int n) {} + HOSTDEVICE Ty operator()(const Tx x) const { return static_cast(x); } + HOSTDEVICE inline void SetDiv(int n) {} +#else inline IdentityFunctor() {} explicit inline IdentityFunctor(int n) {} @@ -38,6 +44,7 @@ struct IdentityFunctor { return static_cast(x); } __device__ inline void SetDiv(int n) {} +#endif }; /** diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py new file mode 100644 index 00000000000..b4dc8e7b7cf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_all_op_xpu.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") + +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestReduceAllOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'reduce_all' + + class XPUTestReduceAllBase(XPUOpTest): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.set_case() + + def set_case(self): + self.op_type = 'reduce_all' + self.attrs = { + 'use_xpu': True, + 'reduce_all': True, + 'keep_dim': True, + 'dim': (3, 5, 4) + } + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + class XPUTestReduceAllCase1(XPUTestReduceAllBase): + def set_case(self): + self.op_type = 'reduce_all' + self.attrs = { + 'use_xpu': True, + 'reduce_all': True, + 'keep_dim': True, + 'dim': [1] + } + self.inputs = { + 'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool") + } + self.outputs = {'Out': self.inputs['X'].all()} + + class XPUTestReduceAllCase2(XPUTestReduceAllBase): + def set_case(self): + self.op_type = 'reduce_all' + self.attrs = { + 'use_xpu': True, + 'reduce_all': True, + 'keep_dim': False, + 'dim': (3, 6) + } + self.inputs = { + 'X': np.random.randint(0, 2, + (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} + + class XPUTestReduceAllCase3(XPUTestReduceAllBase): + def set_case(self): + self.op_type = 'reduce_all' + self.attrs = { + 'use_xpu': True, + 'keep_dim': True, + 'dim': [1] + # 'reduce_all': True, + } + self.inputs = { + 'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool") + } + self.outputs = { + 'Out': np.expand_dims( + self.inputs['X'].all(axis=1), axis=1) + } + + +support_types = get_xpu_op_support_types('reduce_all') +for stype in support_types: + create_test_class(globals(), XPUTestReduceAllOp, stype) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py index 6ea55f5ba93..1dd7b42e5eb 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_max_op_xpu.py @@ -18,56 +18,64 @@ import unittest import numpy as np import sys sys.path.append("..") -from op_test_xpu import OpTest, XPUOpTest -from op_test import skip_check_grad_ci + import paddle -import paddle.fluid.core as core -import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard -from paddle.fluid.framework import convert_np_dtype_to_dtype_ -""" -class TestXPUReduceMaxOp(XPUOpTest): - def setUp(self): - self.init_op_type() - self.initTestCase() - self.use_xpu = True - self.use_mkldnn = False - self.attrs = { - 'dim': self.axis, - 'keep_dim': self.keep_dim, - 'reduce_all': self.reduce_all - } - self.inputs = {'X': np.random.random(self.shape).astype('float32')} - if self.attrs['reduce_all']: - self.outputs = {'Out': self.inputs['X'].max()} - else: - self.outputs = { - 'Out': self.inputs['X'].max(axis=self.axis, - keepdims=self.attrs['keep_dim']) +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestReduceMaxOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'reduce_max' + + class XPUTestReduceMaxBase(XPUOpTest): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.init_case() + self.set_case() + + def set_case(self): + self.op_type = 'reduce_max' + self.attrs = { + 'use_xpu': True, + 'reduce_all': self.reduce_all, + 'keep_dim': self.keep_dim } + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + if self.attrs['reduce_all']: + self.outputs = {'Out': self.inputs['X'].max()} + else: + self.outputs = { + 'Out': self.inputs['X'].max(axis=self.axis, + keepdims=self.attrs['keep_dim']) + } + + def init_case(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + self.reduce_all = False + self.keep_dim = False + + def test_check_output(self): + self.check_output_with_place(self.place) - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) + def test_check_grad(self): + pass - def test_check_grad(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') + class XPUTestReduceMaxCase1(XPUTestReduceMaxBase): + def init_case(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + self.reduce_all = False + self.keep_dim = True - def init_op_type(self): - self.op_type = 'reduce_max' - self.use_mkldnn = False - self.keep_dim = False - self.reduce_all = False - def initTestCase(self): - self.shape = (5, 6, 10) - self.axis = (-1, ) -""" +support_types = get_xpu_op_support_types('reduce_max') +for stype in support_types: + create_test_class(globals(), XPUTestReduceMaxOp, stype) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py index 5e866dddbe2..18a588b1b88 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_mean_op_xpu.py @@ -194,13 +194,5 @@ class TestKeepDim8DReduce(Test1DReduce): } -class TestReduceAll(Test1DReduce): - def setUp(self): - self.op_type = "reduce_mean" - self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} - self.attrs = {'reduce_all': True, 'use_xpu': True} - self.outputs = {'Out': self.inputs['X'].mean()} - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py new file mode 100644 index 00000000000..cf77ea09a58 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_min_op_xpu.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") + +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestReduceMinOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'reduce_min' + + class XPUTestReduceMinBase(XPUOpTest): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.init_case() + self.set_case() + + def set_case(self): + self.op_type = 'reduce_min' + self.attrs = { + 'use_xpu': True, + 'reduce_all': self.reduce_all, + 'keep_dim': self.keep_dim + } + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + if self.attrs['reduce_all']: + self.outputs = {'Out': self.inputs['X'].min()} + else: + self.outputs = { + 'Out': self.inputs['X'].min(axis=self.axis, + keepdims=self.attrs['keep_dim']) + } + + def init_case(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + self.reduce_all = False + self.keep_dim = False + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + class XPUTestReduceMinCase1(XPUTestReduceMinBase): + def init_case(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + self.reduce_all = False + self.keep_dim = True + + +support_types = get_xpu_op_support_types('reduce_min') +for stype in support_types: + create_test_class(globals(), XPUTestReduceMinOp, stype) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py index 638da601a3d..9f42a509624 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_sum_op_xpu.py @@ -18,138 +18,64 @@ import unittest import numpy as np import sys sys.path.append("..") -from op_test_xpu import OpTest, XPUOpTest -from op_test import skip_check_grad_ci -import paddle -import paddle.fluid.core as core -import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard -from paddle.fluid.framework import convert_np_dtype_to_dtype_ - - -class TestXPUReduceSumOp(XPUOpTest): - def setUp(self): - self.init_op_type() - self.initTestCase() - self.use_xpu = True - self.use_mkldnn = False - self.attrs = { - 'dim': self.axis, - 'keep_dim': self.keep_dim, - 'reduce_all': self.reduce_all - } - self.inputs = {'X': np.random.random(self.shape).astype("float32")} - if self.attrs['reduce_all']: - self.outputs = {'Out': self.inputs['X'].sum()} - else: - self.outputs = { - 'Out': self.inputs['X'].sum(axis=self.axis, - keepdims=self.attrs['keep_dim']) - } - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], 'Out') - - def init_op_type(self): - self.op_type = "reduce_sum" - self.use_mkldnn = False - self.keep_dim = False - self.reduce_all = False - - def initTestCase(self): - self.shape = (5, 6, 10) - self.axis = (0, ) - - -class TestSumOp5D(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (1, 2, 5, 6, 10) - self.axis = (0, ) - - -class TestSumOp6D(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (1, 1, 2, 5, 6, 10) - self.axis = (0, ) - - -class TestSumOp8D(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (1, 3, 1, 2, 1, 4, 3, 10) - self.axis = (0, 3) - - -class Test1DReduce(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = 120 - self.axis = (0, ) +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper -class Test2DReduce0(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (20, 10) - self.axis = (0, ) - - -class Test2DReduce1(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (20, 10) - self.axis = (1, ) - - -class Test3DReduce0(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (5, 6, 7) - self.axis = (1, ) - - -class Test3DReduce1(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (5, 6, 7) - self.axis = (2, ) - - -class Test3DReduce2(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (5, 6, 7) - self.axis = (-2, ) - - -class Test3DReduce3(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (5, 6, 7) - self.axis = (1, 2) - - -class TestKeepDimReduce(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (5, 6, 10) - self.axis = (1, ) - self.keep_dim = True - +paddle.enable_static() -class TestKeepDim8DReduce(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (2, 5, 3, 2, 2, 3, 4, 2) - self.axis = (3, 4, 5) - self.keep_dim = True +class XPUTestReduceSumOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'reduce_sum' -class TestReduceAll(TestXPUReduceSumOp): - def initTestCase(self): - self.shape = (5, 6, 2, 10) - self.axis = (0, ) - self.reduce_all = True + class XPUTestReduceSumBase(XPUOpTest): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.init_case() + self.set_case() + def set_case(self): + self.op_type = 'reduce_sum' + self.attrs = { + 'use_xpu': True, + 'reduce_all': self.reduce_all, + 'keep_dim': self.keep_dim + } + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + if self.attrs['reduce_all']: + self.outputs = {'Out': self.inputs['X'].sum()} + else: + self.outputs = { + 'Out': self.inputs['X'].sum(axis=self.axis, + keepdims=self.attrs['keep_dim']) + } + + def init_case(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + self.reduce_all = False + self.keep_dim = False + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + class XPUTestReduceSumCase1(XPUTestReduceSumBase): + def init_case(self): + self.shape = (5, 6, 10) + self.axis = (0, ) + self.reduce_all = False + self.keep_dim = True + + +support_types = get_xpu_op_support_types('reduce_sum') +for stype in support_types: + create_test_class(globals(), XPUTestReduceSumOp, stype) if __name__ == '__main__': unittest.main() -- GitLab