未验证 提交 798b527c 编写于 作者: D duanyanhui 提交者: GitHub

[XPU] Add kernels for VITDET (#50992)

* add support of int64 add for xpu

* add transpose support for int64

* add randperm kernel

* fix randperm

* add distribute_fpn_proposal kernel

* fix comment

* add reduce_sum_int32
上级 a548e70c
......@@ -168,6 +168,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"distribute_fpn_proposals", XPUKernelSet({phi::DataType::FLOAT32})},
{"diagonal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......@@ -496,6 +497,16 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})},
{"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64})},
{"randperm_raw",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64})},
{"randperm",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32,
phi::DataType::FLOAT64})},
{"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})},
{"reciprocal_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -509,6 +520,7 @@ XPUOpMap& get_kl2_ops() {
{"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -324,7 +324,8 @@ PD_REGISTER_KERNEL(
divide, XPU, ALL_LAYOUT, phi::DivideKernel, phi::dtype::float16, float) {}
PD_REGISTER_KERNEL(
add, XPU, ALL_LAYOUT, phi::AddKernel, phi::dtype::float16, float) {}
add, XPU, ALL_LAYOUT, phi::AddKernel, phi::dtype::float16, float, int64_t) {
}
PD_REGISTER_KERNEL(multiply,
XPU,
......
......@@ -35,7 +35,8 @@ inline std::vector<size_t> GetLodFromRoisNum(const Context& dev_ctx,
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
DenseTensor cpu_tensor;
if (paddle::platform::is_gpu_place(rois_num->place())) {
if (paddle::platform::is_gpu_place(rois_num->place()) ||
paddle::platform::is_xpu_place(rois_num->place())) {
Copy<Context>(dev_ctx, *rois_num, phi::CPUPlace(), true, &cpu_tensor);
rois_num_data = cpu_tensor.data<int>();
}
......
......@@ -90,6 +90,7 @@ PD_REGISTER_KERNEL(sum,
float,
phi::dtype::float16,
int8_t,
int,
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/distribute_fpn_proposals_kernel.h"
#include "paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
static void Sort(const XPUContext& dev_ctx,
const DenseTensor& value,
DenseTensor* index_out) {
auto* value_data = value.data<T>();
auto place = dev_ctx.GetPlace();
auto cpu_place = phi::CPUPlace();
DenseTensor scores_slice_cpu;
scores_slice_cpu.Resize({value.numel()});
T* scores_slice_cpu_data = dev_ctx.template HostAlloc<T>(&scores_slice_cpu);
paddle::memory::Copy(cpu_place,
scores_slice_cpu_data,
place,
value_data,
sizeof(T) * value.numel());
// Sort index
DenseTensor index_t;
index_t.Resize({value.numel()});
int* index = dev_ctx.template HostAlloc<int>(&index_t);
for (int i = 0; i < value.numel(); ++i) {
index[i] = i;
}
auto compare = [scores_slice_cpu_data](const int64_t& i, const int64_t& j) {
return scores_slice_cpu_data[i] < scores_slice_cpu_data[j];
};
std::sort(index, index + value.numel(), compare);
index_out->Resize({index_t.numel()});
int* idx_out = dev_ctx.template Alloc<int>(index_out);
paddle::memory::Copy(
place, idx_out, cpu_place, index, sizeof(T) * index_t.numel());
}
template <typename T, typename Context>
void DistributeFpnProposalsKernel(
const Context& dev_ctx,
const DenseTensor& fpn_rois,
const paddle::optional<DenseTensor>& rois_num,
int min_level,
int max_level,
int refer_level,
int refer_scale,
bool pixel_offset,
std::vector<DenseTensor*> multi_fpn_rois,
std::vector<DenseTensor*> multi_level_rois_num,
DenseTensor* restore_index) {
const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
if (!rois_num.get_ptr()) {
PADDLE_ENFORCE_EQ(
fpn_rois.lod().size(),
1UL,
errors::InvalidArgument("DistributeFpnProposalsOp needs LoD"
"with one level"));
}
using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<size_t> fpn_rois_lod;
if (rois_num.get_ptr()) {
fpn_rois_lod = funcs::GetLodFromRoisNum(dev_ctx, rois_num.get_ptr());
} else {
fpn_rois_lod = fpn_rois.lod().back();
}
int lod_size = fpn_rois_lod.size() - 1;
// the total num of roi
int roi_num = fpn_rois_lod[lod_size];
DenseTensor sub_lod_list;
sub_lod_list.Resize({num_level, lod_size});
int* sub_lod_list_data = dev_ctx.template Alloc<int>(&sub_lod_list);
phi::funcs::SetConstant<phi::XPUContext, int> set_zero;
set_zero(dev_ctx, &sub_lod_list, static_cast<int>(0));
DenseTensor target_lvls;
target_lvls.Resize({roi_num});
int* target_lvls_data = dev_ctx.template Alloc<int>(&target_lvls);
std::vector<int> rois_lod_vec(fpn_rois_lod.size(), 0);
for (size_t i = 0; i < fpn_rois_lod.size(); ++i) {
rois_lod_vec[i] = static_cast<int>(fpn_rois_lod[i]);
}
xpu::VectorParam<int> rois_lod = {
rois_lod_vec.data(), static_cast<int>(rois_lod_vec.size()), nullptr};
int r = xpu::distribute_fpn_proposals_helper<XPUType, int, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(fpn_rois.data<T>()),
rois_lod,
sub_lod_list_data,
target_lvls_data,
static_cast<int64_t>(min_level),
static_cast<int64_t>(max_level),
static_cast<int64_t>(refer_level),
static_cast<int64_t>(refer_scale),
pixel_offset);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "distribute_fpn_proposals_helper");
DenseTensor index_out_t;
Sort<int>(dev_ctx, target_lvls, &index_out_t);
Sort<int>(dev_ctx, index_out_t, restore_index);
restore_index->Resize({roi_num, 1});
int start = 0;
std::vector<int> sub_lod_list_cpu(lod_size * num_level);
phi::TensorToVector<int>(sub_lod_list, dev_ctx, &sub_lod_list_cpu);
for (int i = 0; i < num_level; ++i) {
DenseTensor sub_lod = sub_lod_list.Slice(i, i + 1);
// transfer length-based lod to offset-based lod
std::vector<size_t> offset(1, 0);
for (int j = 0; j < lod_size; ++j) {
offset.emplace_back(offset.back() + sub_lod_list_cpu[i * lod_size + j]);
}
int sub_rois_num = offset.back();
int end = start + sub_rois_num;
if (end > start) {
DenseTensor sub_idx = index_out_t.Slice(start, end);
start = end;
multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim});
dev_ctx.template Alloc<T>(multi_fpn_rois[i]);
std::vector<int> fpn_rois_shape(fpn_rois.dims().size());
for (int i = 0; i < fpn_rois.dims().size(); ++i) {
fpn_rois_shape[i] = fpn_rois.dims()[i];
}
int r1 = xpu::gather<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(fpn_rois.data<T>()),
sub_idx.data<int>(),
reinterpret_cast<XPUType*>(multi_fpn_rois[i]->data<T>()),
fpn_rois_shape,
sub_idx.numel(),
0);
PADDLE_ENFORCE_XDNN_SUCCESS(r1, "distribute_fpn_proposals_helper");
} else {
multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim});
dev_ctx.template Alloc<T>(multi_fpn_rois[i]);
}
if (multi_level_rois_num.size() > 0) {
DenseTensor* rois_num_t = multi_level_rois_num[i];
Copy(dev_ctx, sub_lod, dev_ctx.GetPlace(), true, rois_num_t);
rois_num_t->Resize({lod_size});
}
LoD lod;
lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod);
}
}
} // namespace phi
PD_REGISTER_KERNEL(distribute_fpn_proposals,
XPU,
ALL_LAYOUT,
phi::DistributeFpnProposalsKernel,
float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/randperm_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
if (dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU) {
T* out_data = dev_ctx.template HostAlloc<T>(out);
for (int i = 0; i < n; ++i) {
out_data[i] = static_cast<T>(i);
}
std::shuffle(out_data, out_data + n, *engine);
} else {
dev_ctx.template Alloc<T>(out);
phi::DenseTensor tmp_tensor;
tmp_tensor.Resize(phi::make_ddim({n}));
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp_tensor);
for (int i = 0; i < n; ++i) {
tmp_data[i] = static_cast<T>(i);
}
std::shuffle(tmp_data, tmp_data + n, *engine);
Copy(dev_ctx, tmp_tensor, dev_ctx.GetPlace(), true, out);
}
}
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T, Context>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
XPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
int,
int64_t,
float,
double) {}
PD_REGISTER_KERNEL(randperm,
XPU,
ALL_LAYOUT,
phi::RandpermKernel,
int,
int64_t,
float,
double) {}
......@@ -57,4 +57,5 @@ PD_REGISTER_KERNEL(sum_raw,
float,
phi::dtype::float16,
int8_t,
int,
int64_t) {}
......@@ -2164,7 +2164,10 @@ class OpTest(unittest.TestCase):
else:
abs_a = 1 if abs_a < 1e-3 else abs_a
diff_mat = np.abs(a - b) / abs_a
if self.dtype == np.bool:
diff_mat = np.abs(a ^ b) / abs_a
else:
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
def distribute_fpn_proposals_wrapper(
fpn_rois,
rois_num,
min_level,
max_level,
refer_level,
refer_scale,
pixel_offset,
):
return paddle.vision.ops.distribute_fpn_proposals(
fpn_rois=fpn_rois,
min_level=min_level,
max_level=max_level,
refer_level=refer_level,
refer_scale=refer_scale,
rois_num=rois_num,
)
class XPUTestDistributeFPNProposalsOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'distribute_fpn_proposals'
self.use_dynamic_create_class = False
class TestDistributeFPNProposalsOp(XPUOpTest):
def setUp(self):
self.op_type = "distribute_fpn_proposals"
self.python_api = distribute_fpn_proposals_wrapper
self.python_out_sig = ['MultiFpnRois', 'RestoreIndex']
self.dtype = self.in_type
self.init_test_case()
self.make_rois()
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
self.inputs = {'FpnRois': (self.rois[:, 1:5], self.rois_lod)}
self.attrs = {
'max_level': self.roi_max_level,
'min_level': self.roi_min_level,
'refer_scale': self.canonical_scale,
'refer_level': self.canonical_level,
'pixel_offset': self.pixel_offset,
}
output = [
('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))
]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
}
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
def init_test_case(self):
self.roi_max_level = 5
self.roi_min_level = 2
self.canonical_scale = 224
self.canonical_level = 4
self.images_shape = [512, 512]
self.pixel_offset = True
def boxes_area(self, boxes):
offset = 1 if self.pixel_offset else 0
w = boxes[:, 2] - boxes[:, 0] + offset
h = boxes[:, 3] - boxes[:, 1] + offset
areas = w * h
assert np.all(areas >= 0), 'Negative areas founds'
return areas
def map_rois_to_fpn_levels(self, rois, lvl_min, lvl_max):
s = np.sqrt(self.boxes_area(rois))
s0 = self.canonical_scale
lvl0 = self.canonical_level
target_lvls = np.floor(lvl0 + np.log2(s / s0 + 1e-8))
target_lvls = np.clip(target_lvls, lvl_min, lvl_max)
return target_lvls
def get_sub_lod(self, sub_lvl):
sub_lod = [0, 0]
max_batch_id = sub_lvl[-1]
for i in range(max_batch_id.astype(np.int32) + 1):
sub_lod[i] = np.where(sub_lvl == i)[0].size
return sub_lod
def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max):
rois_idx_order = np.empty((0,))
rois_fpn = []
for lvl in range(lvl_min, lvl_max + 1):
idx_lvl = np.where(target_lvls == lvl)[0]
if len(idx_lvl) == 0:
rois_fpn.append((np.empty(shape=(0, 4)), [[0, 0]]))
continue
sub_lod = self.get_sub_lod(rois[idx_lvl, 0])
rois_fpn.append((rois[idx_lvl, 1:], [sub_lod]))
rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
rois_idx_restore = np.argsort(rois_idx_order).astype(
np.int32, copy=False
)
return rois_fpn, rois_idx_restore
def calc_rois_distribute(self):
lvl_min = self.roi_min_level
lvl_max = self.roi_max_level
target_lvls = self.map_rois_to_fpn_levels(
self.rois[:, 1:5], lvl_min, lvl_max
)
rois_fpn, rois_idx_restore = self.add_multilevel_roi(
self.rois, target_lvls, lvl_min, lvl_max
)
return rois_fpn, rois_idx_restore
def make_rois(self):
self.rois_lod = [[10, 4]]
rois = []
lod = self.rois_lod[0]
bno = 0
for roi_num in lod:
for i in range(roi_num):
xywh = np.random.rand(4)
xy1 = xywh[0:2] * 20
wh = xywh[2:4] * (self.images_shape - xy1)
xy2 = xy1 + wh
roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
rois.append(roi)
bno += 1
self.rois = np.array(rois).astype("float32")
class TestDistributeFPNProposalsOpWithRoisNum(TestDistributeFPNProposalsOp):
def setUp(self):
self.op_type = "distribute_fpn_proposals"
self.python_api = distribute_fpn_proposals_wrapper
self.python_out_sig = [
'MultiFpnRois',
'MultiLevelRoIsNum',
'RestoreIndex',
]
self.dtype = self.in_type
self.init_test_case()
self.make_rois()
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
self.inputs = {
'FpnRois': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': np.array(self.rois_lod[0]).astype('int32'),
}
self.attrs = {
'max_level': self.roi_max_level,
'min_level': self.roi_min_level,
'refer_scale': self.canonical_scale,
'refer_level': self.canonical_level,
'pixel_offset': self.pixel_offset,
}
output = [
('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))
]
rois_num_per_level = [
(
'rois_num%d' % i,
np.array(self.rois_fpn[i][1][0]).astype('int32'),
)
for i in range(len(self.rois_fpn))
]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
'MultiLevelRoIsNum': rois_num_per_level,
}
class TestDistributeFPNProposalsOpNoOffset(
TestDistributeFPNProposalsOpWithRoisNum
):
def init_test_case(self):
self.roi_max_level = 5
self.roi_min_level = 2
self.canonical_scale = 224
self.canonical_level = 4
self.images_shape = [512, 512]
self.pixel_offset = False
support_types = get_xpu_op_support_types('distribute_fpn_proposals')
for stype in support_types:
create_test_class(globals(), XPUTestDistributeFPNProposalsOp, stype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.static import Program, program_guard
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
def check_randperm_out(n, data_np):
assert isinstance(
data_np, np.ndarray
), "The input data_np should be np.ndarray."
gt_sorted = np.arange(n)
out_sorted = np.sort(data_np)
return list(gt_sorted == out_sorted)
def error_msg(data_np):
return (
"The sorted ground truth and sorted out should "
+ "be equal, out = "
+ str(data_np)
)
def convert_dtype(dtype_str):
dtype_str_list = [np.int32, np.int64, np.float32, np.float64]
dtype_num_list = [
core.VarDesc.VarType.INT32,
core.VarDesc.VarType.INT64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP64,
]
assert dtype_str in dtype_str_list, (
dtype_str + " should in " + str(dtype_str_list)
)
return dtype_num_list[dtype_str_list.index(dtype_str)]
class XPUTestRandpermOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "randperm"
self.use_dynamic_create_class = False
class TestXPURandpermOp(XPUOpTest):
"""Test randperm op."""
def setUp(self):
self.init_op_type()
self.initTestCase()
self.dtype = self.in_type
self.use_xpu = True
self.use_mkldnn = False
self.inputs = {}
self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
self.attrs = {
"n": self.n,
"dtype": convert_dtype(self.dtype),
}
def init_op_type(self):
self.op_type = "randperm"
self.use_mkldnn = False
def initTestCase(self):
self.n = 200
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
out_np = np.array(outs[0])
self.assertTrue(
check_randperm_out(self.n, out_np), msg=error_msg(out_np)
)
class TestXPURandpermOpN(TestXPURandpermOp):
def initTestCase(self):
self.n = 10000
class TestRandpermImperative(unittest.TestCase):
def test_out(self):
paddle.disable_static()
n = 10
dtype = self.in_type
data_p = paddle.randperm(n, dtype)
data_np = data_p.numpy()
self.assertTrue(
check_randperm_out(n, data_np), msg=error_msg(data_np)
)
paddle.enable_static()
class TestRandpermEager(unittest.TestCase):
def test_out(self):
paddle.disable_static()
n = 10
dtype = self.in_type
data_p = paddle.randperm(n, dtype)
data_np = data_p.numpy()
self.assertTrue(
check_randperm_out(n, data_np), msg=error_msg(data_np)
)
paddle.enable_static()
support_types = get_xpu_op_support_types("randperm")
for stype in support_types:
create_test_class(globals(), XPUTestRandpermOp, stype)
class TestRandpermAPI(unittest.TestCase):
def test_out(self):
n = 10
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
else:
place = paddle.CPUPlace()
with program_guard(Program(), Program()):
x1 = paddle.randperm(n)
x2 = paddle.randperm(n, 'float32')
exe = paddle.static.Executor(place)
res = exe.run(fetch_list=[x1, x2])
self.assertEqual(res[0].dtype, np.int64)
self.assertEqual(res[1].dtype, np.float32)
self.assertTrue(check_randperm_out(n, res[0]))
self.assertTrue(check_randperm_out(n, res[1]))
if __name__ == "__main__":
unittest.main()
......@@ -19,119 +19,127 @@ import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
class TestXPUTransposeOp(XPUOpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.use_xpu = True
self.use_mkldnn = False
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': False,
'use_xpu': True,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis),
}
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = False
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place=place, no_check_set=['XShape'])
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
class TestCase_ZeroDim(TestXPUTransposeOp):
def initTestCase(self):
self.shape = ()
self.axis = ()
class TestCase0(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (100,)
self.axis = (0,)
class TestCase1(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
class TestCase2(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
class TestCase5(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 16, 96)
self.axis = (0, 2, 1)
class TestCase6(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 10, 12, 16)
self.axis = (3, 1, 2, 0)
class TestCase7(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 10, 2, 16)
self.axis = (0, 1, 3, 2)
class TestCase8(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6, 7)
class TestCase9(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestCase10(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2)
self.axis = (-1, 1, -3)
class XPUTestXPUTransposeOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'transpose'
self.use_dynamic_create_class = False
class TestXPUTransposeOp(XPUOpTest):
def setUp(self):
self.init_op_type()
self.init_type()
self.initTestCase()
self.use_xpu = True
self.use_mkldnn = False
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': False,
'use_xpu': True,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype(self.dtype),
'Out': self.inputs['X'].transpose(self.axis),
}
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = False
def init_type(self):
self.dtype = self.in_type
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(
place=place, no_check_set=['XShape']
)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
class TestCase_ZeroDim(TestXPUTransposeOp):
def initTestCase(self):
self.shape = ()
self.axis = ()
class TestCase0(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (100,)
self.axis = (0,)
class TestCase1(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
class TestCase2(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
class TestCase5(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 16, 96)
self.axis = (0, 2, 1)
class TestCase6(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 10, 12, 16)
self.axis = (3, 1, 2, 0)
class TestCase7(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 10, 2, 16)
self.axis = (0, 1, 3, 2)
class TestCase8(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6, 7)
class TestCase9(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestCase10(TestXPUTransposeOp):
def initTestCase(self):
self.shape = (200, 3, 2)
self.axis = (-1, 1, -3)
support_types = get_xpu_op_support_types('transpose')
for stype in support_types:
create_test_class(globals(), XPUTestXPUTransposeOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册