未验证 提交 332a73b1 编写于 作者: L lijin23 提交者: GitHub

[XPU][PHI Kernels] add index_put kernel for xpu (#56169)

* add inverse kernel for xpu

* add more kernels

* add index_put kernel for xpu

* add index_put kernel for xpu

* remove unused headers

* refine test

* wait to avoid memory bugs for xpu

* refine inverse
上级 3e55f255
......@@ -325,6 +325,7 @@ XPUOpMap& get_kl2_ops() {
{"fill_any_like",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"fill_diagonal_tensor",
......@@ -449,6 +450,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"index_put",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"index_sample_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"index_sample",
XPUKernelSet({phi::DataType::INT8,
......@@ -465,6 +470,8 @@ XPUOpMap& get_kl2_ops() {
{"instance_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"inverse",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
{"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})},
{"lamb", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -687,6 +694,12 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BOOL})},
{"set_value_with_tensor",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"set_value_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
......@@ -966,7 +979,9 @@ XPUOpMap& get_kl2_ops() {
{"cos", XPUKernelSet({phi::DataType::FLOAT32})},
{"cos_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"linspace",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"group_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -105,6 +105,9 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
#ifdef PADDLE_WITH_XPU
dev_ctx.Wait();
#endif
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
(indices_v[i]->dtype() == phi::DataType::INT32)) {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/stack_kernel.h"
namespace phi {
template <typename Context>
void XPUDealWithIndices(const Context& dev_ctx,
const std::vector<const DenseTensor*>& int_indices_v,
DDim bd_dim,
DenseTensor* out) {
std::vector<DenseTensor> tmp_indices_v;
for (size_t i = 0; i < int_indices_v.size(); ++i) {
// Use int64 for all indices Because XPU needs to merge all indices into a
// single tensor. Same with CPU and GPU.
DenseTensor casted_index;
if (int_indices_v[i]->dtype() == DataType::INT32) {
casted_index =
phi::Cast<int, Context>(dev_ctx, *int_indices_v[i], DataType::INT64);
} else {
casted_index = *int_indices_v[i];
}
DenseTensor expanded_index(DataType::INT64);
if (casted_index.dims() == bd_dim) {
expanded_index = casted_index;
} else {
expanded_index.Resize(bd_dim);
ExpandKernel<int64_t, Context>(dev_ctx,
casted_index,
IntArray(vectorize<int64_t>(bd_dim)),
&expanded_index);
}
tmp_indices_v.emplace_back(expanded_index);
}
auto bd_dim_vec = vectorize<int64_t>(bd_dim);
std::vector<int64_t> stacked_dim_vec(bd_dim.size() + 1);
std::copy(bd_dim_vec.begin(), bd_dim_vec.end(), stacked_dim_vec.begin());
stacked_dim_vec.back() = int_indices_v.size();
out->Resize(make_ddim(stacked_dim_vec));
std::vector<const DenseTensor*> tmp_indices_ptr(tmp_indices_v.size(),
nullptr);
for (size_t i = 0; i < tmp_indices_ptr.size(); ++i) {
tmp_indices_ptr[i] = &tmp_indices_v[i];
}
StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out);
dev_ctx.Wait();
}
template <typename T, typename Context>
void IndexPutKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
bool accumulate,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor value must be same to the data type "
"of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(),
false,
phi::errors::InvalidArgument("Indices cannot be empty."));
const int64_t total_dims = x.dims().size();
PADDLE_ENFORCE_LE(
total_dims,
6,
errors::InvalidArgument("Dims of input tensor should be less than 7."));
// All bool indices are converted to integers currently
std::vector<DenseTensor> tmp_args;
std::vector<const DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
if (int_indices_v.empty()) {
if (!out->initialized()) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
return;
}
using XPUT = typename XPUTypeTrait<T>::Type;
auto out_data = dev_ctx.template Alloc<T>(out);
auto bd_dims = funcs::BroadCastTensorsDims(int_indices_v);
DenseTensor res_indices(DataType::INT64);
// Broadcast and merge indices
XPUDealWithIndices<Context>(dev_ctx, int_indices_v, bd_dims, &res_indices);
auto index_shape = vectorize<int64_t>(res_indices.dims());
auto x_shape = vectorize<int64_t>(x.dims());
const T* value_data = value.data<T>();
// Broadcast value
auto value_shape = vectorize<int64_t>(value.dims());
int64_t value_rank = bd_dims.size() + (x_shape.size() - int_indices_v.size());
std::vector<int64_t> value_shape_bd(value_rank);
std::copy(index_shape.begin(), index_shape.end() - 1, value_shape_bd.begin());
std::copy(x_shape.begin() + int_indices_v.size(),
x_shape.end(),
value_shape_bd.begin() + index_shape.size() - 1);
DenseTensor value_bd(value.dtype());
if (value_shape != value_shape_bd) {
value_bd.Resize(make_ddim(value_shape_bd));
ExpandKernel<T, Context>(
dev_ctx, value, IntArray(value_shape_bd), &value_bd);
value_data = value_bd.data<T>();
}
int r =
xpu::index_put<XPUT, int64_t>(dev_ctx.x_context(),
reinterpret_cast<const XPUT*>(x.data<T>()),
reinterpret_cast<const XPUT*>(value_data),
res_indices.data<int64_t>(),
reinterpret_cast<XPUT*>(out_data),
x_shape,
index_shape,
accumulate);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put");
dev_ctx.Wait();
}
} // namespace phi
PD_REGISTER_KERNEL(
index_put, XPU, ALL_LAYOUT, phi::IndexPutKernel, float, int, int64_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/inverse_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void InverseKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
auto out_data = dev_ctx.template Alloc<T>(out);
int64_t x_dims_len = x.dims().size();
PADDLE_ENFORCE_GT(
x_dims_len,
1,
phi::errors::InvalidArgument(
"Dimensions of input should be greater than 1, but got %d.",
x_dims_len));
int64_t n = x.dims()[x_dims_len - 1];
int64_t batch = x_dims_len > 2 ? x.numel() / (n * n) : 1;
PADDLE_ENFORCE_LE(n * n * sizeof(T),
8192,
phi::errors::InvalidArgument(
"The size of a single matrix (%d bytes) exceeds the "
"maxinum numbers of bytes xpu supports (8192).",
n * n * sizeof(T)));
auto RAII_GUARD = xpu::ctx_guard(dev_ctx.x_context());
auto* info_xpu = RAII_GUARD.alloc_l3_or_gm<int>(batch);
// Xpu inverse api has check for singularity itself.
int r = xpu::inverse<XPUT>(dev_ctx.x_context(),
reinterpret_cast<const XPUT*>(x.data<T>()),
reinterpret_cast<XPUT*>(out_data),
info_xpu,
batch,
n);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "inverse");
}
} // namespace phi
PD_REGISTER_KERNEL(
inverse, XPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
......@@ -81,7 +81,7 @@ void LinspaceKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
linspace, XPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t) {
linspace, XPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t, int64_t) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
......
......@@ -429,4 +429,5 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
float,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
bool) {}
......@@ -34,7 +34,8 @@ void SplitKernel(const Context& dev_ctx,
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.template Alloc<T>(outs[j]);
out_ptrs.push_back(reinterpret_cast<XPUType*>(outs[j]->data<T>()));
split_lists.push_back(outs[j]->dims()[axis]);
split_lists.push_back(axis < outs[j]->dims().size() ? outs[j]->dims()[axis]
: 1);
}
if (x.numel() == 0) {
return;
......
......@@ -75,6 +75,8 @@ class XPUTestFillAnyLikeOp(XPUOpTestWrapper):
def set_value(self):
if self.dtype == "float16":
self.value = 0.05
elif self.dtype == np.bool_:
self.value = 1.0
else:
self.value = 5.0
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
def compute_index_put_ref(x_np, indices_np, value_np, accumulate=False):
if accumulate:
x_np[indices_np] += value_np
return x_np
else:
x_np[indices_np] = value_np
return x_np
def has_duplicate_index(indices, shapes):
bd_shape = np.broadcast_shapes(*shapes)
bd_indices = [
list(np.broadcast_to(indice, bd_shape).flatten()) for indice in indices
]
zip_res = list(zip(*bd_indices))
if len(zip_res) == len(set(zip_res)):
return False
else:
return True
def gen_indices_np(x_shape, indices_shapes, index_type, is_all_false):
indices = []
if index_type == np.bool_:
indice = np.zeros(indices_shapes[0], dtype=np.bool_)
if not is_all_false:
indice.flatten()
for i in range(len(indice)):
indice[i] = (i & 1) == 0
indice = indice.reshape(indices_shapes[0])
indices.append(indice)
else:
while True:
indices = []
for i in range(len(indices_shapes)):
np.random.seed()
index_np = np.random.randint(
low=0,
high=x_shape[i],
size=indices_shapes[i],
dtype=index_type,
)
indices.append(index_np)
if not has_duplicate_index(
copy.deepcopy(indices), copy.deepcopy(indices_shapes)
):
break
return tuple(indices)
class XPUTestIndexPut(XPUOpTestWrapper):
def __init__(self):
self.op_name = "index_put"
self.use_dynamic_create_class = False
class TestXPUIndexPutOp(XPUOpTest):
def setUp(self):
self.op_type = "index_put"
self.x_dtype = self.in_type
self.mixed_indices = False
self.is_all_false = False
self.place = paddle.XPUPlace(0)
self.set_case()
self.init_data()
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.accumulate = False
def init_data(self):
x_np = ((np.random.random(self.x_shape) - 0.5) * 10.0).astype(
self.x_dtype
)
value_np = (
(np.random.random(self.value_shape) - 0.5) * 10.0
).astype(self.x_dtype)
if self.mixed_indices:
tmp_indices_np1 = gen_indices_np(
self.x_shape,
self.indices_shapes,
self.index_dtype,
self.is_all_false,
)
tmp_indices_np2 = gen_indices_np(
self.x_shape,
self.indices_shapes1,
self.index_dtype1,
self.is_all_false,
)
self.indices_np = tuple(
list(tmp_indices_np1) + list(tmp_indices_np2)
)
else:
self.indices_np = gen_indices_np(
self.x_shape,
self.indices_shapes,
self.index_dtype,
self.is_all_false,
)
indices_names = self.get_indices_names()
indices_name_np = []
for index_name, index_np in zip(indices_names, self.indices_np):
indices_name_np.append((index_name, index_np))
self.inputs = {
'x': x_np,
'indices': indices_name_np,
'value': value_np,
}
self.attrs = {'accumulate': self.accumulate}
if self.is_all_false:
out_np = x_np
else:
out_np = compute_index_put_ref(
copy.deepcopy(x_np),
self.indices_np,
value_np,
self.accumulate,
)
self.outputs = {'out': out_np}
def get_indices_names(self):
indices_names = []
for i in range(len(self.indices_np)):
indices_names.append(f"index_{i}")
return indices_names
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["x", "value"], "out")
class TestXPUIndexPut1(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = [(16, 16), (16, 16), (1, 16), (1, 16)]
self.value_shape = (16, 16)
self.accumulate = False
class TestXPUIndexPut2(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = [(16, 16), (16, 16), (1, 16), (1, 16)]
self.value_shape = (16, 16)
self.accumulate = True
class TestXPUIndexPut3(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (110, 94)
self.indices_shapes = [(110, 94)]
self.value_shape = (5170,)
self.accumulate = False
class TestXPUIndexPut4(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (110, 94)
self.indices_shapes = [(110, 94)]
self.value_shape = (5170,)
self.accumulate = True
class TestXPUIndexPut5(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (16, 16, 56)
self.accumulate = False
class TestXPUIndexPut6(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (16, 16, 56)
self.accumulate = True
class TestXPUIndexPut7(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (110, 94)
self.indices_shapes = [(110,)]
self.value_shape = (55, 94)
self.accumulate = False
self.is_all_false = True
class TestXPUIndexPut8(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (110, 94)
self.indices_shapes = [(110,)]
self.value_shape = (55, 94)
self.accumulate = True
class TestXPUIndexPut9(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (56,)
self.accumulate = False
class TestXPUIndexPut10(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (56,)
self.accumulate = True
class TestXPUIndexPut11(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (1,)
self.accumulate = False
class TestXPUIndexPut12(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (1,)
self.accumulate = True
class TestXPUIndexPut13(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (44, 94)
self.indices_shapes = [(44,)]
self.value_shape = (94,)
self.accumulate = False
class TestXPUIndexPut14(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (44, 94)
self.indices_shapes = [(44,)]
self.value_shape = (94,)
self.accumulate = True
class TestXPUIndexPut15(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (44, 94)
self.indices_shapes = [(44,)]
self.value_shape = (1,)
self.accumulate = False
class TestXPUIndexPut16(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (44, 94)
self.indices_shapes = [(44,)]
self.value_shape = (1,)
self.accumulate = True
class TestXPUIndexPut17(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.accumulate = False
class TestXPUIndexPut18(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.accumulate = True
class TestXPUIndexPutMixedIndices(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.accumulate = False
self.mixed_indices = True
self.index_dtype1 = np.bool_
self.indices_shapes1 = [(32,)]
class TestXPUIndexPutMixedIndices1(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.accumulate = True
self.mixed_indices = True
self.index_dtype1 = np.bool_
self.indices_shapes1 = [(32,)]
supported_type = get_xpu_op_support_types("index_put")
for stype in supported_type:
create_test_class(globals(), XPUTestIndexPut, stype)
class TestIndexPutInplaceAPI(unittest.TestCase):
def setUp(self):
self.init_dtype_type()
self.setPlace()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.value_np = np.random.random(self.value_shape).astype(self.dtype_np)
self.indices_np = gen_indices_np(
self.x_shape, self.indices_shapes, self.index_type_np, False
)
def init_dtype_type(self):
self.dtype_np = np.float32
self.index_type_np = np.int64
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.float32
self.index_type_pd = paddle.int64
self.accumulate = False
def setPlace(self):
self.place = ['xpu']
def test_dygraph_forward(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd)
self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd)
self.indices_pd = [
paddle.to_tensor(indice, dtype=self.index_type_pd)
for indice in self.indices_np
]
self.indices_pd = tuple(self.indices_pd)
ref_res = compute_index_put_ref(
self.x_np, self.indices_np, self.value_np, self.accumulate
)
x_pd_bk = self.x_pd.clone()
pd_res = paddle.index_put_(
x_pd_bk, self.indices_pd, self.value_pd, self.accumulate
)
np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7)
np.testing.assert_allclose(ref_res, x_pd_bk.numpy(), atol=1e-7)
paddle.enable_static()
class TestIndexPutInplaceAPI1(TestIndexPutInplaceAPI):
def init_dtype_type(self):
self.dtype_np = np.float32
self.index_type_np = np.int64
self.x_shape = (100, 110)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.float32
self.index_type_pd = paddle.int64
self.accumulate = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest
import paddle
from paddle import fluid
paddle.enable_static()
class XPUTestInverseOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "inverse"
self.use_dynamic_create_class = False
class TestXPUInverseOp(XPUOpTest):
def setUp(self):
self.op_type = "inverse"
self.place = paddle.XPUPlace(0)
self.set_dtype()
self.set_shape()
self.init_input_output()
def set_shape(self):
self.input_shape = [10, 10]
def init_input_output(self):
np.random.seed(123)
x = np.random.random(self.input_shape).astype(self.dtype)
out = np.linalg.inv(x).astype(self.dtype)
self.inputs = {"Input": x}
self.outputs = {"Output": out}
def set_dtype(self):
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['Input'], 'Output')
class TestXPUInverseOpBatched(TestXPUInverseOp):
def set_shape(self):
self.input_shape = [8, 4, 4]
class TestXPUInverseOpLarge(TestXPUInverseOp):
def set_shape(self):
self.input_shape = [32, 32]
support_types = get_xpu_op_support_types("inverse")
for stype in support_types:
create_test_class(globals(), XPUTestInverseOp, stype)
class TestInverseSingularAPI(unittest.TestCase):
def setUp(self):
self.places = [fluid.XPUPlace(0)]
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = paddle.static.data(
name="input", shape=[4, 4], dtype="float32"
)
result = paddle.inverse(x=input)
input_np = np.ones([4, 4]).astype("float32")
exe = fluid.Executor(place)
with self.assertRaises(OSError):
fetches = exe.run(
fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_np = np.ones([4, 4]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
with self.assertRaises(OSError):
result = paddle.inverse(input)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册