未验证 提交 3749198e 编写于 作者: L Lijunhui 提交者: GitHub

[KP] Add Logical/compare/bitwise registry & UT (#40802)

* init commit no push

* collect comile errors

* bitwise UT

* fix compile problem

* cancel comments

* restore miss deletion

* fix compilation

* fix UT

* NO stash in multiple branch at the same times

* fix error

* combine .cu from gpu and kps

* replace gpu by kps

* fix by Chen-weihang

* Revert "Fix kps compile error in Junhui logic compare bitwise"

* fix backend test

* rm comments
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 a058b474
......@@ -59,6 +59,32 @@ XPUOpMap& get_kp_ops() {
{"swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"thresholded_relu",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// bitwise logical & compare
{"bitwise_and", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"bitwise_or", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"bitwise_not", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"bitwise_xor", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace())})},
{"logical_and",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"logical_or", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"logical_not",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"logical_xor",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"less_than", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"less_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"greater_than",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"not_equal", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
};
return s_xpu_kp_kernels;
......
......@@ -159,7 +159,14 @@ inline Backend StringToBackend(const char* backend_cstr) {
} else if (s == std::string("GPUDNN")) {
return Backend::GPUDNN;
} else if (s == std::string("KPS")) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// NOTE(chenweihang) KPS is not yet a complete backend, and it still needs
// to be converted
// to GPU in the GPU environment
return Backend::GPU;
#else
return Backend::KPS;
#endif
} else if (s == std::string("IPU")) {
return Backend::IPU;
} else {
......
......@@ -15,7 +15,8 @@
#pragma once
// CUDA, XPU and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || defined(__xpu__)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU_KP)
#include <algorithm>
#include <cmath>
......@@ -34,7 +35,6 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
......@@ -52,7 +52,9 @@ namespace cub = hipcub;
#define REDUCE_VEC_SIZE 4
namespace kps = phi::kps;
#ifdef PADDLE_WITH_XPU_KP
using dim3 = phi::kps::dim3;
#endif
namespace phi {
namespace funcs {
......@@ -82,12 +84,14 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
return strides;
}
#ifndef PADDLE_WITH_XPU_KP
// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
return block_dim >= kps::details::kReduceMaxThread
? kps::details::kReduceMaxThread
: GetLastPow2(block_dim);
}
#endif
// check reduce rand is valid
static inline void CheckReduceRank(int reduce_rank, int rank) {
......@@ -180,12 +184,12 @@ struct IndexCalculator {
strides = details::VectorToArray<int, kMaxRank>(full_strides);
reduce_strides = details::VectorToArray<int, kMaxRank>(cal_strides);
#ifndef PADDLE_WITH_XPU_KP
std::vector<paddle::platform::FastDivMod> cal_divmoders;
std::vector<kps::details::FastDivMod> cal_divmoders; // namespace
// fast divmod
for (auto i : cal_strides) {
cal_divmoders.push_back(paddle::platform::FastDivMod(i));
cal_divmoders.push_back(kps::details::FastDivMod(i));
}
divmoders = details::VectorToArray<paddle::platform::FastDivMod, kMaxRank>(
divmoders = details::VectorToArray<kps::details::FastDivMod, kMaxRank>(
cal_divmoders);
#endif
}
......@@ -222,7 +226,7 @@ struct IndexCalculator {
phi::Array<int, kMaxRank> strides;
phi::Array<int, kMaxRank> reduce_strides;
#ifndef PADDLE_WITH_XPU_KP
phi::Array<paddle::platform::FastDivMod, kMaxRank> divmoders;
phi::Array<kps::details::FastDivMod, kMaxRank> divmoders;
#endif
};
......@@ -579,11 +583,11 @@ struct ReduceConfig {
void SetBlockDim() {
// init
int block_num = details::GetBlockDim(reduce_num);
should_reduce_again = false;
dim3 block_dim(block_num, 1, 1);
dim3 block_dim;
dim3 grid_dim(left_num, 1, 1);
blocking_size = reduce_num;
#ifdef PADDLE_WITH_XPU_KP
if (reduce_last_dim) {
block_dim.x = 64;
......@@ -990,6 +994,7 @@ static void LaunchReduceKernel(const Tx* x_data,
}
}
#if !defined(PADDLE_WITH_XPU_KP)
template <typename Tx,
typename Ty,
template <typename> class ReduceOp,
......@@ -1044,6 +1049,7 @@ CubTensorReduceImpl(const Tx* x_data,
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}
#endif // PADDLE_WITH_XPU_KP
template <typename Tx,
typename Ty,
......
......@@ -15,7 +15,8 @@
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU_KP)
#include "paddle/phi/kernels/funcs/reduce_function.h"
......
......@@ -14,7 +14,12 @@ limitations under the License. */
#include "paddle/phi/kernels/bitwise_kernel.h"
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#else
#include "paddle/phi/backends/gpu/gpu_context.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/bitwise_functors.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
......@@ -53,8 +58,19 @@ void BitwiseNotKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(
bitwise_and, KPS, ALL_LAYOUT, phi::BitwiseAndKernel, int, bool) {}
PD_REGISTER_KERNEL(
bitwise_or, KPS, ALL_LAYOUT, phi::BitwiseOrKernel, int, bool) {}
PD_REGISTER_KERNEL(
bitwise_xor, KPS, ALL_LAYOUT, phi::BitwiseXorKernel, int, bool) {}
PD_REGISTER_KERNEL(
bitwise_not, KPS, ALL_LAYOUT, phi::BitwiseNotKernel, int, bool) {}
#else
PD_REGISTER_KERNEL(bitwise_and,
GPU,
KPS,
ALL_LAYOUT,
phi::BitwiseAndKernel,
bool,
......@@ -65,7 +81,7 @@ PD_REGISTER_KERNEL(bitwise_and,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_or,
GPU,
KPS,
ALL_LAYOUT,
phi::BitwiseOrKernel,
bool,
......@@ -76,7 +92,7 @@ PD_REGISTER_KERNEL(bitwise_or,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_xor,
GPU,
KPS,
ALL_LAYOUT,
phi::BitwiseXorKernel,
bool,
......@@ -87,7 +103,7 @@ PD_REGISTER_KERNEL(bitwise_xor,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_not,
GPU,
KPS,
ALL_LAYOUT,
phi::BitwiseNotKernel,
bool,
......@@ -96,3 +112,5 @@ PD_REGISTER_KERNEL(bitwise_not,
int16_t,
int,
int64_t) {}
#endif
......@@ -12,17 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/impl/compare_kernel_impl.h"
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#else
#include <thrust/fill.h>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/compare_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#endif
namespace phi {
......@@ -53,6 +57,7 @@ inline void CompareKernelImpl(const Context& ctx,
ctx, ins, &outs, axis, Functor());
}
#ifndef PADDLE_WITH_XPU_KP
template <typename T, typename Context, typename Functor>
inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& x,
......@@ -83,11 +88,22 @@ inline void CompareAllKernelImpl(const Context& ctx,
funcs::ReduceKernel<bool, bool, BitwiseAdd, kps::IdentityFunctor<bool>>(
ctx, tmp, out, kps::IdentityFunctor<bool>(), reduce_dims);
}
#endif
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {}
PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) {
}
PD_REGISTER_KERNEL(
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {}
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {}
#else
PD_REGISTER_KERNEL(less_than,
GPU,
KPS,
ALL_LAYOUT,
phi::LessThanKernel,
bool,
......@@ -97,7 +113,7 @@ PD_REGISTER_KERNEL(less_than,
float,
double) {}
PD_REGISTER_KERNEL(less_equal,
GPU,
KPS,
ALL_LAYOUT,
phi::LessEqualKernel,
bool,
......@@ -107,7 +123,7 @@ PD_REGISTER_KERNEL(less_equal,
float,
double) {}
PD_REGISTER_KERNEL(greater_than,
GPU,
KPS,
ALL_LAYOUT,
phi::GreaterThanKernel,
bool,
......@@ -117,7 +133,7 @@ PD_REGISTER_KERNEL(greater_than,
float,
double) {}
PD_REGISTER_KERNEL(greater_equal,
GPU,
KPS,
ALL_LAYOUT,
phi::GreaterEqualKernel,
bool,
......@@ -127,7 +143,7 @@ PD_REGISTER_KERNEL(greater_equal,
float,
double) {}
PD_REGISTER_KERNEL(equal,
GPU,
KPS,
ALL_LAYOUT,
phi::EqualKernel,
bool,
......@@ -137,7 +153,7 @@ PD_REGISTER_KERNEL(equal,
float,
double) {}
PD_REGISTER_KERNEL(not_equal,
GPU,
KPS,
ALL_LAYOUT,
phi::NotEqualKernel,
bool,
......@@ -148,7 +164,7 @@ PD_REGISTER_KERNEL(not_equal,
double) {}
PD_REGISTER_KERNEL(equal_all,
GPU,
KPS,
ALL_LAYOUT,
phi::EqualAllKernel,
bool,
......@@ -156,3 +172,4 @@ PD_REGISTER_KERNEL(equal_all,
int64_t,
float,
double) {}
#endif
......@@ -10,11 +10,15 @@
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// limitation
#include "paddle/phi/kernels/logical_kernel.h"
#ifdef PADDLE_WITH_XPU_KP
#include "paddle/phi/backends/xpu/xpu_context.h"
#else
#include "paddle/phi/backends/gpu/gpu_context.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/logical_functor.h"
......@@ -59,9 +63,15 @@ void LogicalNotKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {}
PD_REGISTER_KERNEL(logical_Or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {}
PD_REGISTER_KERNEL(logical_Not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {}
PD_REGISTER_KERNEL(logical_Xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {}
#else
#define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \
GPU, \
KPS, \
ALL_LAYOUT, \
phi::Logical##func_type##Kernel, \
float, \
......@@ -76,3 +86,4 @@ REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, Not)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, Xor)
#endif
......@@ -23,7 +23,7 @@ struct dim3 {
int y;
int z;
explicit inline dim3(int split_x, int split_y = 1, int split_z = 1) {
explicit inline dim3(int split_x = 1, int split_y = 1, int split_z = 1) {
x = split_x;
y = split_y;
z = split_z;
......
......@@ -64,7 +64,11 @@ TEST(Backend, StringToBackend) {
EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU"));
EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN"));
EXPECT_EQ(phi::Backend::GPUDNN, pexp::StringToBackend("GPUDNN"));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
EXPECT_EQ(phi::Backend::GPU, pexp::StringToBackend("KPS"));
#else
EXPECT_EQ(phi::Backend::KPS, pexp::StringToBackend("KPS"));
#endif
EXPECT_EQ(static_cast<phi::Backend>(
static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1),
pexp::StringToBackend("CustomBackend"));
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
################## TEST OP: BitwiseAnd ##################
class XPUTestBitwiseAnd(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'bitwise_and'
class XPUTestBitwiseAndBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_and'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.bitwise_and(x, y)
self.attrs = {'use_xpu': True}
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestBitwiseAndCase1(XPUTestBitwiseAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
class XPUTestBitwiseAndCase2(XPUTestBitwiseAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
self.low = -100
self.high = 100
class XPUTestBitwiseAndCase3(XPUTestBitwiseAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = 0
self.high = 100
support_types = get_xpu_op_support_types('bitwise_and')
for stype in support_types:
create_test_class(globals(), XPUTestBitwiseAnd, stype)
################## TEST OP: BitwiseOr ##################
class XPUTestBitwiseOr(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'bitwise_or'
class XPUTestBitwiseOrBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_or'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.bitwise_or(x, y)
self.attrs = {'use_xpu': True}
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestBitwiseOrCase1(XPUTestBitwiseOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
class XPUTestBitwiseOrCase2(XPUTestBitwiseOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
self.low = -100
self.high = 100
class XPUTestBitwiseOrCase3(XPUTestBitwiseOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = 0
self.high = 100
support_types = get_xpu_op_support_types('bitwise_or')
for stype in support_types:
create_test_class(globals(), XPUTestBitwiseOr, stype)
################## TEST OP: BitwiseXor ##################
class XPUTestBitwiseXor(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'bitwise_xor'
class XPUTestBitwiseXorBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_xor'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.bitwise_xor(x, y)
self.attrs = {'use_xpu': True}
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestBitwiseXorCase1(XPUTestBitwiseXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
class XPUTestBitwiseXorCase2(XPUTestBitwiseXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
self.low = -100
self.high = 100
class XPUTestBitwiseXorCase3(XPUTestBitwiseXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = 0
self.high = 100
support_types = get_xpu_op_support_types('bitwise_xor')
for stype in support_types:
create_test_class(globals(), XPUTestBitwiseXor, stype)
################## TEST OP: BitwiseNot ##################
class XPUTestBitwiseNot(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'bitwise_not'
class XPUTestBitwiseNotBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_not'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
out = np.bitwise_not(x)
self.attrs = {'use_xpu': True}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestBitwiseNotBool(XPUTestBitwiseNotBase):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'bitwise_not'
x = np.random.choice([True, False], self.x_shape)
out = np.bitwise_not(x)
self.attrs = {'use_xpu': True}
self.inputs = {'X': x}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.bool
self.x_shape = [2, 3, 4, 5]
support_types = get_xpu_op_support_types('bitwise_not')
for stype in support_types:
create_test_class(globals(), XPUTestBitwiseNot, stype)
if __name__ == '__main__':
unittest.main()
......@@ -65,7 +65,7 @@ def create_test_class(op_type, typename, callback):
globals()[cls_name] = Cls
for _type_name in {'float32', 'int32', 'int64'}:
for _type_name in {'int32'}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
......
......@@ -13,232 +13,220 @@
# limitations under the License.
from __future__ import print_function
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from paddle.fluid.op import Operator
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest
from paddle.static import Program, program_guard
SUPPORTED_DTYPES = [
bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64
]
TEST_META_OP_DATA = [{
'op_str': 'logical_and',
'binary_op': True
}, {
'op_str': 'logical_or',
'binary_op': True
}, {
'op_str': 'logical_not',
'binary_op': False
}]
TEST_META_SHAPE_DATA = {
'XDimLargerThanYDim1': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 5]
},
'XDimLargerThanYDim2': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 1]
},
'XDimLargerThanYDim3': {
'x_shape': [2, 3, 4, 5],
'y_shape': [1, 4, 1]
},
'XDimLargerThanYDim4': {
'x_shape': [2, 3, 4, 5],
'y_shape': [3, 4, 1]
},
'XDimLargerThanYDim5': {
'x_shape': [2, 3, 1, 5],
'y_shape': [3, 1, 1]
},
'XDimLessThanYDim1': {
'x_shape': [4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim2': {
'x_shape': [1, 4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim3': {
'x_shape': [3, 4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim4': {
'x_shape': [3, 1, 1],
'y_shape': [2, 3, 1, 5]
},
'XDimLessThanYDim5': {
'x_shape': [4, 5],
'y_shape': [2, 3, 4, 5]
},
'Axis1InLargerDim': {
'x_shape': [1, 4, 5],
'y_shape': [2, 3, 1, 5]
},
'EqualDim1': {
'x_shape': [10, 7],
'y_shape': [10, 7]
},
'EqualDim2': {
'x_shape': [1, 1, 4, 5],
'y_shape': [2, 3, 1, 5]
}
}
TEST_META_WRONG_SHAPE_DATA = {
'ErrorDim1': {
'x_shape': [2, 3, 4, 5],
'y_shape': [3, 4]
},
'ErrorDim2': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 3]
}
}
def run_static_xpu(x_np, y_np, op_str, binary_op=True):
paddle.enable_static()
startup_program = fluid.Program()
main_program = fluid.Program()
place = paddle.XPUPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
op = getattr(paddle, op_str)
feed_list = {'x': x_np}
if not binary_op:
res = op(x)
else:
y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype)
feed_list['y'] = y_np
res = op(x, y)
exe.run(startup_program)
static_result = exe.run(main_program, feed=feed_list, fetch_list=[res])
return static_result
def run_dygraph_xpu(x_np, y_np, op_str, binary_op=True):
place = paddle.XPUPlace(0)
paddle.disable_static(place)
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np, dtype=x_np.dtype)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np, dtype=y_np.dtype)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
else:
return np.random.randn(*np_shape).astype(dtype)
def test_xpu(unit_test, test_error=False):
for op_data in TEST_META_OP_DATA:
meta_data = dict(op_data)
np_op = getattr(np, meta_data['op_str'])
META_DATA = dict(TEST_META_SHAPE_DATA)
if test_error:
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values():
for data_type in SUPPORTED_DTYPES:
meta_data['x_np'] = np_data_generator(
shape_data['x_shape'], dtype=data_type)
meta_data['y_np'] = np_data_generator(
shape_data['y_shape'], dtype=data_type)
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static_xpu,
**meta_data)
continue
static_result = run_static_xpu(**meta_data)
dygraph_result = run_dygraph_xpu(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all(
))
def test_type_error(unit_test, type_str_map):
def check_type(op_str, x, y, binary_op):
op = getattr(paddle, op_str)
error_type = ValueError
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
error_type = BaseException
if binary_op:
if type_str_map['x'] != type_str_map['y']:
unit_test.assertRaises(error_type, op, x=x, y=y)
if not fluid._non_static_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
else:
if not fluid._non_static_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, out=1)
place = paddle.XPUPlace(0)
for op_data in TEST_META_OP_DATA:
meta_data = dict(op_data)
binary_op = meta_data['binary_op']
paddle.disable_static(place)
x = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['x'])
y = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['y'])
check_type(meta_data['op_str'], x, y, binary_op)
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name='x', shape=[10], dtype=type_str_map['x'])
y = paddle.static.data(
name='y', shape=[10], dtype=type_str_map['y'])
check_type(meta_data['op_str'], x, y, binary_op)
def type_map_factory():
return [{
'x': x_type,
'y': y_type
} for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES]
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPU(unittest.TestCase):
def test(self):
test_xpu(self, True)
def test_error(self):
test_xpu(self, True)
def test_type_error(self):
type_map_list = type_map_factory()
for type_map in type_map_list:
test_type_error(self, type_map)
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
################## TEST OP: logical_and ##################
class XPUTestLogicalAnd(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'logical_and'
class XPUTestLogicalAndBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'logical_and'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.logical_and(x, y)
self.attrs = {'use_xpu': True}
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestLogicalAndCase1(XPUTestLogicalAndBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
support_types = get_xpu_op_support_types('logical_and')
for stype in support_types:
create_test_class(globals(), XPUTestLogicalAnd, stype)
################## TEST OP: logical_or ##################
class XPUTestLogicalOr(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'logical_or'
class XPUTestLogicalOrBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'logical_or'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.logical_or(x, y)
self.attrs = {'use_xpu': True}
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestLogicalOrCase1(XPUTestLogicalOrBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
support_types = get_xpu_op_support_types('logical_or')
for stype in support_types:
create_test_class(globals(), XPUTestLogicalOr, stype)
################## TEST OP: logical_xor ##################
class XPUTestLogicalXor(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'logical_xor'
class XPUTestLogicalXorBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'logical_xor'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.logical_xor(x, y)
self.attrs = {'use_xpu': True}
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int64
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
class XPUTestLogicalXorCase1(XPUTestLogicalXorBase):
def init_case(self):
self.dtype = np.int32
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
support_types = get_xpu_op_support_types('logical_xor')
for stype in support_types:
create_test_class(globals(), XPUTestLogicalXor, stype)
################## TEST OP: LogicalNot ##################
class XPUTestLogicalNot(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'logical_not'
class XPUTestLogicalNotBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_case()
self.set_case()
def set_case(self):
self.op_type = 'logical_not'
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
out = np.logical_not(x)
self.attrs = {'use_xpu': True}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_case(self):
self.dtype = np.int32
self.x_shape = [2, 3, 4, 5]
self.low = -100
self.high = 100
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
support_types = get_xpu_op_support_types('logical_not')
for stype in support_types:
create_test_class(globals(), XPUTestLogicalNot, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册