提交 5eb5eaea 编写于 作者: 李寅

Add strided slice & stack & shape ops.

上级 fe5e6be0
......@@ -100,8 +100,11 @@ extern void Register_Quantize(OperatorRegistry *op_registry);
extern void Register_Requantize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Shape(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_Stack(OperatorRegistry *op_registry);
extern void Register_StridedSlice(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_SpaceToDepth(OperatorRegistry *op_registry);
extern void Register_Transpose(OperatorRegistry *op_registry);
......@@ -142,8 +145,11 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Requantize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Shape(this);
ops::Register_Slice(this);
ops::Register_Softmax(this);
ops::Register_Stack(this);
ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Transpose(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_STACK_H_
#define MACE_KERNELS_STACK_H_
#include <algorithm>
#include <functional>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct StackFunctor {
explicit StackFunctor(int axis) : axis_(axis) {}
MaceStatus operator()(const std::vector<const Tensor *> &inputs,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(!inputs.empty(), "stack inputs are empty.");
std::vector<index_t> input_shape = inputs[0]->shape();
MACE_CHECK(axis_ >= -(inputs[0]->dim_size() + 1) &&
axis_ < inputs[0]->dim_size() + 1,
"axis out of bound.");
if (axis_ < 0) {
axis_ += inputs[0]->dim_size() + 1;
}
std::vector<index_t> output_shape = input_shape;
output_shape.insert(output_shape.begin() + axis_, inputs.size());
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
// On host, no need to map data
T *output_data = output->mutable_data<T>();
std::vector<const T *> input_data(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
input_data[i] = inputs[i]->data<T>();
}
index_t high_dim_elem_size =
std::accumulate(input_shape.begin(), input_shape.begin() + axis_, 1,
std::multiplies<index_t>());
index_t low_dim_elem_size =
std::accumulate(input_shape.begin() + axis_, input_shape.end(), 1,
std::multiplies<index_t>());
for (index_t h = 0; h < high_dim_elem_size; ++h) {
for (index_t i = 0; i < inputs.size(); ++i) {
memcpy(output_data, input_data[i] + h * low_dim_elem_size,
sizeof(T) * low_dim_elem_size);
output_data += low_dim_elem_size;
}
}
return MACE_SUCCESS;
}
int axis_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_STACK_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_STRIDED_SLICE_H_
#define MACE_KERNELS_STRIDED_SLICE_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct StridedSliceFunctor {
StridedSliceFunctor(int begin_mask,
int end_mask,
int ellipsis_mask,
int new_axis_mask,
int shrink_axis_mask)
: begin_mask_(begin_mask),
end_mask_(end_mask),
ellipsis_mask_(ellipsis_mask),
new_axis_mask_(new_axis_mask),
shrink_axis_mask_(shrink_axis_mask) {}
MaceStatus operator()(const Tensor *input,
const Tensor *begin_indices,
const Tensor *end_indices,
const Tensor *strides,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0,
"ellipsis_mask and new_axis_mask are not supported yet.");
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard begin_indices_guard(begin_indices);
Tensor::MappingGuard end_indices_guard(end_indices);
Tensor::MappingGuard strides_guard(strides);
const T *input_data = input->data<T>();
const int32_t *begin_indices_data = begin_indices->data<int32_t>();
const int32_t *end_indices_data = end_indices->data<int32_t>();
const int32_t *strides_data = strides->data<int32_t>();
std::vector<index_t> output_shape;
std::vector<index_t> real_begin_indices(input->dim_size(), 0);
std::vector<index_t> real_end_indices(input->dim_size(), 0);
for (index_t d = 0; d < input->dim_size(); ++d) {
index_t dim_len = input->dim(d);
if (begin_mask_ & (1 << d)) {
real_begin_indices[d] = strides_data[d] > 0 ? 0 : dim_len - 1;
} else {
real_begin_indices[d] = (begin_indices_data[d] + dim_len) % dim_len;
}
if (end_mask_ & (1 << d)) {
real_end_indices[d] = strides_data[d] > 0 ? dim_len : -1;
} else {
real_end_indices[d] =
end_indices_data[d] < -dim_len
? -1
: (end_indices_data[d] < 0
? (end_indices_data[d] + dim_len)
: std::min(static_cast<index_t>(end_indices_data[d]),
dim_len));
}
int32_t out_dim_len = std::max(
0.f, std::ceil((real_end_indices[d] - real_begin_indices[d]) /
static_cast<float>(strides_data[d])));
if (!(shrink_axis_mask_ & (1 << d))) {
output_shape.push_back(out_dim_len);
} else {
MACE_CHECK(out_dim_len == 1,
"cannot shrink axis that has len > 1, dim(", d, "): [",
real_begin_indices[d], ", ", real_end_indices[d], "]");
}
}
std::vector<index_t> dim_stride(input->dim_size(), 1);
for (index_t d = input->dim_size() - 2; d >= 0; --d) {
dim_stride[d] = dim_stride[d + 1] * input->dim(d + 1);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
T *output_data = output->mutable_data<T>();
bool slice_by_first_axis = true;
if (strides_data[0] != 1) {
slice_by_first_axis = false;
} else {
for (index_t d = 1; d < input->dim_size(); ++d) {
if (strides_data[d] != 1 || real_begin_indices[d] != 0 ||
real_end_indices[d] != input->dim(d)) {
slice_by_first_axis = false;
break;
}
}
}
if (slice_by_first_axis) {
memcpy(output_data, input_data + real_begin_indices[0] * dim_stride[0],
sizeof(T) * (real_end_indices[0] - real_begin_indices[0]) *
dim_stride[0]);
} else {
if (input->dim_size() == 1) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
*output_data++ = input_data[i];
}
} else if (input->dim_size() == 2) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
*output_data++ = input_data[i * dim_stride[0] + j];
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
}
return MACE_SUCCESS;
}
int begin_mask_;
int end_mask_;
int ellipsis_mask_;
int new_axis_mask_;
int shrink_axis_mask_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_STRIDED_SLICE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/shape.h"
namespace mace {
namespace ops {
void Register_Shape(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Shape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ShapeOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_SHAPE_H_
#define MACE_OPS_SHAPE_H_
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class ShapeOp : public Operator<D, T> {
public:
ShapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
if (input->dim_size() > 0) {
MACE_RETURN_IF_ERROR(output->Resize({input->dim_size()}));
} else {
output->Resize({});
}
Tensor::MappingGuard output_guard(output);
int32_t *output_data = output->mutable_data<int32_t>();
for (index_t i = 0; i < input->dim_size(); ++i) {
output_data[i] = input->dim(i);
}
return MACE_SUCCESS;
}
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_SHAPE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ShapeOpTest : public OpsTestBase {};
namespace {
void TestShapeOp(const std::vector<index_t> &input_shape) {
OpsTestNet net;
net.AddRandomInput<CPU, float>("Input", input_shape);
OpDefBuilder("Shape", "ShapeOpTest")
.Input("Input")
.Output("Output")
.OutputType({DataTypeToEnum<int32_t>::v()})
.Finalize(net.NewOperatorDef());
net.RunOp();
// we need to convert vector<index_t> to vector<int32_t>
std::vector<int32_t> expected_input_shape(input_shape.begin(),
input_shape.end());
if (!expected_input_shape.empty()) {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {input_shape.size()},
expected_input_shape);
} else {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {}, {0});
}
ExpectTensorNear<int32_t>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(ShapeOpTest, TestShape) {
TestShapeOp({1, 2, 3});
TestShapeOp({2, 3});
TestShapeOp({3});
TestShapeOp({});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/stack.h"
namespace mace {
namespace ops {
void Register_Stack(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
StackOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_STACK_H_
#define MACE_OPS_STACK_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/stack.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class StackOp : public Operator<D, T> {
public:
StackOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const std::vector<const Tensor *> &inputs = this->Inputs();
Tensor *output = this->Output(OUTPUT);
return functor_(inputs, output, future);
}
private:
kernels::StackFunctor<D, T> functor_;
protected:
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_STACK_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class StackOpTest : public OpsTestBase {};
namespace {
void TestStack(const std::vector<index_t> &input_shape,
const std::vector<std::vector<float>> &inputs,
int axis,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
for (int i = 0; i < inputs.size(); ++i) {
net.AddInputFromArray<CPU, float>(MakeString("Input", i), input_shape,
inputs[i]);
}
auto op_builder = OpDefBuilder("Stack", "StackOpTest")
.Output("Output")
.AddIntArg("axis", axis);
for (int i = 0; i < inputs.size(); ++i) {
op_builder.Input(MakeString("Input", i));
}
op_builder.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(StackOpTest, TestStackScalar) {
TestStack({}, {{1}, {2}, {3}}, 0, {3}, {1, 2, 3});
}
TEST_F(StackOpTest, TestStackVector) {
TestStack({2}, {{1, 4}, {2, 5}, {3, 6}}, 0, {3, 2}, {1, 4, 2, 5, 3, 6});
TestStack({2}, {{1, 4}, {2, 5}, {3, 6}}, -2, {3, 2}, {1, 4, 2, 5, 3, 6});
TestStack({2}, {{1, 4}, {2, 5}, {3, 6}}, 1, {2, 3}, {1, 2, 3, 4, 5, 6});
}
TEST_F(StackOpTest, TestStackHighRank) {
TestStack({2, 3}, {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}, -3, {2, 2, 3},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
TestStack({2, 3}, {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}, 1, {2, 2, 3},
{1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12});
TestStack({2, 3}, {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}, 2, {2, 3, 2},
{1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/strided_slice.h"
namespace mace {
namespace ops {
void Register_StridedSlice(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
StridedSliceOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_STRIDED_SLICE_H_
#define MACE_OPS_STRIDED_SLICE_H_
#include "mace/core/operator.h"
#include "mace/kernels/strided_slice.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class StridedSliceOp : public Operator<D, T> {
public:
StridedSliceOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("begin_mask", 0),
OperatorBase::GetOptionalArg<int>("end_mask", 0),
OperatorBase::GetOptionalArg<int>("ellipsis_mask", 0),
OperatorBase::GetOptionalArg<int>("new_axis_mask", 0),
OperatorBase::GetOptionalArg<int>("shrink_axis_mask", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *begin_indices = this->Input(BEGIN);
const Tensor *end_indices = this->Input(END);
const Tensor *strides = this->Input(STRIDES);
Tensor *output = this->Output(OUTPUT);
return functor_(input, begin_indices, end_indices, strides, output, future);
}
private:
kernels::StridedSliceFunctor<D, T> functor_;
protected:
MACE_OP_INPUT_TAGS(INPUT, BEGIN, END, STRIDES);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_STRIDED_SLICE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class StridedSliceOpTest : public OpsTestBase {};
namespace {
void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int32_t> &begin_indices,
const std::vector<int32_t> &end_indices,
const std::vector<int32_t> &strides,
const int begin_mask,
const int end_mask,
const int ellipsis_mask,
const int new_axis_mask,
const int shrink_axis_mask,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("BeginIndices", {input_shape.size()},
begin_indices);
net.AddInputFromArray<CPU, int32_t>("EndIndices", {input_shape.size()},
end_indices);
net.AddInputFromArray<CPU, int32_t>("Strides", {input_shape.size()}, strides);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input")
.Input("BeginIndices")
.Input("EndIndices")
.Input("Strides")
.Output("Output")
.AddIntArg("begin_mask", begin_mask)
.AddIntArg("end_mask", end_mask)
.AddIntArg("ellipsis_mask", ellipsis_mask)
.AddIntArg("new_axis_mask", new_axis_mask)
.AddIntArg("shrink_axis_mask", shrink_axis_mask)
.Finalize(net.NewOperatorDef());
net.RunOp();
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output"));
}
} // namespace
TEST_F(StridedSliceOpTest, TestSliceByFirstAxis) {
TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0},
{2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 0, {1, 3, 2},
{7, 8, 9, 10, 11, 12});
TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 0, 0},
{2, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, {3, 2}, {7, 8, 9, 10, 11, 12});
TestSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2},
{2, 3, 2}, {1, 1, 1}, 6, 6, 0, 0, 0, {1, 3, 2},
{7, 8, 9, 10, 11, 12});
}
TEST_F(StridedSliceOpTest, TestSliceRank1) {
TestSlice({4}, {1, 2, 3, 4}, {1}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3});
TestSlice({4}, {1, 2, 3, 4}, {-3}, {3}, {1}, 0, 0, 0, 0, 0, {2}, {2, 3});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 0, 0, 0, 0, {2}, {3, 2});
TestSlice({4}, {1, 2, 3, 4}, {-1}, {-4}, {-2}, 0, 0, 0, 0, 0, {2}, {4, 2});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 0, 0, 0, 0, {3}, {4, 3, 2});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 0, 1, 0, 0, 0, {3}, {3, 2, 1});
TestSlice({4}, {1, 2, 3, 4}, {-2}, {-4}, {-1}, 1, 1, 0, 0, 0, {4},
{4, 3, 2, 1});
TestSlice({4}, {1, 2, 3, 4}, {2}, {4}, {2}, 1, 1, 0, 0, 0, {2}, {1, 3});
TestSlice({4}, {1, 2, 3, 4}, {2}, {3}, {1}, 0, 0, 0, 0, 1, {}, {3});
}
TEST_F(StridedSliceOpTest, TestSliceRank2) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0,
{2, 3}, {1, 2, 3, 4, 5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0,
{1, 2}, {5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, 0, 0,
{2, 2}, {1, 3, 4, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 0, 0, 0, 0, 0,
{1, 2}, {6, 5});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {0, 0}, {-1, -1}, 3, 3, 0, 0, 0,
{2, 3}, {6, 5, 4, 3, 2, 1});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 1,
{3}, {4, 5, 6});
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2}, {2, 3}, {1, 1}, 0, 0, 0, 0, 3,
{}, {6});
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册