diff --git a/mace/core/operator.cc b/mace/core/operator.cc index a0296d8722dbf09a5d55d707308de82b06383321..e9a5f1b698e916d44d1aeb4337f2bb7f226b2b7b 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -80,6 +80,8 @@ extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); +extern void Register_Reshape(OperatorRegistry *op_registry); +extern void Register_Eltwise(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_Activation(this); @@ -103,6 +105,8 @@ OperatorRegistry::OperatorRegistry() { Register_MatMul(this); Register_WinogradTransform(this); Register_WinogradInverseTransform(this); + Register_Reshape(this); + Register_Eltwise(this); } } // namespace mace diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h new file mode 100644 index 0000000000000000000000000000000000000000..18f0604c15e4d68796ced9b3c5dfbc64bc879c70 --- /dev/null +++ b/mace/kernels/eltwise.h @@ -0,0 +1,105 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_ELTWISE_H_ +#define MACE_KERNELS_ELTWISE_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/core/runtime/opencl/cl2_header.h" + +namespace mace { +namespace kernels { + +enum EltwiseType{ + PROD = 0, + SUM = 1, + MAX = 2, + MIN = 3, +}; + +struct EltwiseFunctorBase { + EltwiseFunctorBase(const EltwiseType type, + const std::vector &coeff) + : type_(type), coeff_(coeff) {} + + EltwiseType type_; + std::vector coeff_; +}; + +template +struct EltwiseFunctor : EltwiseFunctorBase { + EltwiseFunctor(const EltwiseType type, + const std::vector &coeff) + : EltwiseFunctorBase(type, coeff) {} + + void operator()(const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) { + Tensor::MappingGuard input0_guard(input0); + Tensor::MappingGuard input1_guard(input1); + Tensor::MappingGuard output_guard(output); + + const T *input0_ptr = input0->data(); + const T *input1_ptr = input1->data(); + T *output_ptr = output->mutable_data(); + const index_t size = input0->size(); + + switch (type_) { + case PROD: +#pragma omp parallel for + for(index_t i = 0; i < size; ++i) { + output_ptr[i] = input0_ptr[i] * input1_ptr[i]; + } + break; + case SUM: + if (coeff_.empty()) { +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = input0_ptr[i] + input1_ptr[i]; + } + } else { +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = coeff_[0] * input0_ptr[i] + coeff_[1] * input1_ptr[i]; + } + } + break; + case MAX: +#pragma omp parallel for + for(index_t i = 0; i < size; ++i) { + output_ptr[i] = std::max(input0_ptr[i], input1_ptr[i]); + } + break; + case MIN: +#pragma omp parallel for + for(index_t i = 0; i < size; ++i) { + output_ptr[i] = std::min(input0_ptr[i], input1_ptr[i]); + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type_; + } + } +}; + + +template +struct EltwiseFunctor: EltwiseFunctorBase { + EltwiseFunctor(const EltwiseType type, + const std::vector &coeff) + : EltwiseFunctorBase(type, coeff) {} + + void operator()(const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_ELTWISE_H_ diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl new file mode 100644 index 0000000000000000000000000000000000000000..735bc96e0149b5716230c092f5f3716598c53116 --- /dev/null +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -0,0 +1,34 @@ +#include + +__kernel void eltwise(__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ + __read_only image2d_t input1, +#ifdef COEFF_SUM + __private const float coeff0, + __private const float coeff1, +#endif + __write_only image2d_t output) { + const int w = get_global_id(0); + const int hb = get_global_id(1); + + DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb)); + DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb)); + DATA_TYPE4 out; +#if ELTWISE_TYPE == 0 + out = in0 * in1; +#elif ELTWISE_TYPE == 1 + +#ifdef COEFF_SUM + out = mad(coeff0, in0, mad(coeff1, in1, 0)); +#else + out = in0 + in1; +#endif + +#elif ELTWISE_TYPE == 2 + out = fmax(in0, in1); +#elif ELTWISE_TYPE == 3 + out = fmin(in0, in1); +#endif + + WRITE_IMAGET(output, (int2)(w, hb), out); +} + diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..43356df3ce72fed91a4f0146e9532bc62261a4b9 --- /dev/null +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -0,0 +1,69 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/eltwise.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void EltwiseFunctor::operator()(const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) { + + const index_t batch = input0->dim(0); + const index_t height = input0->dim(1); + const index_t width = input0->dim(2); + const index_t channels = input0->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t width_pixels = channel_blocks * width; + const index_t batch_height_pixels = batch * height; + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("eltwise"); + built_options.emplace("-Deltwise=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + built_options.emplace("-DELTWISE_TYPE=" + ToString(type_)); + if (!coeff_.empty()) built_options.emplace("-DCOEFF_SUM"); + kernel_ = runtime->BuildKernel("eltwise", kernel_name, built_options); + + uint32_t idx = 0; + kernel_.setArg(idx++, + *(static_cast(input0->buffer()))); + kernel_.setArg(idx++, + *(static_cast(input1->buffer()))); + if (!coeff_.empty()) { + kernel_.setArg(idx++, coeff_[0]); + kernel_.setArg(idx++, coeff_[1]); + } + kernel_.setArg(idx++, *(static_cast(output->buffer()))); + } + + const uint32_t gws[2] = { + static_cast(width_pixels), + static_cast(batch_height_pixels) + }; + const std::vector lws = {64, 16, 1}; + std::stringstream ss; + ss << "eltwise_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); +} + +template struct EltwiseFunctor; +template struct EltwiseFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/winograd_transform.cc b/mace/kernels/opencl/winograd_transform.cc index a842ba719174505e115fd26bc8bd8cd7f2301898..f74eb191d885647a42bfef5ec81e44c01f09efed 100644 --- a/mace/kernels/opencl/winograd_transform.cc +++ b/mace/kernels/opencl/winograd_transform.cc @@ -54,8 +54,8 @@ void WinogradTransformFunctor::operator()(const Tensor *i kernel_.setArg(idx++, static_cast(paddings[1] / 2)); } - const uint32_t gws[2] = {static_cast(out_width), - static_cast(RoundUpDiv4(input_tensor->dim(3)))}; + const uint32_t gws[2] = {static_cast(out_width), + static_cast(RoundUpDiv4(input_tensor->dim(3)))}; const std::vector lws = {128, 8, 1}; std::stringstream ss; ss << "winograd_transform_kernel_" @@ -126,8 +126,8 @@ void WinogradInverseTransformFunctor::operator()(const Te kernel_.setArg(idx++, prelu_alpha_); } - const uint32_t gws[2] = {static_cast(input_tensor->dim(2)), - static_cast(RoundUpDiv4(input_tensor->dim(1)))}; + const uint32_t gws[2] = {static_cast(input_tensor->dim(2)), + static_cast(RoundUpDiv4(input_tensor->dim(1)))}; const std::vector lws = {128, 8, 1}; std::stringstream ss; diff --git a/mace/kernels/reshape.h b/mace/kernels/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..4d37a19974d683c3f916e2ccc170dfac11ed94c7 --- /dev/null +++ b/mace/kernels/reshape.h @@ -0,0 +1,32 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_RESHAPE_H_ +#define MACE_KERNELS_RESHAPE_H_ + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/core/runtime/opencl/cl2_header.h" + +namespace mace { +namespace kernels { + +template +struct ReshapeFunctor { + ReshapeFunctor() {} + + void operator()(const Tensor *input, + const std::vector &out_shape, + Tensor *output, + StatsFuture *future) { + output->Resize(out_shape); + // TODO copy on write to avoid this copy. + output->CopyBytes(input->raw_data(), input->size() * sizeof(T)); + } +}; + + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_RESHAPE_H_ diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc new file mode 100644 index 0000000000000000000000000000000000000000..0304ec1aa3d64a50a9ca6a1c1a08267578fe3dd6 --- /dev/null +++ b/mace/ops/eltwise.cc @@ -0,0 +1,29 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/eltwise.h" + +namespace mace { + +void Register_Eltwise(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + EltwiseOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + EltwiseOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + EltwiseOp); +} + +} // namespace mace diff --git a/mace/ops/eltwise.h b/mace/ops/eltwise.h new file mode 100644 index 0000000000000000000000000000000000000000..7d8e63ee0deca2f96e5b9681d3eea76cca4ca171 --- /dev/null +++ b/mace/ops/eltwise.h @@ -0,0 +1,47 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_RESHAPE_H_ +#define MACE_OPS_RESHAPE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/eltwise.h" + +namespace mace { + +template +class EltwiseOp : public Operator { + public: + EltwiseOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(static_cast( + OperatorBase::GetSingleArgument( + "type", static_cast(kernels::EltwiseType::SUM))), + OperatorBase::GetRepeatedArgument("coeff")){} + + bool Run(StatsFuture *future) override { + const Tensor *input0 = this->Input(0); + const Tensor *input1 = this->Input(1); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input0->dim_size() == input1->dim_size()) << "Inputs of Eltwise op must be same shape"; + for(int i = 0; i < input0->dim_size(); ++i) { + MACE_CHECK(input0->dim(i) == input1->dim(i)) << "Inputs of Eltwise op must be same shape"; + } + + output->ResizeLike(input0); + + functor_(input0, input1, output, future); + return true; + } + + private: + kernels::EltwiseFunctor functor_; + + private: + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_RESHAPE_H_ diff --git a/mace/ops/eltwise_benchmark.cc b/mace/ops/eltwise_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..80a3f072b0c785bf26342408b6a91f9b98b63831 --- /dev/null +++ b/mace/ops/eltwise_benchmark.cc @@ -0,0 +1,79 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" +#include "mace/kernels/eltwise.h" + +namespace mace { +template +static void EltwiseBenchmark(int iters, kernels::EltwiseType type, int n, int h, int w, int c) { + mace::testing::StopTiming(); + + OpsTestNet net; + // Add input data + net.AddRandomInput("Input0", {n, h, w, c}); + net.AddRandomInput("Input1", {n, h, w, c}); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input0", "InputImg0", kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("InputImg0") + .Input("InputImg1") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", {1.2, 2.1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("Input0") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", {1.2, 2.1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} + +#define BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, TYPE, DEVICE) \ + static void BM_ELTWISE_##ELT_TYPE##_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W * C; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + EltwiseBenchmark(iters, static_cast(ELT_TYPE), N, H, W, C); \ + } \ + BENCHMARK(BM_ELTWISE_##ELT_TYPE##_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) + +#define BM_ELTWISE(ELT_TYPE, N, H, W, C, ) \ + BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, float, CPU); \ + BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, float, OPENCL); \ + BM_ELTWISE_MACRO(ELT_TYPE, N, H, W, C, half, OPENCL); + +BM_ELTWISE(0, 1, 256, 256, 32); +BM_ELTWISE(0, 1, 128, 128, 32); +BM_ELTWISE(1, 1, 128, 128, 32); +BM_ELTWISE(2, 1, 128, 128, 32); +BM_ELTWISE(0, 1, 240, 240, 256); +BM_ELTWISE(1, 1, 240, 240, 256); +BM_ELTWISE(2, 1, 240, 240, 256); + +} // namespace mace diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e3d336258511ff53ed51c05ca63624be1f52253 --- /dev/null +++ b/mace/ops/eltwise_test.cc @@ -0,0 +1,187 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "mace/kernels/eltwise.h" + +namespace mace { + +class EltwiseOpTest : public OpsTestBase {}; + +template +void Simple(const kernels::EltwiseType type, + const std::vector &shape, + const std::vector &input0, + const std::vector &input1, + const std::vector &output, + const std::vector coeff = {}) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input1", shape, input0); + net.AddInputFromArray("Input2", shape, input1); + + if (D == DeviceType::CPU) { + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("Input1") + .Input("Input2") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", coeff) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + } else { + BufferToImage(net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Input2", "InputImg2", kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("InputImg1") + .Input("InputImg2") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", coeff) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(net, "OutputImg", "Output", kernels::BufferType::IN_OUT_CHANNEL); + } + + auto expected = CreateTensor(shape, output); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-3); +} + +TEST_F(EltwiseOpTest, CPUSimple) { + Simple(kernels::EltwiseType::PROD, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 2, 3, 4, 5, 6}, + {1, 4, 9, 16, 25, 36}); + Simple(kernels::EltwiseType::SUM, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 2, 3, 4, 5, 6}, + {2, 4, 6, 8, 10, 12}); + Simple(kernels::EltwiseType::SUM, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 2, 3, 4, 5, 6}, + {3, 6, 9, 12, 15, 18}, + {2, 1}); + Simple(kernels::EltwiseType::MAX, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 1, 3, 3, 6, 6}, + {1, 2, 3, 4, 6, 6}); + Simple(kernels::EltwiseType::MIN, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 1, 3, 3, 6, 6}, + {1, 1, 3, 3, 5, 6}); +} + +TEST_F(EltwiseOpTest, GPUSimple) { + Simple(kernels::EltwiseType::PROD, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 2, 3, 4, 5, 6}, + {1, 4, 9, 16, 25, 36}); + Simple(kernels::EltwiseType::SUM, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 2, 3, 4, 5, 6}, + {2, 4, 6, 8, 10, 12}); + Simple(kernels::EltwiseType::SUM, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 2, 3, 4, 5, 6}, + {3, 6, 9, 12, 15, 18}, + {2, 1}); + Simple(kernels::EltwiseType::MAX, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 1, 3, 3, 6, 6}, + {1, 2, 3, 4, 6, 6}); + Simple(kernels::EltwiseType::MIN, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, + {1, 1, 3, 3, 6, 6}, + {1, 1, 3, 3, 5, 6}); +} + +template +void RandomTest(const kernels::EltwiseType type, + const std::vector &shape) { + testing::internal::LogToStderr(); + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input1", shape); + net.AddRandomInput("Input2", shape); + + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("Input1") + .Input("Input2") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", {1.2, 2.1}) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + BufferToImage(net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Input2", "InputImg2", kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("InputImg1") + .Input("InputImg2") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", {1.2, 2.1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(net, "OutputImg", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL); + + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), *net.GetOutput("OPENCLOutput"), 1e-3); + } else { + ExpectTensorNear(*net.GetTensor("Output"), *net.GetOutput("OPENCLOutput"), 1e-1); + } +} + +TEST_F(EltwiseOpTest, OPENCLRandomFloat) { + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}); + RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}); +} + +TEST_F(EltwiseOpTest, OPENCLRandomHalf) { + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}); + RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}); +} + +} // namespace mace diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..d72052713c13d76bee21d9a44291e2197b71ffd5 --- /dev/null +++ b/mace/ops/reshape.cc @@ -0,0 +1,17 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/reshape.h" + +namespace mace { + +void Register_Reshape(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ReshapeOp); +} + +} // namespace mace diff --git a/mace/ops/reshape.h b/mace/ops/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..2dea3b9ae85193223ebc5afc2f97c8219427fef5 --- /dev/null +++ b/mace/ops/reshape.h @@ -0,0 +1,64 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_RESHAPE_H_ +#define MACE_OPS_RESHAPE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/reshape.h" + +namespace mace { + +template +class ReshapeOp : public Operator { + public: + ReshapeOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + shape_(OperatorBase::GetRepeatedArgument("shape")){} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const index_t num_dims = shape_.size(); + int unknown_idx = -1; + index_t product = 1; + std::vector out_shape; + + for (int i = 0; i < num_dims; ++i) { + if (shape_[i] == -1) { + MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; + unknown_idx = i; + out_shape.push_back(1); + } else if (shape_[i] < 0) { + VLOG(ERROR) << "Shape must be non-negative"; + } else { + out_shape.push_back(shape_[i]); + product *= shape_[i]; + } + } + + if (unknown_idx != -1) { + MACE_CHECK(product != 0) << "Cannot infer shape if there is zero shape size."; + const index_t missing = input->size() / product; + MACE_CHECK(missing * product == input->size()) << "Input size not match reshaped tensor size"; + out_shape[unknown_idx] = missing; + } + + Tensor *output = this->Output(OUTPUT); + + functor_(input, out_shape, output, future); + return true; + } + + private: + std::vector shape_; + kernels::ReshapeFunctor functor_; + + private: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_RESHAPE_H_ diff --git a/mace/ops/reshape_test.cc b/mace/ops/reshape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ab3c13a0c36fb789fde031efdc1f35d7100b76a1 --- /dev/null +++ b/mace/ops/reshape_test.cc @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gmock/gmock.h" +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; + +class ReshapeTest : public OpsTestBase {}; + +void TestReshape(const std::vector &org_shape, + const std::vector &output_shape, + const std::vector &res_shape) { + + // Construct graph + OpsTestNet net; + OpDefBuilder("Reshape", "ReshapeTest") + .Input("Input") + .Output("Output") + .AddIntsArg("shape", output_shape) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", org_shape); + + // Run + net.RunOp(); + + auto input = net.GetTensor("Input"); + auto output = net.GetTensor("Output"); + + EXPECT_THAT(output->shape(), ::testing::ContainerEq(res_shape)); + + const float *input_ptr = input->data(); + const float *output_ptr = output->data(); + const int size = output->size(); + for (int i = 0; i < size; ++i) { + ASSERT_EQ(input_ptr[i], output_ptr[i]); + } +} + +TEST_F(ReshapeTest, Simple) { + TestReshape({1, 2, 3, 4}, {1, 2, -1, 4}, {1, 2, 3, 4}); + TestReshape({1, 2, 3, 4}, {1, 2, -1, 2}, {1, 2, 6, 2}); + TestReshape({1, 2, 3, 4}, {1, -1, 3, 2}, {1, 4, 3, 2}); + TestReshape({1, 2, 3, 4}, {2, 2, 3, 2}, {2, 2, 3, 2}); +} + +TEST_F(ReshapeTest, Complex) { + TestReshape({1, 2, 3, 4}, {-1}, {24}); + TestReshape({1, 2, 3, 4}, {1, -1}, {1, 24}); + TestReshape({1, 2, 3, 4}, {-1, 1}, {24, 1}); + TestReshape({1, 2, 3, 4}, {1, 3, 8}, {1, 3, 8}); +}