未验证 提交 47306c58 编写于 作者: H houj04 提交者: GitHub

[XPU] add fp16 support for tril_triu. add index_sample op. (#50655)

上级 507af1c8
......@@ -7,7 +7,7 @@ set(XPU_PROJECT "extern_xpu")
set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_BASE_DATE "20230215")
set(XPU_BASE_DATE "20230220")
set(XPU_XCCL_BASE_VERSION "1.0.8")
if(NOT DEFINED XPU_BASE_URL)
......
......@@ -350,6 +350,14 @@ XPUOpMap& get_kl2_ops() {
{"kldiv_loss", XPUKernelSet({phi::DataType::FLOAT32})},
{"kldiv_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
{"index_sample",
XPUKernelSet({phi::DataType::INT8,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::BOOL})},
{"index_select",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
......@@ -642,15 +650,29 @@ XPUOpMap& get_kl2_ops() {
{"temporal_shift", XPUKernelSet({phi::DataType::FLOAT32})},
{"temporal_shift_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"tril_triu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"tril", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"triu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"tril",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"triu",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"tril_triu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"tril_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"triu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"tile",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_sample_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void IndexSampleKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
DenseTensor* out) {
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataTypeToString(index_type),
phi::DataTypeToString(DataType::INT32),
phi::DataTypeToString(DataType::INT64)));
using XPUType = typename XPUTypeTrait<T>::Type;
auto input_dim = x.dims();
auto index_dim = index.dims();
int64_t batch_size = input_dim[0];
int64_t input_length = input_dim[1];
int64_t index_length = index_dim[1];
const T* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
// template<typename T, typename TID> DLL_EXPORT int gather_element(Context*
// ctx, const T* x, const TID* index, T* y, const std::vector<int64_t>&
// xshape, const std::vector<int64_t>& idxshape, int64_t axis);
if (index_type == DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
int r = xpu::gather_element<XPUType, int64_t>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(in_data),
index_data,
reinterpret_cast<XPUType*>(out_data),
{batch_size, input_length},
{batch_size, index_length},
1);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_element");
} else if (index_type == DataType::INT32) {
const int* index_data = index.data<int>();
int r = xpu::gather_element<XPUType, int32_t>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(in_data),
index_data,
reinterpret_cast<XPUType*>(out_data),
{batch_size, input_length},
{batch_size, index_length},
1);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_element");
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_sample,
XPU,
ALL_LAYOUT,
phi::IndexSampleKernel,
phi::dtype::float16,
float,
int8_t,
int16_t,
int32_t,
bool) {}
......@@ -64,9 +64,24 @@ void TriuGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
tril_grad, XPU, ALL_LAYOUT, phi::TrilGradKernel, int, float) {}
PD_REGISTER_KERNEL(
triu_grad, XPU, ALL_LAYOUT, phi::TriuGradKernel, int, float) {}
PD_REGISTER_KERNEL(
tril_triu_grad, XPU, ALL_LAYOUT, phi::TrilTriuGradKernel, int, float) {}
PD_REGISTER_KERNEL(tril_grad,
XPU,
ALL_LAYOUT,
phi::TrilGradKernel,
int,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(triu_grad,
XPU,
ALL_LAYOUT,
phi::TriuGradKernel,
int,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(tril_triu_grad,
XPU,
ALL_LAYOUT,
phi::TrilTriuGradKernel,
int,
float,
phi::dtype::float16) {}
......@@ -64,7 +64,14 @@ void TriuKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(tril_triu,
XPU,
ALL_LAYOUT,
phi::TrilTriuKernel,
int,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
tril_triu, XPU, ALL_LAYOUT, phi::TrilTriuKernel, int, float) {}
PD_REGISTER_KERNEL(tril, XPU, ALL_LAYOUT, phi::TrilKernel, int, float) {}
PD_REGISTER_KERNEL(triu, XPU, ALL_LAYOUT, phi::TriuKernel, int, float) {}
tril, XPU, ALL_LAYOUT, phi::TrilKernel, int, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
triu, XPU, ALL_LAYOUT, phi::TriuKernel, int, float, phi::dtype::float16) {}
......@@ -27,7 +27,6 @@ from xpu.get_test_cover_info import (
)
import paddle
import paddle.fluid.core as core
paddle.enable_static()
......@@ -40,7 +39,6 @@ class XPUTestCumsumOP(XPUOpTestWrapper):
class TestCumsumOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.xpu_version = core.get_xpu_device_version(0)
self.init_dtype()
self.set_case()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
import paddle.fluid as fluid
paddle.enable_static()
class XPUTestIndexSampleOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'index_sample'
self.use_dynamic_create_class = False
class TestIndexSampleOPBase(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.set_case()
def set_case(self):
self.op_type = 'index_sample'
self.config()
xnp = np.random.random(self.x_shape).astype(self.dtype)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
self.inputs = {'X': xnp, 'Index': indexnp}
index_array = []
for i in range(self.index_shape[0]):
for j in indexnp[i]:
index_array.append(xnp[i, j])
index_array = np.array(index_array).astype(self.dtype)
out = np.reshape(index_array, self.index_shape)
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
self.x_shape = (10, 20)
self.index_shape = (10, 10)
self.index_type = "int32"
class XPUTestIndexSample1(TestIndexSampleOPBase):
def config(self):
self.x_shape = (100, 1)
self.index_shape = (100, 1)
self.index_type = "int32"
class XPUTestIndexSample2(TestIndexSampleOPBase):
def config(self):
self.x_shape = (100, 1)
self.index_shape = (100, 1)
self.index_type = "int64"
class XPUTestIndexSample3(TestIndexSampleOPBase):
def config(self):
self.x_shape = (10, 100)
self.index_shape = (10, 10)
self.index_type = "int64"
class XPUTestIndexSample4(TestIndexSampleOPBase):
def config(self):
self.x_shape = (10, 100)
self.index_shape = (10, 10)
self.index_type = "int32"
class XPUTestIndexSample5(TestIndexSampleOPBase):
def config(self):
self.x_shape = (10, 128)
self.index_shape = (10, 64)
self.index_type = "int64"
class XPUTestIndexSample6(TestIndexSampleOPBase):
def config(self):
self.x_shape = (10, 128)
self.index_shape = (10, 64)
self.index_type = "int32"
class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
# create x value
x_shape = (2, 5)
x_np = np.random.random(x_shape).astype('float32')
# create index value
index_shape = (2, 3)
index_type = "int32"
index_np = np.random.randint(
low=0, high=x_shape[1], size=index_shape
).astype(index_type)
x = fluid.data(name='x', shape=[-1, 5], dtype='float32')
index = fluid.data(name='index', shape=[-1, 3], dtype='int32')
output = paddle.index_sample(x=x, index=index)
place = fluid.XPUPlace(0)
exe = fluid.Executor(place=place)
exe.run(fluid.default_startup_program())
feed = {'x': x_np, 'index': index_np}
res = exe.run(feed=feed, fetch_list=[output])
class TestIndexSampleDynamic(unittest.TestCase):
def test_result(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
],
dtype='float32',
)
index = paddle.to_tensor(
[[0, 1, 2], [1, 2, 3], [0, 0, 0]], dtype='int32'
)
out_z1 = paddle.index_sample(x, index)
except_output = np.array(
[[1.0, 2.0, 3.0], [6.0, 7.0, 8.0], [9.0, 9.0, 9.0]]
)
assert out_z1.numpy().all() == except_output.all()
support_types = get_xpu_op_support_types('index_sample')
for stype in support_types:
create_test_class(globals(), XPUTestIndexSampleOP, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册