From aec49361ee75a44c453ecfbfd996ad7373686864 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 7 Jun 2022 10:38:21 +0800 Subject: [PATCH] [XPU KP]Add xpu register, any, amax, amin op test (#43204) --- .../{reduce_amax_op.cu => reduce_amax_op.kps} | 14 ++- .../{reduce_amin_op.cu => reduce_amin_op.kps} | 14 ++- paddle/phi/kernels/funcs/reduce_function.h | 15 +-- .../kernels/{gpu => kps}/reduce_any_kernel.cu | 6 +- paddle/phi/kernels/kps/reduce_max_kernel.cu | 1 - .../{gpu => kps}/reduce_prod_kernel.cu | 7 +- .../primitive/compute_primitives_xpu2.h | 15 +-- paddle/phi/kernels/reduce_all_kernel.cc | 4 + paddle/phi/kernels/reduce_any_kernel.cc | 4 + paddle/phi/kernels/reduce_max_kernel.cc | 4 + paddle/phi/kernels/reduce_mean_kernel.cc | 4 + paddle/phi/kernels/reduce_min_kernel.cc | 4 + paddle/phi/kernels/reduce_prod_kernel.cc | 4 + paddle/phi/kernels/reduce_sum_kernel.cc | 6 ++ .../unittests/xpu/test_reduce_amax_op_xpu.py | 67 +++++++++++++ .../unittests/xpu/test_reduce_amin_op_xpu.py | 67 +++++++++++++ .../unittests/xpu/test_reduce_any_op_xpu.py | 99 +++++++++++++++++++ 17 files changed, 315 insertions(+), 20 deletions(-) rename paddle/fluid/operators/reduce_ops/{reduce_amax_op.cu => reduce_amax_op.kps} (77%) rename paddle/fluid/operators/reduce_ops/{reduce_amin_op.cu => reduce_amin_op.kps} (77%) rename paddle/phi/kernels/{gpu => kps}/reduce_any_kernel.cu (87%) rename paddle/phi/kernels/{gpu => kps}/reduce_prod_kernel.cu (91%) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cu b/paddle/fluid/operators/reduce_ops/reduce_amax_op.kps similarity index 77% rename from paddle/fluid/operators/reduce_ops/reduce_amax_op.cu rename to paddle/fluid/operators/reduce_ops/reduce_amax_op.kps index b3385915341..09987279184 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.kps @@ -12,13 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef PADDLE_WITH_XPU_KP #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#endif + #include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; -// reduce_max +#ifdef PADDLE_WITH_XPU_KP +REGISTER_OP_KERNEL( + reduce_amax, KP, plat::XPUPlace, + ops::ReduceCudaKernel); +#else REGISTER_OP_CUDA_KERNEL( reduce_amax, ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel); +#endif diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cu b/paddle/fluid/operators/reduce_ops/reduce_amin_op.kps similarity index 77% rename from paddle/fluid/operators/reduce_ops/reduce_amin_op.cu rename to paddle/fluid/operators/reduce_ops/reduce_amin_op.kps index 037dab396c7..5e1139396d9 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.kps @@ -12,13 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef PADDLE_WITH_XPU_KP #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#endif + #include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; -// reduce_min +#ifdef PADDLE_WITH_XPU_KP +REGISTER_OP_KERNEL( + reduce_amin, KP, plat::XPUPlace, + ops::ReduceCudaKernel); +#else REGISTER_OP_CUDA_KERNEL( reduce_amin, ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel); +#endif diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 5c74751b348..4d903e01a49 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -236,8 +236,9 @@ struct IndexCalculator { template struct ReduceIndexMapping { const kps::DimConfig dim; - HOSTDEVICE explicit ReduceIndexMapping(const kps::DimConfig& dims) - : dim(dims) {} + int loop_size; + HOSTDEVICE ReduceIndexMapping(const kps::DimConfig& dims, int max_loop = 1) + : dim(dims), loop_size(max_loop) {} #ifdef PADDLE_WITH_XPU_KP __device__ __forceinline__ int BlockIdX() { @@ -277,10 +278,10 @@ struct ReduceIndexMapping { } __device__ __forceinline__ int GetLoopSize() { - if (ReduceLastDim) { - return dim.deal_size_y; - } else { + if ((!ReduceLastDim) && (loop_size == 1)) { return dim.deal_size_x; + } else { + return loop_size; } } #else @@ -670,7 +671,7 @@ __global__ void ReduceAnyKernel(const Tx* x, int store_offset = 0; int stride_left = 0; if (reduce_last_dim) { - auto block = ReduceIndexMapping(dim); + auto block = ReduceIndexMapping(dim, left_num); input_idx = block.BlockIdY() * block.BlockDimX(); left_idx = block.BlockIdX() * block.BlockDimY() + THREAD_ID_Y; stride = block.GridDimY() * block.BlockDimX(); @@ -681,7 +682,7 @@ __global__ void ReduceAnyKernel(const Tx* x, stride_left = 1; tid = THREAD_ID_X; } else { - auto block = ReduceIndexMapping(dim); + auto block = ReduceIndexMapping(dim, left_num); input_idx = block.BlockIdY() * block.BlockDimY(); left_idx = block.BlockIdX() * block.BlockDimX() + THREAD_ID_X; stride = block.GridDimY() * block.BlockDimY(); diff --git a/paddle/phi/kernels/gpu/reduce_any_kernel.cu b/paddle/phi/kernels/kps/reduce_any_kernel.cu similarity index 87% rename from paddle/phi/kernels/gpu/reduce_any_kernel.cu rename to paddle/phi/kernels/kps/reduce_any_kernel.cu index 25f73c64a54..480268936f4 100644 --- a/paddle/phi/kernels/gpu/reduce_any_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_any_kernel.cu @@ -32,4 +32,8 @@ void AnyRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(any_raw, GPU, ALL_LAYOUT, phi::AnyRawKernel, bool) {} +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {} +#else +PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {} +#endif diff --git a/paddle/phi/kernels/kps/reduce_max_kernel.cu b/paddle/phi/kernels/kps/reduce_max_kernel.cu index bc997c6c4e3..52644849ad8 100644 --- a/paddle/phi/kernels/kps/reduce_max_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_max_kernel.cu @@ -37,5 +37,4 @@ PD_REGISTER_KERNEL(max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float) {} #else PD_REGISTER_KERNEL( max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {} - #endif diff --git a/paddle/phi/kernels/gpu/reduce_prod_kernel.cu b/paddle/phi/kernels/kps/reduce_prod_kernel.cu similarity index 91% rename from paddle/phi/kernels/gpu/reduce_prod_kernel.cu rename to paddle/phi/kernels/kps/reduce_prod_kernel.cu index 4ae1dcfeba0..13d8e29b60b 100644 --- a/paddle/phi/kernels/gpu/reduce_prod_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_prod_kernel.cu @@ -31,12 +31,15 @@ void ProdRawKernel(const Context& dev_ctx, } } // namespace phi - +#ifdef PADDLE_WITH_XPU_KP +PD_REGISTER_KERNEL(prod_raw, KPS, ALL_LAYOUT, phi::ProdRawKernel, float) {} +#else PD_REGISTER_KERNEL(prod_raw, - GPU, + KPS, ALL_LAYOUT, phi::ProdRawKernel, float, double, int, int64_t) {} +#endif diff --git a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h index 6ec05ee5054..38a8d40aee6 100644 --- a/paddle/phi/kernels/primitive/compute_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/compute_primitives_xpu2.h @@ -48,7 +48,7 @@ static inline __device__ void sync_all() { #define ncores 64 template -__device__ void BlockXReduce(T* data, OpFunc reducer) { +__device__ void BlockXReduce(T* out, const T* data, OpFunc reducer) { __shared__ T sum_array[ncores * VecSize]; int core_idx = core_id() * VecSize; mfence(); @@ -57,21 +57,22 @@ __device__ void BlockXReduce(T* data, OpFunc reducer) { #pragma unroll for (int i = 0; i < VecSize; i++) { mfence(); - sum_array[core_idx + i] = data[i]; + sum_array[i * ncores + core_idx] = data[i]; mfence(); - data[i] = 0; } sync_all(); #pragma unroll for (int i = 0; i < VecSize; i++) { + T start = data[i * ncores]; #pragma unroll - for (int j = 0; j < ncores; j++) { + for (int j = 1; j < ncores; j++) { mfence(); - T tmp = sum_array[j * VecSize + i]; + T tmp = sum_array[i * ncores + j]; mfence(); - data[i] = reducer(data[i], tmp); + start = reducer(start, tmp); mfence(); } + out[i] = start; } sync_all(); } @@ -346,7 +347,7 @@ __device__ __forceinline__ void Reduce(T* out, if (reduce_last_dim) { #pragma unroll for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x - details::BlockXReduce(&out[i], reducer); + details::BlockXReduce(&out[i], &in[i], reducer); } } } else { // else kLocalMode diff --git a/paddle/phi/kernels/reduce_all_kernel.cc b/paddle/phi/kernels/reduce_all_kernel.cc index 5525f0dbfa7..9b4515ee290 100644 --- a/paddle/phi/kernels/reduce_all_kernel.cc +++ b/paddle/phi/kernels/reduce_all_kernel.cc @@ -36,3 +36,7 @@ PD_REGISTER_KERNEL(all, CPU, ALL_LAYOUT, phi::AllKernel, bool) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL(all, GPU, ALL_LAYOUT, phi::AllKernel, bool) {} #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(all, KPS, ALL_LAYOUT, phi::AllKernel, bool) {} +#endif diff --git a/paddle/phi/kernels/reduce_any_kernel.cc b/paddle/phi/kernels/reduce_any_kernel.cc index 01cbcd4029c..642b80c3d86 100644 --- a/paddle/phi/kernels/reduce_any_kernel.cc +++ b/paddle/phi/kernels/reduce_any_kernel.cc @@ -36,3 +36,7 @@ PD_REGISTER_KERNEL(any, CPU, ALL_LAYOUT, phi::AnyKernel, bool) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL(any, GPU, ALL_LAYOUT, phi::AnyKernel, bool) {} #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(any, KPS, ALL_LAYOUT, phi::AnyKernel, bool) {} +#endif diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index a7458a3e0ac..26b8bc196cc 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -38,3 +38,7 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( max, GPU, ALL_LAYOUT, phi::MaxKernel, float, double, int, int64_t) {} #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(max, KPS, ALL_LAYOUT, phi::MaxKernel, float) {} +#endif diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index 812cf8702e1..599b7eca321 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -46,3 +46,7 @@ PD_REGISTER_KERNEL(mean, int64_t, phi::dtype::float16) {} #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(mean, KPS, ALL_LAYOUT, phi::MeanKernel, float) {} +#endif diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 620b5167566..75d906aa4bd 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -38,3 +38,7 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {} #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {} +#endif diff --git a/paddle/phi/kernels/reduce_prod_kernel.cc b/paddle/phi/kernels/reduce_prod_kernel.cc index 5bd410709c6..3bb1c7552b1 100644 --- a/paddle/phi/kernels/reduce_prod_kernel.cc +++ b/paddle/phi/kernels/reduce_prod_kernel.cc @@ -38,3 +38,7 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( prod, GPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {} +#endif diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index e2b13333d7f..0d79fa34bc2 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -69,3 +69,9 @@ PD_REGISTER_KERNEL(sum, kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } #endif + +#if defined(PADDLE_WITH_XPU_KP) +PD_REGISTER_KERNEL(sum, KPS, ALL_LAYOUT, phi::SumKernel, float) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py new file mode 100644 index 00000000000..a6a0c7b5920 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_amax_op_xpu.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestReduceAmaxOp(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'reduce_amax' + + class XPUTestReduceAmaxBase(XPUOpTest): + + def setUp(self): + self.place = paddle.XPUPlace(0) + self.set_case() + + def set_case(self): + self.op_type = 'reduce_amax' + self.shape = (20, 10) + self.attrs = {'use_xpu': True, 'keep_dim': False, 'dim': (1, )} + + self.inputs = { + 'X': np.random.randint(0, 100, self.shape).astype("float32") + } + + expect_intput = self.inputs['X'] + self.outputs = { + 'Out': + np.amax(expect_intput, + axis=self.attrs['dim'], + keepdims=self.attrs['keep_dim']) + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + +support_types = get_xpu_op_support_types('reduce_amax') +for stype in support_types: + create_test_class(globals(), XPUTestReduceAmaxOp, stype) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py new file mode 100644 index 00000000000..def6c0821f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_amin_op_xpu.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestReduceAmaxOp(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'reduce_amin' + + class XPUTestReduceAmaxBase(XPUOpTest): + + def setUp(self): + self.place = paddle.XPUPlace(0) + self.set_case() + + def set_case(self): + self.op_type = 'reduce_amin' + self.shape = (20, 10) + self.attrs = {'use_xpu': True, 'keep_dim': False, 'dim': (1, )} + + self.inputs = { + 'X': np.random.randint(0, 100, self.shape).astype("float32") + } + + expect_intput = self.inputs['X'] + self.outputs = { + 'Out': + np.amin(expect_intput, + axis=self.attrs['dim'], + keepdims=self.attrs['keep_dim']) + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + +support_types = get_xpu_op_support_types('reduce_amin') +for stype in support_types: + create_test_class(globals(), XPUTestReduceAmaxOp, stype) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py new file mode 100644 index 00000000000..5118c3787e6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_reduce_any_op_xpu.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +import paddle +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestReduceAnyOp(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'reduce_any' + + class XPUTestReduceAnyBase(XPUOpTest): + + def setUp(self): + self.place = paddle.XPUPlace(0) + self.set_case() + + def set_case(self): + self.op_type = 'reduce_any' + self.attrs = { + 'use_xpu': True, + 'reduce_all': True, + 'keep_dim': True, + 'dim': (3, 5, 4) + } + self.inputs = { + 'X': + np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + class XPUTestReduceAnyCase1(XPUTestReduceAnyBase): + + def set_case(self): + self.op_type = 'reduce_any' + self.attrs = { + 'use_xpu': True, + 'dim': [1] + # 'reduce_all': True, + # 'keep_dim': True, + } + self.inputs = { + 'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool") + } + self.outputs = {'Out': self.inputs['X'].any(axis=1)} + + class XPUTestReduceAnyCase2(XPUTestReduceAnyBase): + + def set_case(self): + self.op_type = 'reduce_any' + self.attrs = { + 'use_xpu': True, + 'reduce_all': True, + 'keep_dim': False, + 'dim': (3, 6) + } + self.inputs = { + 'X': + np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool") + } + self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} + + +support_types = get_xpu_op_support_types('reduce_any') +for stype in support_types: + create_test_class(globals(), XPUTestReduceAnyOp, stype) + +if __name__ == '__main__': + unittest.main() -- GitLab