未验证 提交 7743cdf2 编写于 作者: B baoachun 提交者: GitHub

add strided_slice_grad op for npu (#35204)

* add strided_slice_grad op for npu
上级 5fa7d9ce
......@@ -226,14 +226,204 @@ class StridedSliceNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class StridedSliceGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
PADDLE_ENFORCE_EQ(is_tensor_array, false,
platform::errors::InvalidArgument(
"Tensor array as input is not supported."));
int rank = ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceGradCompute<1>(ctx);
break;
case 2:
StridedSliceGradCompute<2>(ctx);
break;
case 3:
StridedSliceGradCompute<3>(ctx);
break;
case 4:
StridedSliceGradCompute<4>(ctx);
break;
case 5:
StridedSliceGradCompute<5>(ctx);
break;
case 6:
StridedSliceGradCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input is supported up to 6."));
break;
}
}
private:
template <size_t D>
void StridedSliceGradCompute(const framework::ExecutionContext& ctx) const {
auto place = ctx.GetPlace();
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* input = ctx.Input<framework::Tensor>("Input");
auto input_dims = input->dims();
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("Input"));
dx->mutable_data<T>(input_dims, place);
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
auto strides_int = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = ctx.Attr<std::vector<int>>("axes");
auto infer_flags = ctx.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto list_new_ends_tensor =
ctx.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
ctx.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
ctx.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (ctx.HasInput("StartsTensor")) {
auto* starts_tensor = ctx.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (ctx.HasInput("EndsTensor")) {
auto* ends_tensor = ctx.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (ctx.HasInput("StridesTensor")) {
auto* strides_tensor = ctx.Input<framework::Tensor>("StridesTensor");
strides = GetDataFromTensor<int64_t>(strides_tensor);
}
std::vector<int64_t> out_dims_vector(input_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), input_dims, infer_flags,
decrease_axis, starts.size());
std::vector<int64_t> starts_indices_vector(D, 0);
std::vector<int64_t> ends_indices_vector(out_dims_vector.begin(),
out_dims_vector.end());
std::vector<int64_t> strides_indices_vector(D, 1);
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices_vector[axis_index] = starts[axis];
ends_indices_vector[axis_index] = ends[axis];
strides_indices_vector[axis_index] = strides[axis];
}
Tensor starts_indices_tensor;
Tensor ends_indices_tensor;
Tensor strides_indices_tensor;
starts_indices_tensor.mutable_data<int64_t>({D}, place);
ends_indices_tensor.mutable_data<int64_t>({D}, place);
strides_indices_tensor.mutable_data<int64_t>({D}, place);
TensorFromVector(starts_indices_vector, dev_ctx, &starts_indices_tensor);
TensorFromVector(ends_indices_vector, dev_ctx, &ends_indices_tensor);
TensorFromVector(strides_indices_vector, dev_ctx, &strides_indices_tensor);
std::vector<int64_t> input_dims_vector;
for (int i = 0; i < input_dims.size(); i++) {
input_dims_vector.push_back(input_dims[i]);
}
Tensor input_dims_tensor;
TensorFromVector(input_dims_vector, dev_ctx, &input_dims_tensor);
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
auto stream = dev_ctx.stream();
framework::NPUAttributeMap attr_input = {{"begin_mask", 0},
{"end_mask", 0},
{"ellipsis_mask", 0},
{"new_axis_mask", 0},
{"shrink_axis_mask", 0}};
if (need_reverse) {
Tensor reverse_axis;
std::vector<int> reverse_axis_vector;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
reverse_axis_vector.push_back(axes[axis]);
}
}
reverse_axis.mutable_data<int>(
{static_cast<int>(reverse_axis_vector.size())}, place);
TensorFromVector(reverse_axis_vector, dev_ctx, &reverse_axis);
Tensor dout_tmp;
dout_tmp.mutable_data<T>(dout->dims(), place);
const auto& runner_reverse =
NpuOpRunner("ReverseV2", {*dout, reverse_axis}, {dout_tmp});
runner_reverse.Run(stream);
const auto& runner =
NpuOpRunner("StridedSliceGrad",
{input_dims_tensor, starts_indices_tensor,
ends_indices_tensor, strides_indices_tensor, dout_tmp},
{*dx}, attr_input);
runner.Run(stream);
} else {
const auto& runner =
NpuOpRunner("StridedSliceGrad",
{input_dims_tensor, starts_indices_tensor,
ends_indices_tensor, strides_indices_tensor, *dout},
{*dx}, attr_input);
runner.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
strided_slice, ops::StridedSliceNPUKernel<plat::NPUDeviceContext, bool>,
ops::StridedSliceNPUKernel<plat::NPUDeviceContext, int>,
ops::StridedSliceNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::StridedSliceNPUKernel<plat::NPUDeviceContext, float>,
ops::StridedSliceNPUKernel<plat::NPUDeviceContext, double>);
REGISTER_OP_NPU_KERNEL(
strided_slice,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::StridedSliceNPUKernel<paddle::platform::NPUDeviceContext, double>);
strided_slice_grad,
ops::StridedSliceGradNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::StridedSliceGradNPUKernel<plat::NPUDeviceContext, float>,
ops::StridedSliceGradNPUKernel<plat::NPUDeviceContext, double>,
ops::StridedSliceGradNPUKernel<plat::NPUDeviceContext, int>,
ops::StridedSliceGradNPUKernel<plat::NPUDeviceContext, int64_t>);
......@@ -56,11 +56,11 @@ def strided_slice_native_forward(input, axes, starts, ends, strides):
return result
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp(OpTest):
def setUp(self):
self.initTestCase()
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = 'strided_slice'
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
......@@ -75,12 +75,17 @@ class TestStridedSliceOp(OpTest):
'infer_flags': self.infer_flags
}
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
def initTestCase(self):
self.input = np.random.rand(10)
self.input = np.random.rand(100)
self.axes = [0]
self.starts = [2]
self.ends = [7]
......@@ -283,12 +288,12 @@ class TestStridedSliceOpBool6D(TestStridedSliceOpBool):
self.infer_flags = [1, 1, 1, 1, 1]
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.place = paddle.NPUPlace(0)
self.op_type = "strided_slice"
self.config()
self.set_npu()
starts_tensor = []
for index, ele in enumerate(self.starts):
......@@ -305,6 +310,9 @@ class TestStridedSliceOp_starts_ListTensor(OpTest):
'infer_flags': self.infer_flags
}
def set_npu(self):
self.__class__.use_npu = True
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
......@@ -318,16 +326,18 @@ class TestStridedSliceOp_starts_ListTensor(OpTest):
self.starts_infer = [1, 10, 2]
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_ends_ListTensor(OpTest):
def setUp(self):
self.place = paddle.NPUPlace(0)
self.op_type = "strided_slice"
self.config()
self.set_npu()
ends_tensor = []
for index, ele in enumerate(self.ends):
......@@ -344,6 +354,9 @@ class TestStridedSliceOp_ends_ListTensor(OpTest):
'infer_flags': self.infer_flags
}
def set_npu(self):
self.__class__.use_npu = True
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 0]
......@@ -357,16 +370,19 @@ class TestStridedSliceOp_ends_ListTensor(OpTest):
self.ends_infer = [3, 1, 4]
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_starts_Tensor(OpTest):
def setUp(self):
self.place = paddle.NPUPlace(0)
self.op_type = "strided_slice"
self.config()
self.set_npu()
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
......@@ -381,6 +397,9 @@ class TestStridedSliceOp_starts_Tensor(OpTest):
'infer_flags': self.infer_flags,
}
def set_npu(self):
self.__class__.use_npu = True
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
......@@ -392,16 +411,19 @@ class TestStridedSliceOp_starts_Tensor(OpTest):
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_ends_Tensor(OpTest):
def setUp(self):
self.place = paddle.NPUPlace(0)
self.op_type = "strided_slice"
self.config()
self.set_npu()
self.inputs = {
'Input': self.input,
"EndsTensor": np.array(
......@@ -416,6 +438,9 @@ class TestStridedSliceOp_ends_Tensor(OpTest):
'infer_flags': self.infer_flags,
}
def set_npu(self):
self.__class__.use_npu = True
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
......@@ -427,20 +452,23 @@ class TestStridedSliceOp_ends_Tensor(OpTest):
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_listTensor_Tensor(OpTest):
def setUp(self):
self.place = paddle.NPUPlace(0)
self.op_type = "strided_slice"
self.set_npu()
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.op_type = "strided_slice"
self.inputs = {
'Input': self.input,
......@@ -457,6 +485,9 @@ class TestStridedSliceOp_listTensor_Tensor(OpTest):
'infer_flags': self.infer_flags,
}
def set_npu(self):
self.__class__.use_npu = True
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2]
......@@ -468,16 +499,19 @@ class TestStridedSliceOp_listTensor_Tensor(OpTest):
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
@skip_check_grad_ci(
reason='''forward only, it doesn't need to call check_grad.''')
class TestStridedSliceOp_strides_Tensor(OpTest):
def setUp(self):
self.place = paddle.NPUPlace(0)
self.op_type = "strided_slice"
self.set_npu()
self.config()
self.inputs = {
'Input': self.input,
"StridesTensor": np.array(
......@@ -492,6 +526,9 @@ class TestStridedSliceOp_strides_Tensor(OpTest):
'infer_flags': self.infer_flags,
}
def set_npu(self):
self.__class__.use_npu = True
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, -1, 2]
......@@ -503,8 +540,10 @@ class TestStridedSliceOp_strides_Tensor(OpTest):
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
place = paddle.NPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], 'Out')
# Test python API
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册