提交 8cee5c3d 编写于 作者: 李寅

Make slice more general

上级 5b12c75f
......@@ -32,12 +32,16 @@ struct StridedSliceFunctor {
int end_mask,
int ellipsis_mask,
int new_axis_mask,
int shrink_axis_mask)
int shrink_axis_mask,
bool is_slice = false)
: begin_mask_(begin_mask),
end_mask_(end_mask),
ellipsis_mask_(ellipsis_mask),
new_axis_mask_(new_axis_mask),
shrink_axis_mask_(shrink_axis_mask) {}
shrink_axis_mask_(shrink_axis_mask),
is_slice_(is_slice),
tmp_strides_tensor_(GetDeviceAllocator(D),
DataTypeToEnum<int32_t>::v()) {}
MaceStatus operator()(const Tensor *input,
const Tensor *begin_indices,
......@@ -49,6 +53,14 @@ struct StridedSliceFunctor {
MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0,
"ellipsis_mask and new_axis_mask are not supported yet.");
if (strides == nullptr) {
tmp_strides_tensor_.Resize({begin_indices->size()});
Tensor::MappingGuard strides_guard(&tmp_strides_tensor_);
int32_t *strides_data = tmp_strides_tensor_.mutable_data<int32_t>();
std::fill(strides_data, strides_data + tmp_strides_tensor_.size(), 1);
strides = &tmp_strides_tensor_;
}
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard begin_indices_guard(begin_indices);
Tensor::MappingGuard end_indices_guard(end_indices);
......@@ -56,6 +68,19 @@ struct StridedSliceFunctor {
const T *input_data = input->data<T>();
const int32_t *begin_indices_data = begin_indices->data<int32_t>();
const int32_t *end_indices_data = end_indices->data<int32_t>();
std::vector<int32_t> slice_end_data;
if (is_slice_) {
// if this op is slice, the end_indices_data is size actually
slice_end_data.resize(end_indices->size());
for (int i = 0; i < slice_end_data.size(); ++i) {
if (end_indices_data[i] == -1) {
slice_end_data[i] = input->dim(i);
} else {
slice_end_data[i] = begin_indices_data[i] + end_indices_data[i];
}
}
end_indices_data = slice_end_data.data();
}
const int32_t *strides_data = strides->data<int32_t>();
std::vector<index_t> output_shape;
......@@ -152,6 +177,8 @@ struct StridedSliceFunctor {
int ellipsis_mask_;
int new_axis_mask_;
int shrink_axis_mask_;
bool is_slice_;
Tensor tmp_strides_tensor_;
};
} // namespace kernels
......
......@@ -30,13 +30,18 @@ class StridedSliceOp : public Operator<D, T> {
OperatorBase::GetOptionalArg<int>("end_mask", 0),
OperatorBase::GetOptionalArg<int>("ellipsis_mask", 0),
OperatorBase::GetOptionalArg<int>("new_axis_mask", 0),
OperatorBase::GetOptionalArg<int>("shrink_axis_mask", 0)) {}
OperatorBase::GetOptionalArg<int>("shrink_axis_mask", 0),
OperatorBase::GetOptionalArg<bool>("slice",
false)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *begin_indices = this->Input(BEGIN);
const Tensor *end_indices = this->Input(END);
const Tensor *strides = this->Input(STRIDES);
const Tensor *strides = nullptr;
if (this->InputSize() > 3) {
strides = this->Input(STRIDES);
}
Tensor *output = this->Output(OUTPUT);
return functor_(input, begin_indices, end_indices, strides, output, future);
......
......@@ -23,32 +23,27 @@ class StridedSliceOpTest : public OpsTestBase {};
namespace {
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &end_indices,
const std::vector<int32_t> &strides,
const int begin_mask,
const int end_mask,
const int ellipsis_mask,
const int new_axis_mask,
const int shrink_axis_mask,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
void TestStridedSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &end_indices,
const std::vector<int32_t> &strides,
const int begin_mask,
const int end_mask,
const int ellipsis_mask,
const int new_axis_mask,
const int shrink_axis_mask,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("BeginIndices",
{static_cast<int32_t>(
input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>("EndIndices",
{static_cast<int32_t>(
input_shape.size())},
end_indices);
net.AddInputFromArray<CPU, int32_t>("Strides",
{static_cast<int32_t>(
input_shape.size())},
strides);
net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>(
"EndIndices", {static_cast<int32_t>(input_shape.size())}, end_indices);
net.AddInputFromArray<CPU, int32_t>(
"Strides", {static_cast<int32_t>(input_shape.size())}, strides);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input")
......@@ -70,47 +65,92 @@ void TestSlice(const std::vector<index_t> &input_shape,
*net.GetOutput("Output"));
}
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &indices_size,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>(
"IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input")
.Input("BeginIndices")
.Input("IndicesSize")
.Output("Output")
.AddIntArg("slice", 1)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(StridedSliceOpTest, TestSliceByFirstAxis) {
TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0},
{2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, 3, 2},
{7, 8, 9, 10, 11, 12});
TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0},
{2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, {3, 2}, {7, 8, 9, 10, 11, 12});
TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
{2, 3, 2}, {1, 1, 1}, 6, 6, 0, 0, 0, {1, 3, 2},
{7, 8, 9, 10, 11, 12});
TEST_F(StridedSliceOpTest, TestStridedSliceByFirstAxis) {
TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{1, 0, 0}, {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, 3, 2},
{7, 8, 9, 10, 11, 12});
TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{1, 0, 0}, {2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, {3, 2},
{7, 8, 9, 10, 11, 12});
TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{1, 1, 2}, {2, 3, 2}, {1, 1, 1}, 6, 6, 0, 0, 0, {1, 3, 2},
{7, 8, 9, 10, 11, 12});
}
TEST_F(StridedSliceOpTest, TestStridedSliceRank1) {
TestStridedSlice({4}, {1, 2, 3, 4}, {1}, {3}, {1}, 0, 0, 0, 0, 0, {2},
{2, 3});
TestStridedSlice({4}, {1, 2, 3, 4}, {-3}, {3}, {1}, 0, 0, 0, 0, 0, {2},
{2, 3});
TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 0, 0, 0, 0, {2},
{3, 2});
TestStridedSlice({4}, {1, 2, 3, 4}, {-1}, {-4}, {-2}, 0, 0, 0, 0, 0, {2},
{4, 2});
TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 0, 0, 0, 0, {3},
{4, 3, 2});
TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 1, 0, 0, 0, {3},
{3, 2, 1});
TestStridedSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 1, 0, 0, 0, {4},
{4, 3, 2, 1});
TestStridedSlice({4}, {1, 2, 3, 4}, {2}, {4}, {2}, 1, 1, 0, 0, 0, {2},
{1, 3});
TestStridedSlice({4}, {1, 2, 3, 4}, {2}, {3}, {1}, 0, 0, 0, 0, 1, {}, {3});
}
TEST_F(StridedSliceOpTest, TestSliceRank1) {
TestSlice({4}, {1, 2, 3, 4}, {1}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3});
TestSlice({4}, {1, 2, 3, 4}, {-3}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 0, 0, 0, 0, {2}, {3, 2});
TestSlice({4}, {1, 2, 3, 4}, {-1}, {-4}, {-2}, 0, 0, 0, 0, 0, {2}, {4, 2});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 0, 0, 0, 0, {3}, {4, 3, 2});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 1, 0, 0, 0, {3}, {3, 2, 1});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 1, 0, 0, 0, {4},
{4, 3, 2, 1});
TestSlice({4}, {1, 2, 3, 4}, {2}, {4}, {2}, 1, 1, 0, 0, 0, {2}, {1, 3});
TestSlice({4}, {1, 2, 3, 4}, {2}, {3}, {1}, 0, 0, 0, 0, 1, {}, {3});
TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0,
0, 0, {2, 3}, {1, 2, 3, 4, 5, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0,
0, 0, {1, 2}, {5, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0,
0, 0, {2, 2}, {1, 3, 4, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 0, 0,
0, 0, 0, {1, 2}, {6, 5});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 3, 3,
0, 0, 0, {2, 3}, {6, 5, 4, 3, 2, 1});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {2, 3}, {1, 1}, 0, 0, 0,
0, 1, {3}, {4, 5, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {2, 3}, {1, 1}, 0, 0, 0,
0, 3, {}, {6});
}
TEST_F(StridedSliceOpTest, TestSliceRank2) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0,
{2, 3}, {1, 2, 3, 4, 5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0,
{1, 2}, {5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, 0, 0,
{2, 2}, {1, 3, 4, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 0, 0, 0, 0, 0,
{1, 2}, {6, 5});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 3, 3, 0, 0, 0,
{2, 3}, {6, 5, 4, 3, 2, 1});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 1,
{3}, {4, 5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {2, 3}, {1, 1}, 0, 0, 0, 0, 3,
{}, {6});
TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3},
{1, 2, 3, 4, 5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {1, 2}, {1, 2}, {4, 5});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, -1}, {2, 3},
{1, 2, 3, 4, 5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 1}, {2, -1}, {2, 2}, {2, 3, 5, 6});
}
} // namespace test
......
......@@ -1155,5 +1155,6 @@ class Transformer(base_converter.ConverterInterface):
print("Final ops:")
for op in net.op:
print("%s (%s)" % (op.name, op.type))
print("%s (%s): %s" % (op.name, op.type, [
out_shape.dims for out_shape in op.output_shape]))
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册