未验证 提交 0d12afea 编写于 作者: W wangshengxiang 提交者: GitHub

xpu: bind op scatter_nd_add. add data type for transpose2, clip & assign_value (#50825)

* [XPU] bind op scatter_nd_add

* [XPU] add more data type for op: clip, transpose2 & assign_value
上级 a36cdd6b
...@@ -49,7 +49,11 @@ XPUOpMap& get_kl2_ops() { ...@@ -49,7 +49,11 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::BOOL})}, phi::DataType::BOOL})},
{"assign_value", XPUKernelSet({phi::DataType::FLOAT32})}, {"assign_value",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"atan_grad", {"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
...@@ -120,10 +124,15 @@ XPUOpMap& get_kl2_ops() { ...@@ -120,10 +124,15 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32})}, phi::DataType::INT32})},
{"check_finite_and_unscale", {"check_finite_and_unscale",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip", XPUKernelSet({phi::DataType::FLOAT32})}, {"clip",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_grad", {"clip_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"coalesce_tensor", {"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat_grad", {"concat_grad",
...@@ -545,6 +554,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -545,6 +554,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"scatter_nd_add", XPUKernelSet({phi::DataType::FLOAT32})},
{"sampling_id", {"sampling_id",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
{"set_value", {"set_value",
...@@ -692,13 +702,29 @@ XPUOpMap& get_kl2_ops() { ...@@ -692,13 +702,29 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transpose2_grad", {"transpose2_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose2", {"transpose2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose_grad", {"transpose_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"transpose", {"transpose",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL})},
{"truncated_gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})}, {"truncated_gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"top_k", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"top_k", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"top_k_v2", {"top_k_v2",
......
...@@ -34,11 +34,11 @@ void ClipGradKernel(const Context& ctx, ...@@ -34,11 +34,11 @@ void ClipGradKernel(const Context& ctx,
reinterpret_cast<const XPUDataType*>(out_grad.data<T>()), reinterpret_cast<const XPUDataType*>(out_grad.data<T>()),
reinterpret_cast<XPUDataType*>(x_grad->data<T>()), reinterpret_cast<XPUDataType*>(x_grad->data<T>()),
x.numel(), x.numel(),
min.to<T>(), min.to<XPUDataType>(),
max.to<T>()); max.to<XPUDataType>());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad");
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
clip_grad, XPU, ALL_LAYOUT, phi::ClipGradKernel, float, int) {} clip_grad, XPU, ALL_LAYOUT, phi::ClipGradKernel, float, int64_t, int) {}
...@@ -33,8 +33,8 @@ void ClipKernel(const Context& dev_ctx, ...@@ -33,8 +33,8 @@ void ClipKernel(const Context& dev_ctx,
x_data, x_data,
out_data, out_data,
x.numel(), x.numel(),
min.to<float>(), min.to<XPUDataType>(),
max.to<float>()); max.to<XPUDataType>());
PADDLE_ENFORCE_EQ(r, PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS, XPU_SUCCESS,
...@@ -46,4 +46,5 @@ void ClipKernel(const Context& dev_ctx, ...@@ -46,4 +46,5 @@ void ClipKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(clip, XPU, ALL_LAYOUT, phi::ClipKernel, float) {} PD_REGISTER_KERNEL(
clip, XPU, ALL_LAYOUT, phi::ClipKernel, float, int64_t, int) {}
...@@ -57,8 +57,6 @@ void GatherNdGradKernel(const Context &ctx, ...@@ -57,8 +57,6 @@ void GatherNdGradKernel(const Context &ctx,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64)); phi::DataType::INT64));
int index_size =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
auto x_shape = phi::vectorize<int64_t>(x_grad->dims()); auto x_shape = phi::vectorize<int64_t>(x_grad->dims());
auto index_shape = phi::vectorize<int64_t>(index.dims()); auto index_shape = phi::vectorize<int64_t>(index.dims());
if (index_shape.size() == 1) { if (index_shape.size() == 1) {
...@@ -70,6 +68,7 @@ void GatherNdGradKernel(const Context &ctx, ...@@ -70,6 +68,7 @@ void GatherNdGradKernel(const Context &ctx,
DenseTensor index_cpu(index.type()); DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu); phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);
int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
auto index_data = const_cast<int *>(index.data<int>()); auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{ xpu::VectorParam<int> index_vec{
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/scatter_nd_add_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ScatterNdAddKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &index,
const DenseTensor &updates,
DenseTensor *out) {
const T *x_ptr = x.data<T>();
const T *updates_ptr = updates.data<T>();
T *out_ptr = ctx.template Alloc<T>(out);
int r = xpu::copy(ctx.x_context(), x_ptr, out_ptr, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
if (updates.numel() == 0) return;
if (index.numel() == 0) {
int loop_time =
static_cast<int>(index.dims().size() == 0 ? 1 : index.dims()[0]);
for (int i = 0; i < loop_time; i++) {
// xpu::add only support float or float16 template typename
// now, register this op only with float type
r = xpu::add<T>(ctx.x_context(),
updates_ptr + out->numel() * i,
out_ptr,
out_ptr,
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
return;
}
const phi::DataType index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s].",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto x_shape = phi::vectorize<int64_t>(x.dims());
auto index_shape = phi::vectorize<int64_t>(index.dims());
if (index_shape.size() == 1) {
index_shape.insert(index_shape.begin(), 1);
}
xpu::VectorParam<int64_t> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);
int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) {
xpu::VectorParam<int> index_vec{index_cpu.data<int>(), index_size, nullptr};
r = xpu::scatter_nd<T, int>(ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
} else {
xpu::VectorParam<int64_t> index_vec{
index_cpu.data<int64_t>(), index_size, nullptr};
r = xpu::scatter_nd<T, int64_t>(ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter_nd_add");
}
} // namespace phi
PD_REGISTER_KERNEL(
scatter_nd_add, XPU, ALL_LAYOUT, phi::ScatterNdAddKernel, float) {}
...@@ -58,4 +58,7 @@ PD_REGISTER_KERNEL(transpose_grad, ...@@ -58,4 +58,7 @@ PD_REGISTER_KERNEL(transpose_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::TransposeGradKernel, phi::TransposeGradKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16,
int64_t,
int,
bool) {}
...@@ -54,4 +54,7 @@ PD_REGISTER_KERNEL(transpose, ...@@ -54,4 +54,7 @@ PD_REGISTER_KERNEL(transpose,
ALL_LAYOUT, ALL_LAYOUT,
phi::TransposeKernel, phi::TransposeKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16,
int64_t,
int,
bool) {}
...@@ -55,7 +55,7 @@ class XPUTestAssignValueOp(XPUOpTestWrapper): ...@@ -55,7 +55,7 @@ class XPUTestAssignValueOp(XPUOpTestWrapper):
self.outputs = {"Out": self.value} self.outputs = {"Out": self.value}
def init_data(self): def init_data(self):
self.value = np.random.random(size=(2, 5)).astype(self.dtype) self.value = np.random.random(size=(2, 5)).astype(np.float32)
self.attrs["fp32_values"] = [float(v) for v in self.value.flat] self.attrs["fp32_values"] = [float(v) for v in self.value.flat]
def test_forward(self): def test_forward(self):
......
...@@ -165,6 +165,18 @@ class XPUTestGatherNd(XPUOpTestWrapper): ...@@ -165,6 +165,18 @@ class XPUTestGatherNd(XPUOpTestWrapper):
self.inp = np.array([1, 2]).astype("int64") self.inp = np.array([1, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)] self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpMultiDimIndex1(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([2, 2]).astype("int32")
self.output = self.xnp[tuple(self.inp.T)]
class XPUTestGatherNdOpMultiDimIndex2(XPUTestGatherNdBase):
def init_data(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype(self.in_type)
self.inp = np.array([2, 2]).astype("int64")
self.output = self.xnp[tuple(self.inp.T)]
support_types = get_xpu_op_support_types('gather_nd') support_types = get_xpu_op_support_types('gather_nd')
for stype in support_types: for stype in support_types:
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
def numpy_scatter_nd(ref, index, updates, fun):
ref_shape = ref.shape
index_shape = index.shape
end_size = index_shape[-1]
# as type int32, flat_index or flat_updates can't reshape to int64
remain_numl = np.prod(index_shape[:-1]).astype("int32")
slice_size = np.prod(ref_shape[end_size : len(ref_shape)]).astype("int32")
flat_index = index.reshape([remain_numl] + list(index_shape[-1:]))
flat_updates = updates.reshape((remain_numl, slice_size))
flat_output = ref.reshape(list(ref_shape[:end_size]) + [slice_size])
for i_up, i_out in enumerate(flat_index):
i_out = tuple(i_out)
flat_output[i_out] = fun(flat_output[i_out], flat_updates[i_up])
return flat_output.reshape(ref.shape)
def numpy_scatter_nd_add(ref, index, updates):
return numpy_scatter_nd(ref, index, updates, lambda x, y: x + y)
def judge_update_shape(ref, index):
ref_shape = ref.shape
index_shape = index.shape
update_shape = []
for i in range(len(index_shape) - 1):
update_shape.append(index_shape[i])
for i in range(index_shape[-1], len(ref_shape), 1):
update_shape.append(ref_shape[i])
return update_shape
class XPUTestScatterNdAdd(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'scatter_nd_add'
class TestScatterNdAdd(XPUOpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
# get data type here
self.dtype = self.in_type
self.__class__.no_need_check_grad = True
self.place = paddle.XPUPlace(0)
self.init_data() # only test float32 because of its register type
self.inputs = {
'X': self.x_np,
'Index': self.index_np,
'Updates': self.updates_np,
}
output = numpy_scatter_nd_add(
self.x_np.copy(), self.index_np, self.updates_np
)
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output_with_place(self.place)
def init_data(self):
self.x_np = np.random.random([100]).astype(self.dtype)
self.index_np = np.random.randint(0, 100, [100, 1]).astype("int32")
self.updates_np = np.random.random([100]).astype(self.dtype)
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
self.__class__.dtype = self.dtype
self.output_dtype = self.dtype
class TestScatterNdAddWithEmptyIndex(TestScatterNdAdd):
def init_data(self):
self.x_np = np.random.random((10, 10)).astype(self.dtype)
self.index_np = np.array([[], []]).astype("int32")
self.updates_np = np.random.random((2, 10, 10)).astype(self.dtype)
class TestScatterNdAddOpWithHighRankSame(TestScatterNdAdd):
def init_data(self):
shape = (3, 2, 2, 1, 10)
self.x_np = np.random.rand(*shape).astype(self.dtype)
self.index_np = np.vstack(
[np.random.randint(0, s, size=100) for s in shape]
).T.astype("int32")
update_shape = judge_update_shape(self.x_np, self.index_np)
self.updates_np = np.random.rand(*update_shape).astype(self.dtype)
class TestScatterNdAddWithHighRankDiff(TestScatterNdAdd):
def init_data(self):
shape = (8, 2, 2, 1, 10)
self.x_np = np.random.rand(*shape).astype(self.dtype)
index_tmp = np.vstack(
[np.random.randint(0, s, size=500) for s in shape]
).T
self.index_np = index_tmp.reshape([10, 5, 10, 5]).astype("int64")
update_shape = judge_update_shape(self.x_np, self.index_np)
self.updates_np = np.random.rand(*update_shape).astype(self.dtype)
class TestScatterNdAddWithMultiDimIndex(TestScatterNdAdd):
def init_data(self):
shape = (16, 3, 20, 20)
self.x_np = np.random.rand(*shape).astype(self.dtype)
self.index_np = np.random.rand(796, 4).astype("int32")
update_shape = judge_update_shape(self.x_np, self.index_np)
self.updates_np = np.random.rand(*update_shape).astype(self.dtype)
support_types = get_xpu_op_support_types('scatter_nd_add')
for stype in support_types:
create_test_class(globals(), XPUTestScatterNdAdd, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册