未验证 提交 382d099d 编写于 作者: W wangchaochaohu 提交者: GitHub

add support tensor and tensorlist for strided_slice OP (#19929)

* add support tensor and tensorlist for strided_slice OP test=develop

* fix the commnet test=develop

* fix test=develop

* fix the bug test=develop

* delete log test=develop

* fix API.spec test=develop

* fix test=develop
上级 fe218df3
...@@ -251,7 +251,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype ...@@ -251,7 +251,7 @@ paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype
paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2')) paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', 'b24d0b21361c4bb8ef2cec8c26fb12b2'))
paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310')) paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310'))
paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '315b4870f294e33a27ecbdf440bed3ff'))
paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', 'a2e5296d34c081f2a67890aaa5f02238')) paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', '340d8d656272ea396b441aab848429a2'))
paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'bf61c8f79d795a8371bdb3b5468aa82b'))
paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3')) paddle.fluid.layers.rank (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '096df0e0273145ab80ed119a4c294db3'))
paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe')) paddle.fluid.layers.size (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', 'cf2e156beae36378722666c4c33bebfe'))
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h" #include "paddle/fluid/operators/strided_slice_op.h"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/slice_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +28,7 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -26,7 +28,7 @@ class StridedSliceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
"Input (Input) of slice op should not be null."); "Input (Input) of slice op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
...@@ -39,56 +41,56 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -39,56 +41,56 @@ class StridedSliceOp : public framework::OperatorWithKernel {
auto ends = ctx->Attrs().Get<std::vector<int>>("ends"); auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto axes = ctx->Attrs().Get<std::vector<int>>("axes"); auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
PADDLE_ENFORCE_EQ(starts.size(), ends.size(), auto starts_size = starts.size();
"starts and ends dim size must to be same"); auto ends_size = ends.size();
PADDLE_ENFORCE_EQ(ends.size(), strides.size(), auto strides_size = strides.size();
"ends and strides dim size must to be same");
PADDLE_ENFORCE_EQ(ends.size(), axes.size(),
"axes, end and start dim size must to be same");
// we need to analysis strided slice op is valid for if (ctx->HasInputs("StartsTensorList")) {
// the parameter that we get from python front auto StartsTensorList = ctx->Inputs("StartsTensorList");
int stride_index, start_index, end_index; PADDLE_ENFORCE_GT(StartsTensorList.size(), 0,
std::vector<int> out_dims_vector(in_dims.size()); "StartsTensorList size can't be zero");
for (int i = 0; i < in_dims.size(); i++) { starts_size = StartsTensorList.size();
out_dims_vector[i] = in_dims[i];
} }
for (size_t i = 0; i < starts.size(); i++) { if (ctx->HasInputs("EndsTensorList")) {
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); auto EndsTensorList = ctx->Inputs("EndsTensorList");
int axes_index = axes[i]; PADDLE_ENFORCE_GT(EndsTensorList.size(), 0,
start_index = starts[i]; "EndsTensorList size can't be zero");
end_index = ends[i]; ends_size = EndsTensorList.size();
stride_index = strides[i]; }
int axis_size = in_dims[axes_index]; if (ctx->HasInputs("StridesTensorList")) {
if (axis_size < 0) { auto StridesTensorList = ctx->Inputs("StridesTensorList");
continue; PADDLE_ENFORCE_GT(StridesTensorList.size(), 0,
"StridesTensorList size can't be zero");
strides_size = StridesTensorList.size();
} }
if (start_index < 0) { auto tensor_input = false;
start_index = start_index + axis_size; if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") ||
ctx->HasInput("StridesTensor")) {
tensor_input = true;
} }
if (end_index < 0) { if (ctx->HasInput("EndsTensor") == false) {
end_index = end_index + axis_size; PADDLE_ENFORCE_EQ(ends_size, axes.size(),
"The size of ends must be equal to the size of axes.");
} }
if (ctx->HasInput("StartsTensor") == false) {
if (stride_index < 0) { PADDLE_ENFORCE_EQ(
start_index = start_index + 1; starts_size, axes.size(),
end_index = end_index + 1; "The size of starts must be equal to the size of axes.");
} }
if (ctx->HasInput("StridesTensor") == false) {
bool zero_dim_condition = PADDLE_ENFORCE_EQ(
((stride_index < 0 && (start_index <= end_index)) || strides_size, axes.size(),
(stride_index > 0 && (start_index >= end_index))); "The size of strides must be equal to the size of axes.");
PADDLE_ENFORCE_EQ(zero_dim_condition, false, }
"starts and end must meet requirement in different " // we need to analysis strided slice op is valid for
"stride conditiont"); // the parameter that we get from python front
int left = std::max(0, std::min(start_index, end_index)); std::vector<int> out_dims_vector(in_dims.size(), -1);
int right = std::min(axis_size, std::max(start_index, end_index)); if (!tensor_input) {
int step = std::abs(stride_index); StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
auto out_dims_index = (std::abs(right - left) + step - 1) / step; out_dims_vector.data(), axes.size(), true);
out_dims_vector[axes_index] = out_dims_index;
} }
framework::DDim out_dims(framework::make_ddim(out_dims_vector)); framework::DDim out_dims(framework::make_ddim(out_dims_vector));
...@@ -98,26 +100,83 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -98,26 +100,83 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.Input<Tensor>("Input")->place()); ctx.Input<Tensor>("Input")->place());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Input", "Tensor of data to extract slices from."); AddInput("Input", "Tensor of data to extract slices from.");
AddOutput("Out", "Sliced data tensor."); AddOutput("Out", "Strided Sliced data tensor.");
AddInput("StartsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StartsTensor, StartsTensorList "
"and attr(starts).")
.AsDispensable();
AddInput("EndsTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of EndsTensor, EndsTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StridesTensor",
"(Tensor<int32>, optional) If provided, slice will use this."
"It has the highest priority of StridesTensor, StridesTensorList and "
"attr(ends).")
.AsDispensable();
AddInput(
"StartsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(starts).")
.AsDuplicable()
.AsDispensable();
AddInput(
"EndsTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(ends).")
.AsDuplicable()
.AsDispensable();
AddInput(
"StridesTensorList",
"(vector<Tensor<int32>>, optional) If provided, slice will use this."
"The shape of the tensor in vector MUST BE [1]."
"It has higher priority compare with attr(strides).")
.AsDuplicable()
.AsDispensable();
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"axes", "(list<int> Axes stride from the start to the end)"); "axes", "(list<int>) Axes that `starts` and `ends` apply to.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"starts", "(list<int>) start that the tensor slice start."); "starts", "(list<int>) Start indices for the strided slice start.")
.SetDefault({});
AddAttr<std::vector<int>>("ends", AddAttr<std::vector<int>>("ends",
"(list<int>) end that the tensor slice end"); "(list<int>) End indices the tensor slice end")
.SetDefault({});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "(list<int> stride stride from the start to the end)"); "strides", "(list<int> Stride step from the start to the end)")
.SetDefault({});
AddAttr<std::vector<int>>(
"infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Strided Slice Operator. Strided Slice Operator.
Instead of calling this op directly most users will want to use the Instead of calling this op directly most users will want to use the
...@@ -133,7 +192,7 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { ...@@ -133,7 +192,7 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null"); PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, "Input should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
...@@ -145,11 +204,23 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { ...@@ -145,11 +204,23 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") {
return expected_kernel_type;
}
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker { class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
...@@ -158,9 +229,15 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -158,9 +229,15 @@ class StridedSliceOpGradMaker : public framework::SingleGradOpDescMaker {
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc(); auto *bind = new framework::OpDesc();
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetInput("Input", Input("Input")); bind->SetInput("Input", Input("Input"));
bind->SetInput("StartsTensor", Input("StartsTensor"));
bind->SetInput("EndsTensor", Input("EndsTensor"));
bind->SetInput("StridesTensor", Input("StridesTensor"));
bind->SetInput("StartsTensorList", Input("StartsTensorList"));
bind->SetInput("EndsTensorList", Input("EndsTensorList"));
bind->SetInput("StridesTensorList", Input("StridesTensorList"));
bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); bind->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
bind->SetAttrMap(Attrs()); bind->SetAttrMap(Attrs());
bind->SetType("strided_slice_grad"); bind->SetType("strided_slice_grad");
......
...@@ -19,9 +19,62 @@ limitations under the License. */ ...@@ -19,9 +19,62 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static void StridedSliceOutDims(
const std::vector<int>& starts, const std::vector<int>& ends,
const std::vector<int>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims,
int* out_dims_vector, const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1;
continue;
}
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero");
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
int axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
end_index = end_index + axis_size;
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool zero_dim_condition =
((stride_index < 0 && (start_index <= end_index)) ||
(stride_index > 0 && (start_index >= end_index)));
PADDLE_ENFORCE_EQ(zero_dim_condition, false,
"starts and end must meet requirement in different "
"stride conditiont");
int left = std::max(0, std::min(start_index, end_index));
int right = std::min(axis_size, std::max(start_index, end_index));
int step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
}
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims, int* reverse_axis, const framework::DDim dims,
const size_t size) { const size_t size) {
...@@ -91,19 +144,52 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -91,19 +144,52 @@ class StridedSliceKernel : public framework::OpKernel<T> {
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
auto in = context.Input<framework::Tensor>("Input"); auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out"); auto out = context.Output<framework::Tensor>("Out");
auto out_dims = out->dims();
auto in_dims = in->dims(); auto in_dims = in->dims();
auto starts = context.Attr<std::vector<int>>("starts"); auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends"); auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides"); auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>(); auto reverse_axis = Eigen::array<bool, D>();
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
}
std::vector<int> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0); std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, starts.size()); reverse_vector.data(), in_dims, starts.size());
...@@ -112,6 +198,7 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -112,6 +198,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
starts_indices[axis] = 0; starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis]; ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1; strides_indices[axis] = 1;
reverse_axis[axis] = false;
} }
for (size_t axis = 0; axis < axes.size(); axis++) { for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis]; int axis_index = axes[axis];
...@@ -124,6 +211,7 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -124,6 +211,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
framework::Tensor tmp; framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace()); tmp.mutable_data<T>(out_dims, context.GetPlace());
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto in_t = auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
...@@ -189,6 +277,34 @@ class StridedSliceGradKernel : public framework::OpKernel<T> { ...@@ -189,6 +277,34 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
auto strides = context.Attr<std::vector<int>>("strides"); auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
}
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
......
...@@ -11396,57 +11396,148 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -11396,57 +11396,148 @@ def strided_slice(input, axes, starts, ends, strides):
axes = [0, 1] axes = [0, 1]
starts = [1, 0] starts = [1, 0]
ends = [2, 3] ends = [2, 3]
strides = [1, 1] strides=[1, 1]
Then: Then:
result = [ [5, 6, 7] ] result = [ [5, 6, 7], ]
Case2: Case2:
Given: Given:
data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] data = [ [1, 2, 3, 4], [5, 6, 7, 8], ]
axes = [0, 1] axes = [0, 1]
starts = [0, -1] starts = [0, 1]
ends = [-1, 0] ends = [-1, 1000]
strides = [1, -1] strides = [1, 3]
Then: Then:
result = [ [4, 3, 2] ] result = [ [2], ]
Atrgs: Args:
input (Varibale): the input variable. input (Variable): ${input_comment}.
axes(List):axis we need to slice axes (List): ${axes_comment}
starts (List): the start index in axis starts (List|Variable): ${starts_comment}
ends (List): the end index in axis ends (List|Variable): ${ends_comment}
strides (List): the stride length when we do slice operation
Returns Returns:
out(Variable): the result by strided_slice Op out (Variable): ${out_comment}
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
starts = [1, 0, 2]
ends = [3, 3, 4]
axes = [0, 1, 2]
strides= [1, 1, 1]
input = fluid.layers.data( input = fluid.layers.data(
name="input", shape=[3, 4, 5, 6], dtype='float32') name="input", shape=[3, 4, 5, 6], dtype='float32')
out = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides) # example 1:
# attr starts is a list which doesn't contain tensor Variable.
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]
strides=[1, 1, 1]
sliced_1 = fluid.layers.strided_slice(input, axes=axes, starts=starts, ends=ends, strides=strides)
# example 2:
# attr starts is a list which contain tensor Variable.
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
sliced_2 = fluid.layers.strided_slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends, strides=strides)
""" """
if not isinstance(starts, (list, tuple, Variable)):
raise ValueError(
"Input starts must be an Variable, python list or tuple.")
if not isinstance(ends, (list, tuple, Variable)):
raise ValueError(
"Input ends must be an Variable, python list or tuple.")
if not isinstance(strides, (list, tuple, Variable)):
raise ValueError(
"Input strides must be an Variable, python list or tuple.")
helper = LayerHelper('strided_slice', **locals()) helper = LayerHelper('strided_slice', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op( def contain_var(one_list):
type='strided_slice', for ele in one_list:
inputs={'Input': input}, if isinstance(ele, Variable):
outputs={'Out': out}, return True
attrs={ return False
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': input}
attrs = {'axes': axes}
infer_flags = list(1 for i in range(len(axes)))
if in_dygraph_mode():
inputs = {'Input': input}
attrs = {
'axes': axes, 'axes': axes,
'starts': starts, 'starts': starts,
'ends': ends, 'ends': ends,
'strides': strides 'strides': strides,
}) 'infer_flags': infer_flags
}
else:
# starts
if isinstance(starts, Variable):
starts.stop_gradient = True
inputs['StartsTensor'] = starts
elif isinstance(starts, (list, tuple)):
attrs['starts'] = []
if not contain_var(starts):
attrs['starts'] = starts
else:
inputs['StartsTensorList'] = get_new_list_tensor(starts)
for i, dim in enumerate(starts):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
# ends
if isinstance(ends, Variable):
ends.stop_gradient = True
inputs['EndsTensor'] = ends
elif isinstance(ends, (list, tuple)):
attrs['ends'] = []
if not contain_var(ends):
attrs['ends'] = ends
else:
inputs['EndsTensorList'] = get_new_list_tensor(ends)
for i, dim in enumerate(ends):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
# strides
if isinstance(strides, Variable):
strides.stop_gradient = True
inputs['StridesTensor'] = strides
elif isinstance(strides, (list, tuple)):
attrs['strides'] = []
if not contain_var(strides):
attrs['strides'] = strides
else:
inputs['StridesTensorList'] = get_new_list_tensor(strides)
for i, dim in enumerate(strides):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
attrs['infer_flags'] = infer_flags
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op(
type='strided_slice', inputs=inputs, attrs=attrs, outputs={'Out': out})
return out return out
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from op_test import OpTest from op_test import OpTest
import numpy as np import numpy as np
import unittest import unittest
import paddle.fluid as fluid
def strided_slice_native_forward(input, axes, starts, ends, strides): def strided_slice_native_forward(input, axes, starts, ends, strides):
...@@ -63,7 +64,8 @@ class TestStrideSliceOp(OpTest): ...@@ -63,7 +64,8 @@ class TestStrideSliceOp(OpTest):
'axes': self.axes, 'axes': self.axes,
'starts': self.starts, 'starts': self.starts,
'ends': self.ends, 'ends': self.ends,
'strides': self.strides 'strides': self.strides,
'infer_flags': self.infer_flags
} }
def test_check_output(self): def test_check_output(self):
...@@ -78,6 +80,7 @@ class TestStrideSliceOp(OpTest): ...@@ -78,6 +80,7 @@ class TestStrideSliceOp(OpTest):
self.starts = [-4] self.starts = [-4]
self.ends = [-3] self.ends = [-3]
self.strides = [1] self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOp1(TestStrideSliceOp): class TestStrideSliceOp1(TestStrideSliceOp):
...@@ -87,6 +90,7 @@ class TestStrideSliceOp1(TestStrideSliceOp): ...@@ -87,6 +90,7 @@ class TestStrideSliceOp1(TestStrideSliceOp):
self.starts = [3] self.starts = [3]
self.ends = [8] self.ends = [8]
self.strides = [1] self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOp2(TestStrideSliceOp): class TestStrideSliceOp2(TestStrideSliceOp):
...@@ -96,6 +100,7 @@ class TestStrideSliceOp2(TestStrideSliceOp): ...@@ -96,6 +100,7 @@ class TestStrideSliceOp2(TestStrideSliceOp):
self.starts = [5] self.starts = [5]
self.ends = [0] self.ends = [0]
self.strides = [-1] self.strides = [-1]
self.infer_flags = [1]
class TestStrideSliceOp3(TestStrideSliceOp): class TestStrideSliceOp3(TestStrideSliceOp):
...@@ -105,6 +110,7 @@ class TestStrideSliceOp3(TestStrideSliceOp): ...@@ -105,6 +110,7 @@ class TestStrideSliceOp3(TestStrideSliceOp):
self.starts = [-1] self.starts = [-1]
self.ends = [-3] self.ends = [-3]
self.strides = [-1] self.strides = [-1]
self.infer_flags = [1]
class TestStrideSliceOp4(TestStrideSliceOp): class TestStrideSliceOp4(TestStrideSliceOp):
...@@ -114,6 +120,7 @@ class TestStrideSliceOp4(TestStrideSliceOp): ...@@ -114,6 +120,7 @@ class TestStrideSliceOp4(TestStrideSliceOp):
self.starts = [0, -1, 0] self.starts = [0, -1, 0]
self.ends = [2, -3, 5] self.ends = [2, -3, 5]
self.strides = [1, -1, 1] self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp5(TestStrideSliceOp): class TestStrideSliceOp5(TestStrideSliceOp):
...@@ -123,6 +130,7 @@ class TestStrideSliceOp5(TestStrideSliceOp): ...@@ -123,6 +130,7 @@ class TestStrideSliceOp5(TestStrideSliceOp):
self.starts = [1, 0, 0] self.starts = [1, 0, 0]
self.ends = [2, 1, 3] self.ends = [2, 1, 3]
self.strides = [1, 1, 1] self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp6(TestStrideSliceOp): class TestStrideSliceOp6(TestStrideSliceOp):
...@@ -132,6 +140,7 @@ class TestStrideSliceOp6(TestStrideSliceOp): ...@@ -132,6 +140,7 @@ class TestStrideSliceOp6(TestStrideSliceOp):
self.starts = [1, -1, 0] self.starts = [1, -1, 0]
self.ends = [2, -3, 3] self.ends = [2, -3, 3]
self.strides = [1, -1, 1] self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp7(TestStrideSliceOp): class TestStrideSliceOp7(TestStrideSliceOp):
...@@ -141,6 +150,7 @@ class TestStrideSliceOp7(TestStrideSliceOp): ...@@ -141,6 +150,7 @@ class TestStrideSliceOp7(TestStrideSliceOp):
self.starts = [1, 0, 0] self.starts = [1, 0, 0]
self.ends = [2, 2, 3] self.ends = [2, 2, 3]
self.strides = [1, 1, 1] self.strides = [1, 1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOp8(TestStrideSliceOp): class TestStrideSliceOp8(TestStrideSliceOp):
...@@ -150,6 +160,7 @@ class TestStrideSliceOp8(TestStrideSliceOp): ...@@ -150,6 +160,7 @@ class TestStrideSliceOp8(TestStrideSliceOp):
self.starts = [1] self.starts = [1]
self.ends = [2] self.ends = [2]
self.strides = [1] self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOp9(TestStrideSliceOp): class TestStrideSliceOp9(TestStrideSliceOp):
...@@ -159,6 +170,7 @@ class TestStrideSliceOp9(TestStrideSliceOp): ...@@ -159,6 +170,7 @@ class TestStrideSliceOp9(TestStrideSliceOp):
self.starts = [-1] self.starts = [-1]
self.ends = [-2] self.ends = [-2]
self.strides = [-1] self.strides = [-1]
self.infer_flags = [1]
class TestStrideSliceOp10(TestStrideSliceOp): class TestStrideSliceOp10(TestStrideSliceOp):
...@@ -168,6 +180,7 @@ class TestStrideSliceOp10(TestStrideSliceOp): ...@@ -168,6 +180,7 @@ class TestStrideSliceOp10(TestStrideSliceOp):
self.starts = [1, 0] self.starts = [1, 0]
self.ends = [2, 2] self.ends = [2, 2]
self.strides = [1, 1] self.strides = [1, 1]
self.infer_flags = [1, 1]
class TestStrideSliceOp11(TestStrideSliceOp): class TestStrideSliceOp11(TestStrideSliceOp):
...@@ -177,6 +190,7 @@ class TestStrideSliceOp11(TestStrideSliceOp): ...@@ -177,6 +190,7 @@ class TestStrideSliceOp11(TestStrideSliceOp):
self.starts = [1, 0, 0, 0] self.starts = [1, 0, 0, 0]
self.ends = [2, 2, 3, 4] self.ends = [2, 2, 3, 4]
self.strides = [1, 1, 1, 2] self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
class TestStrideSliceOp12(TestStrideSliceOp): class TestStrideSliceOp12(TestStrideSliceOp):
...@@ -186,6 +200,7 @@ class TestStrideSliceOp12(TestStrideSliceOp): ...@@ -186,6 +200,7 @@ class TestStrideSliceOp12(TestStrideSliceOp):
self.starts = [1, 0, 0, 0, 0] self.starts = [1, 0, 0, 0, 0]
self.ends = [2, 2, 3, 4, 4] self.ends = [2, 2, 3, 4, 4]
self.strides = [1, 1, 1, 1, 1] self.strides = [1, 1, 1, 1, 1]
self.infer_flags = [1, 1, 1, 1]
class TestStrideSliceOp13(TestStrideSliceOp): class TestStrideSliceOp13(TestStrideSliceOp):
...@@ -195,6 +210,295 @@ class TestStrideSliceOp13(TestStrideSliceOp): ...@@ -195,6 +210,295 @@ class TestStrideSliceOp13(TestStrideSliceOp):
self.starts = [1, 0, 0, 0, 1, 2] self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8] self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2] self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
class TestStridedSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts_infer,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [1, -1, 1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.starts_infer = [1, 10, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_ends_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {'Input': self.input, 'EndsTensorList': ends_tensor}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends_infer,
'strides': self.strides,
'infer_flags': self.infer_flags
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 0]
self.ends = [3, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 2]
self.infer_flags = [1, -1, 1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
self.ends_infer = [3, 1, 4]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_starts_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_ends_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"EndsTensor": np.array(
self.ends, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
#'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_listTensor_Tensor(OpTest):
def setUp(self):
self.config()
ends_tensor = []
for index, ele in enumerate(self.ends):
ends_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.op_type = "strided_slice"
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
"EndsTensorList": ends_tensor
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
#'starts': self.starts,
#'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2]
self.ends = [2, 3, 4]
self.axes = [0, 1, 2]
self.strides = [1, 1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
class TestStridedSliceOp_strides_Tensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
self.config()
self.inputs = {
'Input': self.input,
"StridesTensor": np.array(
self.strides, dtype="int32")
}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
#'strides': self.strides,
'infer_flags': self.infer_flags,
}
def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, -1, 2]
self.ends = [2, 0, 4]
self.axes = [0, 1, 2]
self.strides = [1, -1, 1]
self.infer_flags = [-1, -1, -1]
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['Input'], 'Out', max_relative_error=0.006)
# Test python API
class TestSliceAPI(OpTest):
def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float32")
minus_1 = fluid.layers.fill_constant([1], "int32", -1)
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
starts = fluid.layers.data(
name='starts', shape=[3], append_batch_size=False)
ends = fluid.layers.data(
name='ends', shape=[3], append_batch_size=False)
strides = fluid.layers.data(
name='strides', shape=[3], append_batch_size=False)
x = fluid.layers.data(
name="x",
shape=[3, 4, 5, 6],
append_batch_size=False,
dtype="float32")
out_1 = fluid.layers.strided_slice(
x,
axes=[0, 1, 2],
starts=[-3, 0, 2],
ends=[3, 100, -1],
strides=[1, 1, 1])
out_2 = fluid.layers.strided_slice(
x,
axes=[0, 1, 3],
starts=[minus_3, 0, 2],
ends=[3, 100, -1],
strides=[1, 1, 1])
out_3 = fluid.layers.strided_slice(
x,
axes=[0, 1, 3],
starts=[minus_3, 0, 2],
ends=[3, 100, minus_1],
strides=[1, 1, 1])
out_4 = fluid.layers.strided_slice(
x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides)
out_5 = x[-3:3, 0:100, 2:-1]
out_6 = x[minus_3:3, 0:100, :, 2:-1]
out_7 = x[minus_1, 0:100, :, 2:minus_1]
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
fluid.default_main_program(),
feed={
"x": input,
'starts': np.array([-3, 0, 2]).astype("int32"),
'ends': np.array([3, 100, -1]).astype("int32"),
'strides': np.array([1, 1, 1]).astype("int32")
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册