未验证 提交 aec49361 编写于 作者: N niuliling123 提交者: GitHub

[XPU KP]Add xpu register, any, amax, amin op test (#43204)

上级 a2020d0c
......@@ -12,13 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/phi/core/kernel_registry.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
// reduce_max
#ifdef PADDLE_WITH_XPU_KP
REGISTER_OP_KERNEL(
reduce_amax, KP, plat::XPUPlace,
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>);
#else
REGISTER_OP_CUDA_KERNEL(
reduce_amax,
ops::ReduceCudaKernel<float, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MaxFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MaxFunctor, kps::IdentityFunctor>);
#endif
......@@ -12,13 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/phi/core/kernel_registry.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
// reduce_min
#ifdef PADDLE_WITH_XPU_KP
REGISTER_OP_KERNEL(
reduce_amin, KP, plat::XPUPlace,
ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>);
#else
REGISTER_OP_CUDA_KERNEL(
reduce_amin,
ops::ReduceCudaKernel<float, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<double, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::MinFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::MinFunctor, kps::IdentityFunctor>);
#endif
......@@ -236,8 +236,9 @@ struct IndexCalculator {
template <bool ReduceLastDim = false>
struct ReduceIndexMapping {
const kps::DimConfig dim;
HOSTDEVICE explicit ReduceIndexMapping(const kps::DimConfig& dims)
: dim(dims) {}
int loop_size;
HOSTDEVICE ReduceIndexMapping(const kps::DimConfig& dims, int max_loop = 1)
: dim(dims), loop_size(max_loop) {}
#ifdef PADDLE_WITH_XPU_KP
__device__ __forceinline__ int BlockIdX() {
......@@ -277,10 +278,10 @@ struct ReduceIndexMapping {
}
__device__ __forceinline__ int GetLoopSize() {
if (ReduceLastDim) {
return dim.deal_size_y;
} else {
if ((!ReduceLastDim) && (loop_size == 1)) {
return dim.deal_size_x;
} else {
return loop_size;
}
}
#else
......@@ -670,7 +671,7 @@ __global__ void ReduceAnyKernel(const Tx* x,
int store_offset = 0;
int stride_left = 0;
if (reduce_last_dim) {
auto block = ReduceIndexMapping<true>(dim);
auto block = ReduceIndexMapping<true>(dim, left_num);
input_idx = block.BlockIdY() * block.BlockDimX();
left_idx = block.BlockIdX() * block.BlockDimY() + THREAD_ID_Y;
stride = block.GridDimY() * block.BlockDimX();
......@@ -681,7 +682,7 @@ __global__ void ReduceAnyKernel(const Tx* x,
stride_left = 1;
tid = THREAD_ID_X;
} else {
auto block = ReduceIndexMapping<false>(dim);
auto block = ReduceIndexMapping<false>(dim, left_num);
input_idx = block.BlockIdY() * block.BlockDimY();
left_idx = block.BlockIdX() * block.BlockDimX() + THREAD_ID_X;
stride = block.GridDimY() * block.BlockDimY();
......
......@@ -32,4 +32,8 @@ void AnyRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(any_raw, GPU, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
#else
PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
#endif
......@@ -37,5 +37,4 @@ PD_REGISTER_KERNEL(max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float) {}
#else
PD_REGISTER_KERNEL(
max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {}
#endif
......@@ -31,12 +31,15 @@ void ProdRawKernel(const Context& dev_ctx,
}
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(prod_raw, KPS, ALL_LAYOUT, phi::ProdRawKernel, float) {}
#else
PD_REGISTER_KERNEL(prod_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::ProdRawKernel,
float,
double,
int,
int64_t) {}
#endif
......@@ -48,7 +48,7 @@ static inline __device__ void sync_all() {
#define ncores 64
template <typename T, typename OpFunc, int VecSize>
__device__ void BlockXReduce(T* data, OpFunc reducer) {
__device__ void BlockXReduce(T* out, const T* data, OpFunc reducer) {
__shared__ T sum_array[ncores * VecSize];
int core_idx = core_id() * VecSize;
mfence();
......@@ -57,21 +57,22 @@ __device__ void BlockXReduce(T* data, OpFunc reducer) {
#pragma unroll
for (int i = 0; i < VecSize; i++) {
mfence();
sum_array[core_idx + i] = data[i];
sum_array[i * ncores + core_idx] = data[i];
mfence();
data[i] = 0;
}
sync_all();
#pragma unroll
for (int i = 0; i < VecSize; i++) {
T start = data[i * ncores];
#pragma unroll
for (int j = 0; j < ncores; j++) {
for (int j = 1; j < ncores; j++) {
mfence();
T tmp = sum_array[j * VecSize + i];
T tmp = sum_array[i * ncores + j];
mfence();
data[i] = reducer(data[i], tmp);
start = reducer(start, tmp);
mfence();
}
out[i] = start;
}
sync_all();
}
......@@ -346,7 +347,7 @@ __device__ __forceinline__ void Reduce(T* out,
if (reduce_last_dim) {
#pragma unroll
for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x
details::BlockXReduce<T, ReduceFunctor, 1>(&out[i], reducer);
details::BlockXReduce<T, ReduceFunctor, 1>(&out[i], &in[i], reducer);
}
}
} else { // else kLocalMode
......
......@@ -36,3 +36,7 @@ PD_REGISTER_KERNEL(all, CPU, ALL_LAYOUT, phi::AllKernel, bool) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(all, GPU, ALL_LAYOUT, phi::AllKernel, bool) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(all, KPS, ALL_LAYOUT, phi::AllKernel, bool) {}
#endif
......@@ -36,3 +36,7 @@ PD_REGISTER_KERNEL(any, CPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(any, GPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(any, KPS, ALL_LAYOUT, phi::AnyKernel, bool) {}
#endif
......@@ -38,3 +38,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
max, GPU, ALL_LAYOUT, phi::MaxKernel, float, double, int, int64_t) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(max, KPS, ALL_LAYOUT, phi::MaxKernel, float) {}
#endif
......@@ -46,3 +46,7 @@ PD_REGISTER_KERNEL(mean,
int64_t,
phi::dtype::float16) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(mean, KPS, ALL_LAYOUT, phi::MeanKernel, float) {}
#endif
......@@ -38,3 +38,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {}
#endif
......@@ -38,3 +38,7 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
prod, GPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {}
#endif
......@@ -69,3 +69,9 @@ PD_REGISTER_KERNEL(sum,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(sum, KPS, ALL_LAYOUT, phi::SumKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestReduceAmaxOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_amax'
class XPUTestReduceAmaxBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.set_case()
def set_case(self):
self.op_type = 'reduce_amax'
self.shape = (20, 10)
self.attrs = {'use_xpu': True, 'keep_dim': False, 'dim': (1, )}
self.inputs = {
'X': np.random.randint(0, 100, self.shape).astype("float32")
}
expect_intput = self.inputs['X']
self.outputs = {
'Out':
np.amax(expect_intput,
axis=self.attrs['dim'],
keepdims=self.attrs['keep_dim'])
}
def test_check_output(self):
self.check_output_with_place(self.place)
support_types = get_xpu_op_support_types('reduce_amax')
for stype in support_types:
create_test_class(globals(), XPUTestReduceAmaxOp, stype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestReduceAmaxOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_amin'
class XPUTestReduceAmaxBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.set_case()
def set_case(self):
self.op_type = 'reduce_amin'
self.shape = (20, 10)
self.attrs = {'use_xpu': True, 'keep_dim': False, 'dim': (1, )}
self.inputs = {
'X': np.random.randint(0, 100, self.shape).astype("float32")
}
expect_intput = self.inputs['X']
self.outputs = {
'Out':
np.amin(expect_intput,
axis=self.attrs['dim'],
keepdims=self.attrs['keep_dim'])
}
def test_check_output(self):
self.check_output_with_place(self.place)
support_types = get_xpu_op_support_types('reduce_amin')
for stype in support_types:
create_test_class(globals(), XPUTestReduceAmaxOp, stype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestReduceAnyOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_any'
class XPUTestReduceAnyBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.set_case()
def set_case(self):
self.op_type = 'reduce_any'
self.attrs = {
'use_xpu': True,
'reduce_all': True,
'keep_dim': True,
'dim': (3, 5, 4)
}
self.inputs = {
'X':
np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestReduceAnyCase1(XPUTestReduceAnyBase):
def set_case(self):
self.op_type = 'reduce_any'
self.attrs = {
'use_xpu': True,
'dim': [1]
# 'reduce_all': True,
# 'keep_dim': True,
}
self.inputs = {
'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")
}
self.outputs = {'Out': self.inputs['X'].any(axis=1)}
class XPUTestReduceAnyCase2(XPUTestReduceAnyBase):
def set_case(self):
self.op_type = 'reduce_any'
self.attrs = {
'use_xpu': True,
'reduce_all': True,
'keep_dim': False,
'dim': (3, 6)
}
self.inputs = {
'X':
np.random.randint(0, 2, (2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])}
support_types = get_xpu_op_support_types('reduce_any')
for stype in support_types:
create_test_class(globals(), XPUTestReduceAnyOp, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册