未验证 提交 5a4d2186 编写于 作者: Y ykkk2333 提交者: GitHub

add roll and roll_grad kernels and strided_slice and strided_slice_grad...

add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun (#47368)

* add stat tool

* add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun
上级 8a7e54d5
......@@ -472,6 +472,8 @@ XPUOpMap& get_kl2_ops() {
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roll", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roll_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
......@@ -567,6 +569,16 @@ XPUOpMap& get_kl2_ops() {
{"stack_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"strided_slice",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"strided_slice_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......
......@@ -68,3 +68,13 @@ PD_REGISTER_KERNEL(strided_slice_grad,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(strided_slice_grad,
XPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
int,
int16_t,
float,
phi::dtype::float16) {}
#endif
......@@ -59,3 +59,13 @@ PD_REGISTER_KERNEL(strided_slice,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(strided_slice,
XPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
int,
int16_t,
float,
phi::dtype::float16) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roll_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RollGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto shifts_data = shifts.GetData();
dev_ctx.template Alloc<T>(x_grad);
DDim input_dim = x.dims();
std::vector<int> xshape;
size_t nums = shifts_data.size();
for (int i = 0; i < input_dim.size(); ++i) {
xshape.emplace_back(input_dim[i]);
}
auto dims = axis;
// axis = none, reshape to 1-D tensor
if (dims.size() == 0) {
dims.push_back(0l);
input_dim = phi::Dim<1>(x.numel());
}
std::vector<int> shifts_in;
std::vector<int> axis_in;
for (size_t i = 0; i < nums; ++i) {
int a = dims[i];
if (a < 0) {
a += (input_dim.size());
}
axis_in.emplace_back(a);
int sh = (0 - shifts_data[i]) % input_dim[a];
if (sh < 0) {
sh += input_dim[a];
}
shifts_in.emplace_back(sh);
}
int r = xpu::roll(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape,
shifts_in,
axis_in);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "roll");
}
} // namespace phi
PD_REGISTER_KERNEL(roll_grad, XPU, ALL_LAYOUT, phi::RollGradKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/roll_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RollKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shifts,
const std::vector<int64_t>& axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto shifts_data = shifts.GetData();
dev_ctx.template Alloc<T>(out);
DDim input_dim = x.dims();
std::vector<int> xshape;
size_t nums = shifts_data.size();
for (int i = 0; i < input_dim.size(); ++i) {
xshape.emplace_back(input_dim[i]);
}
auto dims = axis;
// axis = none, reshape to 1-D tensor
if (dims.size() == 0) {
dims.push_back(0l);
input_dim = phi::Dim<1>(x.numel());
}
std::vector<int> shifts_in;
std::vector<int> axis_in;
for (size_t i = 0; i < nums; ++i) {
int a = dims[i];
if (a < 0) {
a += (input_dim.size());
}
axis_in.emplace_back(a);
int sh = shifts_data[i] % input_dim[a];
if (sh < 0) {
sh += input_dim[a];
}
shifts_in.emplace_back(sh);
}
int r = xpu::roll(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
shifts_in,
axis_in);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "roll");
}
} // namespace phi
PD_REGISTER_KERNEL(roll, XPU, ALL_LAYOUT, phi::RollKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceRawGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
DDim in_dims = x.dims();
dev_ctx.template Alloc<T>(x_grad);
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
std::vector<int> xshape;
std::vector<int> starts_in(in_dims.size(), 0);
std::vector<int> ends_in;
std::vector<int> strides_in(in_dims.size(), 1);
for (int i = 0; i < in_dims.size(); ++i) {
xshape.emplace_back(in_dims[i]);
ends_in.emplace_back(in_dims[i]);
}
int num = axes.size();
for (int i = 0; i < num; ++i) {
PADDLE_ENFORCE_EQ(
strides_[i] > 0,
true,
errors::InvalidArgument("Currently, XPU strided slice kernel does not",
"support reverse strided slice"));
int cur_axe = axes[i];
int st = starts_[i];
if (st > xshape[cur_axe]) {
st = xshape[cur_axe];
}
if (st < 0) {
st += xshape[cur_axe];
}
starts_in[cur_axe] = st;
int end = ends_[i];
if (end > xshape[cur_axe]) {
end = xshape[cur_axe];
}
if (end < 0) {
end += xshape[cur_axe];
}
ends_in[cur_axe] = end;
strides_in[cur_axe] = strides_[i];
}
int r = xpu::strided_slice_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape,
starts_in,
ends_in,
strides_in);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(strided_slice_raw_grad,
XPU,
ALL_LAYOUT,
phi::StridedSliceRawGradKernel,
int,
int16_t,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const IntArray& starts,
const IntArray& ends,
const IntArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
DDim in_dims = x.dims();
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
funcs::StridedSliceOutDims(starts_,
ends_,
strides_,
axes,
infer_flags,
in_dims,
decrease_axis,
out_dims_vector.data(),
axes.size(),
false);
DDim out_dims(phi::make_ddim(out_dims_vector));
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
std::vector<int> xshape;
std::vector<int> starts_in(in_dims.size(), 0);
std::vector<int> ends_in;
std::vector<int> strides_in(in_dims.size(), 1);
for (int i = 0; i < in_dims.size(); ++i) {
xshape.emplace_back(in_dims[i]);
ends_in.emplace_back(in_dims[i]);
}
int num = axes.size();
for (int i = 0; i < num; ++i) {
PADDLE_ENFORCE_EQ(
strides_[i] > 0,
true,
errors::InvalidArgument("Currently, XPU strided slice kernel does not ",
"support reverse strided slice."));
int cur_axe = axes[i];
int st = starts_[i];
if (st > xshape[cur_axe]) {
st = xshape[cur_axe];
}
if (st < 0) {
st += xshape[cur_axe];
}
starts_in[cur_axe] = st;
int end = ends_[i];
if (end > xshape[cur_axe]) {
end = xshape[cur_axe];
}
if (end < 0) {
end += xshape[cur_axe];
}
ends_in[cur_axe] = end;
PADDLE_ENFORCE_EQ(
st < end,
true,
errors::InvalidArgument("End index should be larger than",
"start Index, this OP does not support",
"reverse operator."));
strides_in[cur_axe] = strides_[i];
}
int r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
starts_in,
ends_in,
strides_in);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
}
} // namespace phi
PD_REGISTER_KERNEL(strided_slice_raw,
XPU,
ALL_LAYOUT,
phi::StridedSliceRawKernel,
int,
int16_t,
float,
phi::dtype::float16) {}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import sys
import unittest
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
create_test_class,
get_xpu_op_support_types,
XPUOpTestWrapper,
)
paddle.enable_static()
class XPUTestRollOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "roll"
self.use_dynamic_create_class = False
class TestXPURollOp(XPUOpTest):
def setUp(self):
self.op_type = "roll"
self.dtype = self.in_type
self.init_shapes()
self.inputs = {
'X': np.random.random(self.x_shape).astype(self.dtype)
}
self.attrs = {'shifts': self.shifts, 'axis': self.axis}
self.outputs = {
'Out': np.roll(
self.inputs['X'], self.attrs['shifts'], self.attrs['axis']
)
}
def init_shapes(self):
self.x_shape = (100, 4, 5)
self.shifts = [101, -1]
self.axis = [0, -2]
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Out')
class TestRollOpCase2(TestXPURollOp):
def init_shapes(self):
self.x_shape = (100, 10, 5)
self.shifts = [8, -1]
self.axis = [-1, -2]
class TestRollOpCase3(TestXPURollOp):
def init_shapes(self):
self.x_shape = (100, 10, 5, 10, 15)
self.shifts = [50, -1, 3]
self.axis = [-1, -2, 1]
class TestRollOpCase4(TestXPURollOp):
def init_shapes(self):
self.x_shape = (100, 10, 5, 10, 15)
self.shifts = [8, -1]
self.axis = [-1, -2]
class TestRollOpCase4(TestXPURollOp):
def init_shapes(self):
self.x_shape = (100, 10, 5, 10)
self.shifts = [20, -1]
self.axis = [0, -2]
support_types = get_xpu_op_support_types('roll')
for stype in support_types:
create_test_class(globals(), XPUTestRollOp, stype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import sys
import unittest
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
create_test_class,
get_xpu_op_support_types,
XPUOpTestWrapper,
)
paddle.enable_static()
def strided_slice_native_forward(input, axes, starts, ends, strides):
dim = input.ndim
start = []
end = []
stride = []
for i in range(dim):
start.append(0)
end.append(input.shape[i])
stride.append(1)
for i in range(len(axes)):
start[axes[i]] = starts[i]
end[axes[i]] = ends[i]
stride[axes[i]] = strides[i]
result = {
1: lambda input, start, end, stride: input[
start[0] : end[0] : stride[0]
],
2: lambda input, start, end, stride: input[
start[0] : end[0] : stride[0], start[1] : end[1] : stride[1]
],
3: lambda input, start, end, stride: input[
start[0] : end[0] : stride[0],
start[1] : end[1] : stride[1],
start[2] : end[2] : stride[2],
],
4: lambda input, start, end, stride: input[
start[0] : end[0] : stride[0],
start[1] : end[1] : stride[1],
start[2] : end[2] : stride[2],
start[3] : end[3] : stride[3],
],
5: lambda input, start, end, stride: input[
start[0] : end[0] : stride[0],
start[1] : end[1] : stride[1],
start[2] : end[2] : stride[2],
start[3] : end[3] : stride[3],
start[4] : end[4] : stride[4],
],
6: lambda input, start, end, stride: input[
start[0] : end[0] : stride[0],
start[1] : end[1] : stride[1],
start[2] : end[2] : stride[2],
start[3] : end[3] : stride[3],
start[4] : end[4] : stride[4],
start[5] : end[5] : stride[5],
],
}[dim](input, start, end, stride)
return result
class XPUTestStrideSliceOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'strided_slice'
self.use_dynamic_create_class = False
class XPUTestStrideSliceOp(XPUOpTest):
def setUp(self):
self.op_type = 'strided_slice'
self.dtype = self.in_type
self.initTestCase()
self.input = np.random.random(self.inshape).astype(self.dtype)
self.python_api = paddle.strided_slice
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides
)
self.inputs = {'Input': self.input}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.XPUPlace(0), ['Input'], 'Out')
def initTestCase(self):
self.inshape = 100
self.axes = [0]
self.starts = [-4]
self.ends = [-1]
self.strides = [1]
self.infer_flags = [1]
class XPUTestStrideSliceOp1(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = 100
self.axes = [0]
self.starts = [3]
self.ends = [8]
self.strides = [1]
self.infer_flags = [1]
class XPUTestStrideSliceOp2(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (4, 8, 12)
self.axes = [0, 1, 2]
self.starts = [3, 4, 5]
self.ends = [4, 5, 6]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class XPUTestStrideSliceOp3(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (4, 8, 12, 4, 40)
self.axes = [0, 1, 2, 3, 4]
self.starts = [3, 4, 5, 1, 10]
self.ends = [4, 5, 6, 2, 30]
self.strides = [1, 1, 1, 2, 2]
self.infer_flags = [1, 1, 1, 1, 1]
class XPUTestStrideSliceOp5(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (5, 5, 5)
self.axes = [0, 1, 2]
self.starts = [1, 0, 0]
self.ends = [2, 1, 3]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class XPUTestStrideSliceOp7(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (5, 5, 5)
self.axes = [0, 1, 2]
self.starts = [1, 0, 0]
self.ends = [2, 2, 3]
self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class XPUTestStrideSliceOp8(XPUTestStrideSliceOp):
def initTestCase(self):
self.inshape = (3, 3, 3, 6, 7, 8)
self.axes = [0, 1, 2, 3, 4, 5]
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
support_types = get_xpu_op_support_types('strided_slice')
for stype in support_types:
create_test_class(globals(), XPUTestStrideSliceOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册