From 9a17b40bd67f87e55d034d59b009c0a907da64f2 Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 19 Mar 2018 19:00:56 +0800 Subject: [PATCH] Add ops/depth_to_space files --- mace/ops/depth_to_space.cc | 24 ++++++++++++ mace/ops/depth_to_space.h | 76 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 mace/ops/depth_to_space.cc create mode 100644 mace/ops/depth_to_space.h diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc new file mode 100644 index 00000000..e34d5871 --- /dev/null +++ b/mace/ops/depth_to_space.cc @@ -0,0 +1,24 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/depth_to_space.h" + +namespace mace { +namespace ops { + +void Register_DepthToSpace(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h new file mode 100644 index 00000000..e05152a8 --- /dev/null +++ b/mace/ops/depth_to_space.h @@ -0,0 +1,76 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_DEPTH_TO_SPACE_H_ +#define MACE_OPS_DEPTH_TO_SPACE_H_ + +#include +#include + +#include "mace/core/operator.h" +#include "mace/kernels/depth_to_space.h" + +namespace mace { +namespace ops { + +template +class DepthToSpaceOp : public Operator { + public: + DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(OperatorBase::GetRepeatedArgument("crops", {0, 0, 0, 0}), + OperatorBase::GetRepeatedArgument("block_shape", {1, 1}), + true) {} + + bool Run(StatsFuture *future) override { + const Tensor *batch_tensor = this->Input(INPUT); + Tensor *space_tensor = this->Output(OUTPUT); + + std::vector output_shape(4, 0); + CalculateOutputShape(batch_tensor, space_tensor, output_shape.data()); + functor_(space_tensor, output_shape, const_cast(batch_tensor), + future); + return true; + } + + private: + inline void CalculateOutputShape(const Tensor *input_tensor, + Tensor *output, + index_t *output_shape) { + auto crops = OperatorBase::GetRepeatedArgument("crops", {0, 0, 0, 0}); + auto block_shape = + OperatorBase::GetRepeatedArgument("block_shape", {1, 1}); + MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); + MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D"); + MACE_CHECK(crops.size() == 4, "Crops' shape should be 2D"); + + const index_t block_dims = block_shape.size(); + index_t block_shape_product = 1; + for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) { + MACE_CHECK(block_shape[block_dim] > 1, + "block_shape's value should be great to 1"); + const index_t block_shape_value = block_shape[block_dim]; + const index_t cropped_input_size = + input_tensor->dim(block_dim + 1) * block_shape_value - + crops[block_dim * 2] - crops[block_dim * 2 + 1]; + MACE_CHECK(cropped_input_size >= 0, "cropped size must be non-negative"); + block_shape_product *= block_shape_value; + output_shape[block_dim + 1] = cropped_input_size; + } + output_shape[0] = input_tensor->dim(0) / block_shape_product; + output_shape[3] = input_tensor->dim(3); + } + + private: + kernels::DepthToSpaceOpFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_DEPTH_TO_SPACE_H_ -- GitLab