未验证 提交 0687bcd6 编写于 作者: W wangchaochaohu 提交者: GitHub

Refine getitem of Variable (#20729)

* add support for __get_item__ of Variable test=develop
上级 72d1d72c
...@@ -42,6 +42,7 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -42,6 +42,7 @@ class StridedSliceOp : public framework::OperatorWithKernel {
auto strides = ctx->Attrs().Get<std::vector<int>>("strides"); auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto axes = ctx->Attrs().Get<std::vector<int>>("axes"); auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags"); auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto starts_size = starts.size(); auto starts_size = starts.size();
auto ends_size = ends.size(); auto ends_size = ends.size();
...@@ -90,10 +91,32 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -90,10 +91,32 @@ class StridedSliceOp : public framework::OperatorWithKernel {
std::vector<int> out_dims_vector(in_dims.size(), -1); std::vector<int> out_dims_vector(in_dims.size(), -1);
if (!tensor_input) { if (!tensor_input) {
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), true); decrease_axis, out_dims_vector.data(), axes.size(),
true);
} }
framework::DDim out_dims(framework::make_ddim(out_dims_vector)); framework::DDim out_dims(framework::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = framework::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Input", /*->*/ "Out"); ctx->ShareLoD("Input", /*->*/ "Out");
} }
...@@ -177,6 +200,8 @@ class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -177,6 +200,8 @@ class StridedSliceOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"infer_flags", "(list<int>) Flags of inferring dims in attributes.") "infer_flags", "(list<int>) Flags of inferring dims in attributes.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Strided Slice Operator. Strided Slice Operator.
Instead of calling this op directly most users will want to use the Instead of calling this op directly most users will want to use the
...@@ -212,10 +237,12 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { ...@@ -212,10 +237,12 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "StartsTensor" || var_name == "EndsTensor") { if (var_name == "StartsTensor" || var_name == "EndsTensor" ||
var_name == "StridesTensor") {
return expected_kernel_type; return expected_kernel_type;
} }
if (var_name == "StartsTensorList" || var_name == "EndsTensorList") { if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
var_name == "StridesTensorList") {
return expected_kernel_type; return expected_kernel_type;
} }
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
......
...@@ -27,22 +27,34 @@ static void StridedSliceOutDims( ...@@ -27,22 +27,34 @@ static void StridedSliceOutDims(
const std::vector<int>& starts, const std::vector<int>& ends, const std::vector<int>& starts, const std::vector<int>& ends,
const std::vector<int>& strides, const std::vector<int>& axes, const std::vector<int>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims, const std::vector<int>& infer_flags, const framework::DDim in_dims,
int* out_dims_vector, const size_t size, bool infer_shape) { const std::vector<int>& decrease_axis, int* out_dims_vector,
const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) { for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i]; out_dims_vector[i] = in_dims[i];
} }
int stride_index, start_index, end_index; int stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
int axes_index = axes[i]; int axes_index = axes[i];
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
bool decrease_axis_affect = false;
if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
if (decrease_axis_affect) {
out_dims_vector[axes_index] = 1;
continue;
}
if (infer_shape && infer_flags[i] == -1) { if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1; out_dims_vector[axes_index] = -1;
continue; continue;
} }
PADDLE_ENFORCE_NE(strides[i], 0, "stride must not to be zero"); PADDLE_ENFORCE_NE(stride_index, 0, "stride must not to be zero");
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
int axis_size = in_dims[axes_index]; int axis_size = in_dims[axes_index];
if (axis_size < 0) { if (axis_size < 0) {
continue; continue;
...@@ -77,6 +89,8 @@ static void StridedSliceOutDims( ...@@ -77,6 +89,8 @@ static void StridedSliceOutDims(
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims, int* reverse_axis, const framework::DDim dims,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
const size_t size) { const size_t size) {
for (size_t axis = 0; axis < size; axis++) { for (size_t axis = 0; axis < size; axis++) {
int axis_size = dims[axes[axis]]; int axis_size = dims[axes[axis]];
...@@ -86,6 +100,15 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, ...@@ -86,6 +100,15 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
ends[axis_index] = 1; ends[axis_index] = 1;
strides[axis_index] = 1; strides[axis_index] = 1;
} }
bool decrease_axis_affect = false;
if (starts[axis_index] == -1 && ends[axis_index] == 0 &&
infer_flags[axis_index] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(),
axes[axis_index]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
// stride must not be zero // stride must not be zero
if (starts[axis_index] < 0) { if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size; starts[axis_index] = starts[axis_index] + axis_size;
...@@ -94,6 +117,13 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes, ...@@ -94,6 +117,13 @@ static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
if (ends[axis_index] < 0) { if (ends[axis_index] < 0) {
ends[axis_index] = ends[axis_index] + axis_size; ends[axis_index] = ends[axis_index] + axis_size;
} }
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
ends[axis_index] = starts[axis_index] - 1;
} else {
ends[axis_index] = starts[axis_index] + 1;
}
}
if (strides[axis_index] < 0) { if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1; reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index]; strides[axis_index] = -strides[axis_index];
...@@ -151,6 +181,7 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -151,6 +181,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
auto strides = context.Attr<std::vector<int>>("strides"); auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags"); auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>(); auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
...@@ -187,12 +218,14 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -187,12 +218,14 @@ class StridedSliceKernel : public framework::OpKernel<T> {
std::vector<int> out_dims_vector(in_dims.size(), -1); std::vector<int> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
out_dims_vector.data(), axes.size(), false); decrease_axis, out_dims_vector.data(), axes.size(),
false);
framework::DDim out_dims(framework::make_ddim(out_dims_vector)); framework::DDim out_dims(framework::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0); std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, starts.size()); reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) { for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0; starts_indices[axis] = 0;
...@@ -209,8 +242,28 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -209,8 +242,28 @@ class StridedSliceKernel : public framework::OpKernel<T> {
} }
framework::Tensor tmp; framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
out_dims_origin[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims_origin.size(); ++i) {
if (out_dims_origin[i] != 0) {
new_out_shape.push_back(out_dims_origin[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims_origin = framework::make_ddim(new_out_shape);
}
tmp.mutable_data<T>(out_dims, context.GetPlace());
out->Resize(out_dims); out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto in_t = auto in_t =
...@@ -225,6 +278,10 @@ class StridedSliceKernel : public framework::OpKernel<T> { ...@@ -225,6 +278,10 @@ class StridedSliceKernel : public framework::OpKernel<T> {
tmp_t.device(place) = tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices); in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis); out_t.device(place) = tmp_t.reverse(reverse_axis);
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
} }
}; };
...@@ -276,6 +333,8 @@ class StridedSliceGradKernel : public framework::OpKernel<T> { ...@@ -276,6 +333,8 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
auto ends = context.Attr<std::vector<int>>("ends"); auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides"); auto strides = context.Attr<std::vector<int>>("strides");
auto axes = context.Attr<std::vector<int>>("axes"); auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto list_new_ends_tensor = auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList"); context.MultiInput<framework::Tensor>("EndsTensorList");
...@@ -313,7 +372,8 @@ class StridedSliceGradKernel : public framework::OpKernel<T> { ...@@ -313,7 +372,8 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
std::vector<int> reverse_vector(starts.size(), 0); std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), out_dims, starts.size()); reverse_vector.data(), out_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) { for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0; starts_indices[axis] = 0;
......
...@@ -1414,9 +1414,11 @@ class Variable(object): ...@@ -1414,9 +1414,11 @@ class Variable(object):
slice_axis = [] slice_axis = []
slice_start = [] slice_start = []
slice_end = [] slice_end = []
slice_step = []
use_strided_slice = False
reverse_axis = [] reverse_axis = []
def fill_constant(shape, dtype, value, force_cpu=False, out=None): def fill_constant(shape, value, force_cpu=False, out=None):
self.block.append_op( self.block.append_op(
type='fill_constant', type='fill_constant',
inputs={}, inputs={},
...@@ -1425,7 +1427,7 @@ class Variable(object): ...@@ -1425,7 +1427,7 @@ class Variable(object):
'shape': shape, 'shape': shape,
'dtype': out.dtype, 'dtype': out.dtype,
'value': float(value), 'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu() 'force_cpu': force_cpu
}, },
stop_gradient=True) stop_gradient=True)
out.stop_gradient = True out.stop_gradient = True
...@@ -1435,15 +1437,17 @@ class Variable(object): ...@@ -1435,15 +1437,17 @@ class Variable(object):
if isinstance(slice_item, slice): if isinstance(slice_item, slice):
start = slice_item.start start = slice_item.start
end = slice_item.stop end = slice_item.stop
step = slice_item.step if slice_item.step else 1 step = slice_item.step
assert (step == 1 or step == -1) if start is None and end is None and step is None:
continue
if step == -1: if step is None:
reverse_axis.append(dim) step = 1
assert (start is None and end is None)
if start is None and end is None: if start is None and end is None:
assert (step == -1)
reverse_axis.append(dim)
continue continue
if start is None: if start is None:
...@@ -1452,16 +1456,21 @@ class Variable(object): ...@@ -1452,16 +1456,21 @@ class Variable(object):
if end is None: if end is None:
end = 10000000 end = 10000000
if step != 1:
use_strided_slice = True
slice_axis.append(dim) slice_axis.append(dim)
slice_start.append(start) slice_start.append(start)
slice_end.append(end) slice_end.append(end)
slice_step.append(step)
else: else:
decrease_axis.append(dim) decrease_axis.append(dim)
slice_axis.append(dim) slice_axis.append(dim)
slice_start.append(slice_item) slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable): if isinstance(slice_item, Variable):
temp_1 = self.block.create_var(dtype='int32') temp_1 = self.block.create_var(dtype='int32')
fill_constant([1], 'int32', 1, force_cpu=True, out=temp_1) fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = self.block.create_var(dtype='int32') temp_end = self.block.create_var(dtype='int32')
self.block.append_op( self.block.append_op(
type='elementwise_add', type='elementwise_add',
...@@ -1489,8 +1498,7 @@ class Variable(object): ...@@ -1489,8 +1498,7 @@ class Variable(object):
else: else:
assert (isinstance(dim, int)) assert (isinstance(dim, int))
temp_out = self.block.create_var(dtype='int32') temp_out = self.block.create_var(dtype='int32')
fill_constant( fill_constant([1], dim, force_cpu=True, out=temp_out)
[1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out) new_list_tensor.append(temp_out)
return new_list_tensor return new_list_tensor
...@@ -1501,8 +1509,9 @@ class Variable(object): ...@@ -1501,8 +1509,9 @@ class Variable(object):
'ends': [], 'ends': [],
'decrease_axis': decrease_axis 'decrease_axis': decrease_axis
} }
if (use_strided_slice == True):
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis))) infer_flags = list(1 for i in range(len(slice_axis)))
# starts # starts
if not contain_var(slice_start): if not contain_var(slice_start):
attrs['starts'] = slice_start attrs['starts'] = slice_start
...@@ -1525,11 +1534,23 @@ class Variable(object): ...@@ -1525,11 +1534,23 @@ class Variable(object):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['ends'].append(dim) attrs['ends'].append(dim)
# strides
if use_strided_slice == True:
if not contain_var(slice_step):
attrs['strides'] = slice_step
else:
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
# infer_flags # infer_flags
attrs['infer_flags'] = infer_flags attrs['infer_flags'] = infer_flags
out = self out = self
if len(slice_axis) > 0: if use_strided_slice == False and len(slice_axis) > 0:
# append slice_op here # append slice_op here
slice_out_var = self.block.create_var( slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name + name=unique_name.generate_with_ignorable_key(self.name +
...@@ -1543,6 +1564,18 @@ class Variable(object): ...@@ -1543,6 +1564,18 @@ class Variable(object):
attrs=attrs) attrs=attrs)
out = slice_out_var out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
strided_slice_out_var = self.block.create_var(
name=unique_name.generate_with_ignorable_key(self.name +
"_strided_slice"),
dtype=self.dtype)
self.block.append_op(
type="strided_slice",
inputs=inputs,
outputs={'Out': [strided_slice_out_var]},
attrs=attrs)
out = strided_slice_out_var
if len(reverse_axis) > 0: if len(reverse_axis) > 0:
reverse_out_var = self.block.create_var( reverse_out_var = self.block.create_var(
......
...@@ -438,7 +438,7 @@ class TestStridedSliceOp_strides_Tensor(OpTest): ...@@ -438,7 +438,7 @@ class TestStridedSliceOp_strides_Tensor(OpTest):
# Test python API # Test python API
class TestSliceAPI(OpTest): class TestStridedSliceAPI(OpTest):
def test_1(self): def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float32") input = np.random.random([3, 4, 5, 6]).astype("float32")
minus_1 = fluid.layers.fill_constant([1], "int32", -1) minus_1 = fluid.layers.fill_constant([1], "int32", -1)
...@@ -455,7 +455,6 @@ class TestSliceAPI(OpTest): ...@@ -455,7 +455,6 @@ class TestSliceAPI(OpTest):
shape=[3, 4, 5, 6], shape=[3, 4, 5, 6],
append_batch_size=False, append_batch_size=False,
dtype="float32") dtype="float32")
out_1 = fluid.layers.strided_slice( out_1 = fluid.layers.strided_slice(
x, x,
axes=[0, 1, 2], axes=[0, 1, 2],
...@@ -477,9 +476,9 @@ class TestSliceAPI(OpTest): ...@@ -477,9 +476,9 @@ class TestSliceAPI(OpTest):
out_4 = fluid.layers.strided_slice( out_4 = fluid.layers.strided_slice(
x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides) x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides)
out_5 = x[-3:3, 0:100, 2:-1] out_5 = x[-3:3, 0:100:2, -1:2:-1]
out_6 = x[minus_3:3, 0:100, :, 2:-1] out_6 = x[minus_3:3:1, 0:100:2, :, minus_1:2:minus_1]
out_7 = x[minus_1, 0:100, :, 2:minus_1] out_7 = x[minus_1, 0:100:2, :, -1:2:-1]
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
...@@ -491,14 +490,13 @@ class TestSliceAPI(OpTest): ...@@ -491,14 +490,13 @@ class TestSliceAPI(OpTest):
'strides': np.array([1, 1, 1]).astype("int32") 'strides': np.array([1, 1, 1]).astype("int32")
}, },
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7]) fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :]) assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1]) assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1]) assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :]) assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
assert np.array_equal(res_5, input[-3:3, 0:100, 2:-1, :]) assert np.array_equal(res_5, input[-3:3, 0:100:2, -1:2:-1, :])
assert np.array_equal(res_6, input[-3:3, 0:100, :, 2:-1]) assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1])
assert np.array_equal(res_7, input[-1, 0:100, :, 2:-1]) assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册