未验证 提交 9a1855ff 编写于 作者: L Leo Guo 提交者: GitHub

Add index_select, index_select_grad, reduce_min kernel and their unittests for...

Add index_select, index_select_grad, reduce_min kernel and their unittests for kunlun. Add registers of index_select, index_select_grad, reduce_min, sqrt, sqrt_grad to xpu2_op_list.test=kunlun. (#46557)
上级 98deee29
......@@ -301,6 +301,10 @@ XPUOpMap& get_kl2_ops() {
{"huber_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"index_select",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"instance_norm",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"instance_norm_grad",
......@@ -422,6 +426,7 @@ XPUOpMap& get_kl2_ops() {
{"reduce_mean_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_min", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......@@ -510,6 +515,8 @@ XPUOpMap& get_kl2_ops() {
{"split",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"squeeze2_grad",
......
......@@ -42,7 +42,7 @@ PD_REGISTER_KERNEL(
min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {}
#endif
#if defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {}
#endif
......@@ -50,3 +50,7 @@ PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {}
PD_REGISTER_KERNEL(
min, OneDNN, ALL_LAYOUT, phi::MinKernel, float, phi::dtype::bfloat16) {}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(min, XPU, ALL_LAYOUT, phi::MinKernel, float) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output) {
auto input_dim = x.dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto* in_data = x.data<T>();
std::vector<int> in_shape = phi::vectorize<int>(input_dim);
int index_len = output->dims()[dim];
T* out_data = ctx.template Alloc<T>(output);
int r = 0;
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
r = xpu::gather<T, int64_t>(ctx.x_context(),
in_data,
index_data,
out_data,
in_shape,
index_len,
dim);
} else {
const int* index_data = index.data<int>();
r = xpu::gather<T, int>(ctx.x_context(),
in_data,
index_data,
out_data,
in_shape,
index_len,
dim);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
}
} // namespace phi
PD_REGISTER_KERNEL(index_select,
XPU,
ALL_LAYOUT,
phi::IndexSelectKernel,
float,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/reduce.h"
namespace phi {
template <typename T, typename Context>
void MinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
int r = XPUReduce<Context, T>(dev_ctx,
x,
dims.GetData(),
keep_dim,
reduce_all,
out,
xpu::reduce_min<T>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_min");
}
} // namespace phi
PD_REGISTER_KERNEL(min_raw, XPU, ALL_LAYOUT, phi::MinRawKernel, float) {}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
sys.path.append("..")
import numpy as np
import paddle
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestIndexSelect(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'index_select'
class TestXPUIndexSelectOp(XPUOpTest):
def setUp(self):
self.op_type = "index_select"
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.init_dtype_type()
index_np = np.random.randint(low=0,
high=self.x_shape[self.dim],
size=self.index_size).astype(
self.index_type)
x_np = np.random.random(self.x_shape).astype(self.dtype)
self.inputs = {'X': x_np, 'Index': index_np}
self.attrs = {'dim': self.dim}
outer_loop = np.prod(self.x_shape[:self.dim])
x_reshape = [outer_loop] + list(self.x_shape[self.dim:])
x_np_reshape = np.reshape(x_np, tuple(x_reshape))
out_list = []
for i in range(outer_loop):
for j in range(self.index_size):
out_list.append(x_np_reshape[i, index_np[j]])
self.out_shape = list(self.x_shape)
self.out_shape[self.dim] = self.index_size
self.out_shape = tuple(self.out_shape)
out = np.reshape(out_list, self.out_shape)
self.outputs = {'Out': out}
def init_dtype_type(self):
self.dim = 1
self.index_type = np.int64
self.x_shape = (100, 4, 5)
self.index_size = 100
def test_check_output(self):
if paddle.is_compiled_with_xpu():
self.check_output_with_place(self.place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestXPUIndexSelectOpCase2(TestXPUIndexSelectOp):
def init_dtype_type(self):
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]])
self.data_index = np.array([0, 1, 1]).astype('int32')
def test_index_select_api(self):
self.input_data()
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(name='index',
shape=[3],
dtype='int32',
append_batch_size=False)
z = paddle.index_select(x, index, axis=1)
exe = fluid.Executor(fluid.XPUPlace(0))
res, = exe.run(feed={
'x': self.data_x,
'index': self.data_index
},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
[9.0, 10.0, 10.0]])
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 4])
index = fluid.layers.data(name='index',
shape=[3],
dtype='int32',
append_batch_size=False)
z = paddle.index_select(x, index)
exe = fluid.Executor(fluid.XPUPlace(0))
res, = exe.run(feed={
'x': self.data_x,
'index': self.data_index
},
fetch_list=[z.name],
return_numpy=False)
expect_out = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[5.0, 6.0, 7.0, 8.0]])
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
def test_dygraph_api(self):
self.input_data()
# case 1:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index)
z = paddle.index_select(x, index)
np_z = z.numpy()
expect_out = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[5.0, 6.0, 7.0, 8.0]])
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
index = fluid.dygraph.to_variable(self.data_index)
z = paddle.index_select(x, index, axis=1)
np_z = z.numpy()
expect_out = np.array([[1.0, 2.0, 2.0], [5.0, 6.0, 6.0],
[9.0, 10.0, 10.0]])
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
support_types = get_xpu_op_support_types('index_select')
for stype in support_types:
create_test_class(globals(), XPUTestIndexSelect, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册