From 4b6fa1c90d9a105841e529b4f746c86a25fad5a5 Mon Sep 17 00:00:00 2001 From: liuqi Date: Sun, 11 Feb 2018 10:48:36 +0800 Subject: [PATCH] Add reshape op (CPU version only) --- mace/core/operator.cc | 2 ++ mace/kernels/reshape.h | 31 +++++++++++++++++++ mace/ops/reshape.cc | 17 +++++++++++ mace/ops/reshape.h | 64 ++++++++++++++++++++++++++++++++++++++++ mace/ops/reshape_test.cc | 56 +++++++++++++++++++++++++++++++++++ 5 files changed, 170 insertions(+) create mode 100644 mace/kernels/reshape.h create mode 100644 mace/ops/reshape.cc create mode 100644 mace/ops/reshape.h create mode 100644 mace/ops/reshape_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index a0296d87..e0dc1143 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -80,6 +80,7 @@ extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); +extern void Register_Reshape(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_Activation(this); @@ -103,6 +104,7 @@ OperatorRegistry::OperatorRegistry() { Register_MatMul(this); Register_WinogradTransform(this); Register_WinogradInverseTransform(this); + Register_Reshape(this); } } // namespace mace diff --git a/mace/kernels/reshape.h b/mace/kernels/reshape.h new file mode 100644 index 00000000..a185567f --- /dev/null +++ b/mace/kernels/reshape.h @@ -0,0 +1,31 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_RESHAPE_H_ +#define MACE_KERNELS_RESHAPE_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/core/runtime/opencl/cl2_header.h" + +namespace mace { +namespace kernels { + +template +struct ReshapeFunctor { + ReshapeFunctor() {} + + void operator()(const Tensor *input, + const std::vector &out_shape, + Tensor *output, + StatsFuture *future) { + output->Resize(out_shape); + output->CopyBytes(input->raw_data(), input->size() * sizeof(T)); + } +}; + + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_RESHAPE_H_ diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc new file mode 100644 index 00000000..c6ec12d6 --- /dev/null +++ b/mace/ops/reshape.cc @@ -0,0 +1,17 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/reshape.h" + +namespace mace { + +void Register_Reshape(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ReshapeOp); +} + +} // namespace mace diff --git a/mace/ops/reshape.h b/mace/ops/reshape.h new file mode 100644 index 00000000..2dea3b9a --- /dev/null +++ b/mace/ops/reshape.h @@ -0,0 +1,64 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_RESHAPE_H_ +#define MACE_OPS_RESHAPE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/reshape.h" + +namespace mace { + +template +class ReshapeOp : public Operator { + public: + ReshapeOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + shape_(OperatorBase::GetRepeatedArgument("shape")){} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const index_t num_dims = shape_.size(); + int unknown_idx = -1; + index_t product = 1; + std::vector out_shape; + + for (int i = 0; i < num_dims; ++i) { + if (shape_[i] == -1) { + MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; + unknown_idx = i; + out_shape.push_back(1); + } else if (shape_[i] < 0) { + VLOG(ERROR) << "Shape must be non-negative"; + } else { + out_shape.push_back(shape_[i]); + product *= shape_[i]; + } + } + + if (unknown_idx != -1) { + MACE_CHECK(product != 0) << "Cannot infer shape if there is zero shape size."; + const index_t missing = input->size() / product; + MACE_CHECK(missing * product == input->size()) << "Input size not match reshaped tensor size"; + out_shape[unknown_idx] = missing; + } + + Tensor *output = this->Output(OUTPUT); + + functor_(input, out_shape, output, future); + return true; + } + + private: + std::vector shape_; + kernels::ReshapeFunctor functor_; + + private: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_RESHAPE_H_ diff --git a/mace/ops/reshape_test.cc b/mace/ops/reshape_test.cc new file mode 100644 index 00000000..ab3c13a0 --- /dev/null +++ b/mace/ops/reshape_test.cc @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; + +class ReshapeTest : public OpsTestBase {}; + +void TestReshape(const std::vector &org_shape, + const std::vector &output_shape, + const std::vector &res_shape) { + + // Construct graph + OpsTestNet net; + OpDefBuilder("Reshape", "ReshapeTest") + .Input("Input") + .Output("Output") + .AddIntsArg("shape", output_shape) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", org_shape); + + // Run + net.RunOp(); + + auto input = net.GetTensor("Input"); + auto output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(res_shape)); + + const float *input_ptr = input->data(); + const float *output_ptr = output->data(); + const int size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(input_ptr[i], output_ptr[i]); + } +} + +TEST_F(ReshapeTest, Simple) { + TestReshape({1, 2, 3, 4}, {1, 2, -1, 4}, {1, 2, 3, 4}); + TestReshape({1, 2, 3, 4}, {1, 2, -1, 2}, {1, 2, 6, 2}); + TestReshape({1, 2, 3, 4}, {1, -1, 3, 2}, {1, 4, 3, 2}); + TestReshape({1, 2, 3, 4}, {2, 2, 3, 2}, {2, 2, 3, 2}); +} + +TEST_F(ReshapeTest, Complex) { + TestReshape({1, 2, 3, 4}, {-1}, {24}); + TestReshape({1, 2, 3, 4}, {1, -1}, {1, 24}); + TestReshape({1, 2, 3, 4}, {-1, 1}, {24, 1}); + TestReshape({1, 2, 3, 4}, {1, 3, 8}, {1, 3, 8}); +} -- GitLab