未验证 提交 b3959fe4 编写于 作者: L Lijunhui 提交者: GitHub

[KP] Add Reduce op registry & UT for xpu_kp compilation (#41869)

上级 14c35a58
......@@ -71,8 +71,10 @@ PD_DECLARE_KERNEL(concat_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_XPU_KP
PD_DECLARE_KERNEL(add_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
#else
PD_DECLARE_KERNEL(add_raw, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(max_raw, KPS, ALL_LAYOUT);
#endif
PD_DECLARE_KERNEL(add, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
......@@ -85,7 +87,6 @@ PD_DECLARE_KERNEL(matmul_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(transpose_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
......
......@@ -97,6 +97,14 @@ XPUOpMap& get_kp_ops() {
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"not_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
// reduce op
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})},
{"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})},
};
return s_xpu_kp_kernels;
......
......@@ -18,6 +18,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#define INT_BITS 32
#if defined(__xpu__)
#define __forceinline__ __inline__
#endif
namespace paddle {
namespace platform {
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/phi/core/hostdevice.h"
#if defined(__xpu__)
#define CHAR_BIT 8
#endif
namespace phi {
......
......@@ -33,10 +33,14 @@
namespace cub = hipcub;
#endif
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/array.h"
......@@ -183,7 +187,7 @@ struct IndexCalculator {
strides = details::VectorToArray<int, kMaxRank>(full_strides);
reduce_strides = details::VectorToArray<int, kMaxRank>(cal_strides);
#ifndef PADDLE_WITH_XPU_KP
std::vector<kps::details::FastDivMod> cal_divmoders; // namespace
std::vector<kps::details::FastDivMod> cal_divmoders;
// fast divmod
for (auto i : cal_strides) {
cal_divmoders.push_back(kps::details::FastDivMod(i));
......@@ -325,9 +329,10 @@ struct ReduceConfig {
// step4: set the block and grid for launch kernel
SetBlockDim();
#ifndef PADDLE_WITH_XPU_KP
// step5: limit the grid to prevent thead overflow
paddle::platform::LimitGridDim(dev_ctx, &grid);
#endif
}
// when should_reduce_again is true, we need malloc temp space for temp data
......
......@@ -41,7 +41,7 @@ void Reduce(const KPDevice& dev_ctx,
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
#ifndef PADDLE_WITH_XPU_KP
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES(
......@@ -73,6 +73,16 @@ void Reduce(const KPDevice& dev_ctx,
reduce_dims,
is_mean);
}
#else
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>>(
dev_ctx,
x,
out,
TransformOp<T, MPType>(reduce_num),
reduce_dims,
is_mean);
#endif
}
} // namespace phi
......
......@@ -33,4 +33,8 @@ void AllRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(all_raw, GPU, ALL_LAYOUT, phi::AllRawKernel, bool) {}
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {}
#else
PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {}
#endif
......@@ -33,5 +33,10 @@ void MaxRawKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float) {}
#else
PD_REGISTER_KERNEL(
max_raw, GPU, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {}
max_raw, KPS, ALL_LAYOUT, phi::MaxRawKernel, float, double, int, int64_t) {}
#endif
......@@ -33,10 +33,13 @@ void MeanRawKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(mean_raw, KPS, ALL_LAYOUT, phi::MeanRawKernel, float) {}
#else
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(mean_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::MeanRawKernel,
float,
......@@ -45,3 +48,4 @@ PD_REGISTER_KERNEL(mean_raw,
float16,
int,
int64_t) {}
#endif
......@@ -33,5 +33,9 @@ void MinRawKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float) {}
#else
PD_REGISTER_KERNEL(
min_raw, GPU, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {}
min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {}
#endif
......@@ -33,13 +33,18 @@ void SumRawKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(sum_raw, KPS, ALL_LAYOUT, phi::SumRawKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#else
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(sum_raw,
GPU,
KPS,
ALL_LAYOUT,
phi::SumRawKernel,
bool,
......@@ -54,3 +59,4 @@ PD_REGISTER_KERNEL(sum_raw,
complex128) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif
......@@ -336,7 +336,7 @@ __device__ __forceinline__ void Reduce(T* out,
out[i] = reducer(out[i], in[i * NX + j]);
}
}
BlockXReduce<T, ReduceFunctor, NY>(out, reducer);
details::BlockXReduce<T, ReduceFunctor, NY>(out, reducer);
} else { // else kLocalMode
#pragma unroll
for (int i = 0; i < NY; ++i) {
......
......@@ -77,7 +77,7 @@ struct BroadcastConfig {
#pragma pack()
template <typename T>
__device__ __forceinline__ void WriteData(T* _global_ptr_ dst,
__device__ __forceinline__ void WriteData(T _global_ptr_* dst,
T* src,
int num) {
if (num > 0) {
......@@ -403,16 +403,17 @@ template <typename Tx,
typename IndexCal,
typename Functor,
bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(Ty* dst,
const Tx* __restrict__ src,
int block_offset,
const IndexCal& index_cal,
int size_nx,
int size_ny,
int stride_nx,
int stride_ny,
Functor func,
bool reduce_last_dim) {
__device__ __forceinline__ void ReadDataReduce(
Ty* dst,
const Tx _global_ptr_* __restrict__ src,
int block_offset,
const IndexCal& index_cal,
int size_nx,
int size_ny,
int stride_nx,
int stride_ny,
Functor func,
bool reduce_last_dim) {
__local__ Tx in_temp[1];
int thread_offset = 0;
int left_idx = 0;
......
......@@ -25,6 +25,12 @@ namespace kps {
*/
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
#ifdef PADDLE_WITH_XPU_KP
HOSTDEVICE inline IdentityFunctor() {}
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
HOSTDEVICE Ty operator()(const Tx x) const { return static_cast<Ty>(x); }
HOSTDEVICE inline void SetDiv(int n) {}
#else
inline IdentityFunctor() {}
explicit inline IdentityFunctor(int n) {}
......@@ -38,6 +44,7 @@ struct IdentityFunctor {
return static_cast<Ty>(x);
}
__device__ inline void SetDiv(int n) {}
#endif
};
/**
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestReduceAllOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_all'
class XPUTestReduceAllBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.set_case()
def set_case(self):
self.op_type = 'reduce_all'
self.attrs = {
'use_xpu': True,
'reduce_all': True,
'keep_dim': True,
'dim': (3, 5, 4)
}
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestReduceAllCase1(XPUTestReduceAllBase):
def set_case(self):
self.op_type = 'reduce_all'
self.attrs = {
'use_xpu': True,
'reduce_all': True,
'keep_dim': True,
'dim': [1]
}
self.inputs = {
'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")
}
self.outputs = {'Out': self.inputs['X'].all()}
class XPUTestReduceAllCase2(XPUTestReduceAllBase):
def set_case(self):
self.op_type = 'reduce_all'
self.attrs = {
'use_xpu': True,
'reduce_all': True,
'keep_dim': False,
'dim': (3, 6)
}
self.inputs = {
'X': np.random.randint(0, 2,
(2, 5, 3, 2, 2, 3, 4, 2)).astype("bool")
}
self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])}
class XPUTestReduceAllCase3(XPUTestReduceAllBase):
def set_case(self):
self.op_type = 'reduce_all'
self.attrs = {
'use_xpu': True,
'keep_dim': True,
'dim': [1]
# 'reduce_all': True,
}
self.inputs = {
'X': np.random.randint(0, 2, (5, 6, 10)).astype("bool")
}
self.outputs = {
'Out': np.expand_dims(
self.inputs['X'].all(axis=1), axis=1)
}
support_types = get_xpu_op_support_types('reduce_all')
for stype in support_types:
create_test_class(globals(), XPUTestReduceAllOp, stype)
if __name__ == '__main__':
unittest.main()
......@@ -18,56 +18,64 @@ import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test_xpu import OpTest, XPUOpTest
from op_test import skip_check_grad_ci
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import convert_np_dtype_to_dtype_
"""
class TestXPUReduceMaxOp(XPUOpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.use_xpu = True
self.use_mkldnn = False
self.attrs = {
'dim': self.axis,
'keep_dim': self.keep_dim,
'reduce_all': self.reduce_all
}
self.inputs = {'X': np.random.random(self.shape).astype('float32')}
if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].max()}
else:
self.outputs = {
'Out': self.inputs['X'].max(axis=self.axis,
keepdims=self.attrs['keep_dim'])
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestReduceMaxOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_max'
class XPUTestReduceMaxBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'reduce_max'
self.attrs = {
'use_xpu': True,
'reduce_all': self.reduce_all,
'keep_dim': self.keep_dim
}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].max()}
else:
self.outputs = {
'Out': self.inputs['X'].max(axis=self.axis,
keepdims=self.attrs['keep_dim'])
}
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0, )
self.reduce_all = False
self.keep_dim = False
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
pass
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class XPUTestReduceMaxCase1(XPUTestReduceMaxBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0, )
self.reduce_all = False
self.keep_dim = True
def init_op_type(self):
self.op_type = 'reduce_max'
self.use_mkldnn = False
self.keep_dim = False
self.reduce_all = False
def initTestCase(self):
self.shape = (5, 6, 10)
self.axis = (-1, )
"""
support_types = get_xpu_op_support_types('reduce_max')
for stype in support_types:
create_test_class(globals(), XPUTestReduceMaxOp, stype)
if __name__ == '__main__':
unittest.main()
......@@ -194,13 +194,5 @@ class TestKeepDim8DReduce(Test1DReduce):
}
class TestReduceAll(Test1DReduce):
def setUp(self):
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'reduce_all': True, 'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].mean()}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestReduceMinOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_min'
class XPUTestReduceMinBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'reduce_min'
self.attrs = {
'use_xpu': True,
'reduce_all': self.reduce_all,
'keep_dim': self.keep_dim
}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].min()}
else:
self.outputs = {
'Out': self.inputs['X'].min(axis=self.axis,
keepdims=self.attrs['keep_dim'])
}
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0, )
self.reduce_all = False
self.keep_dim = False
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestReduceMinCase1(XPUTestReduceMinBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0, )
self.reduce_all = False
self.keep_dim = True
support_types = get_xpu_op_support_types('reduce_min')
for stype in support_types:
create_test_class(globals(), XPUTestReduceMinOp, stype)
if __name__ == '__main__':
unittest.main()
......@@ -18,138 +18,64 @@ import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test_xpu import OpTest, XPUOpTest
from op_test import skip_check_grad_ci
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import convert_np_dtype_to_dtype_
class TestXPUReduceSumOp(XPUOpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.use_xpu = True
self.use_mkldnn = False
self.attrs = {
'dim': self.axis,
'keep_dim': self.keep_dim,
'reduce_all': self.reduce_all
}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].sum()}
else:
self.outputs = {
'Out': self.inputs['X'].sum(axis=self.axis,
keepdims=self.attrs['keep_dim'])
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_op_type(self):
self.op_type = "reduce_sum"
self.use_mkldnn = False
self.keep_dim = False
self.reduce_all = False
def initTestCase(self):
self.shape = (5, 6, 10)
self.axis = (0, )
class TestSumOp5D(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (1, 2, 5, 6, 10)
self.axis = (0, )
class TestSumOp6D(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (1, 1, 2, 5, 6, 10)
self.axis = (0, )
class TestSumOp8D(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (1, 3, 1, 2, 1, 4, 3, 10)
self.axis = (0, 3)
class Test1DReduce(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = 120
self.axis = (0, )
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
class Test2DReduce0(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (20, 10)
self.axis = (0, )
class Test2DReduce1(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (20, 10)
self.axis = (1, )
class Test3DReduce0(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (1, )
class Test3DReduce1(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (2, )
class Test3DReduce2(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (-2, )
class Test3DReduce3(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (1, 2)
class TestKeepDimReduce(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 10)
self.axis = (1, )
self.keep_dim = True
paddle.enable_static()
class TestKeepDim8DReduce(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (2, 5, 3, 2, 2, 3, 4, 2)
self.axis = (3, 4, 5)
self.keep_dim = True
class XPUTestReduceSumOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'reduce_sum'
class TestReduceAll(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 2, 10)
self.axis = (0, )
self.reduce_all = True
class XPUTestReduceSumBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'reduce_sum'
self.attrs = {
'use_xpu': True,
'reduce_all': self.reduce_all,
'keep_dim': self.keep_dim
}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].sum()}
else:
self.outputs = {
'Out': self.inputs['X'].sum(axis=self.axis,
keepdims=self.attrs['keep_dim'])
}
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0, )
self.reduce_all = False
self.keep_dim = False
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestReduceSumCase1(XPUTestReduceSumBase):
def init_case(self):
self.shape = (5, 6, 10)
self.axis = (0, )
self.reduce_all = False
self.keep_dim = True
support_types = get_xpu_op_support_types('reduce_sum')
for stype in support_types:
create_test_class(globals(), XPUTestReduceSumOp, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册