From d1c3fef48f938319ddf7cd1e2a3cb4d755fbc65d Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 25 Dec 2017 15:33:55 +0800 Subject: [PATCH] Rename space_to_batch helper function name for readability. --- mace/ops/batch_to_space.h | 10 +++++----- mace/ops/space_to_batch.h | 31 +++++++++++++++---------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/mace/ops/batch_to_space.h b/mace/ops/batch_to_space.h index 911fc4b6..59f8e03b 100644 --- a/mace/ops/batch_to_space.h +++ b/mace/ops/batch_to_space.h @@ -24,18 +24,18 @@ class BatchToSpaceNDOp : public Operator { bool Run(StatsFuture *future) override { const Tensor *batch_tensor = this->Input(INPUT); - Tensor *space_tensor= this->Output(OUTPUT); + Tensor *space_tensor = this->Output(OUTPUT); std::vector output_shape(4, 0); - BatchToSpaceHelper(batch_tensor, space_tensor, output_shape); + CalculateOutputShape(batch_tensor, space_tensor, output_shape.data()); functor_(space_tensor, output_shape, const_cast(batch_tensor), future); return true; } private: - inline void BatchToSpaceHelper(const Tensor *input_tensor, - Tensor *output, - std::vector &output_shape) { + inline void CalculateOutputShape(const Tensor *input_tensor, + Tensor *output, + index_t *output_shape) { auto crops = OperatorBase::GetRepeatedArgument("crops", {0, 0, 0, 0}); auto block_shape = OperatorBase::GetRepeatedArgument("block_shape", {1, 1}); MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); diff --git a/mace/ops/space_to_batch.h b/mace/ops/space_to_batch.h index 58b90bbd..787b82e6 100644 --- a/mace/ops/space_to_batch.h +++ b/mace/ops/space_to_batch.h @@ -12,7 +12,6 @@ namespace mace { - template class SpaceToBatchNDOp : public Operator { public: @@ -24,20 +23,20 @@ class SpaceToBatchNDOp : public Operator { false) {} bool Run(StatsFuture *future) override { - const Tensor *space_tensor= this->Input(INPUT); - Tensor *batch_tensor= this->Output(OUTPUT); + const Tensor *space_tensor = this->Input(INPUT); + Tensor *batch_tensor = this->Output(OUTPUT); std::vector output_shape(4, 0); - SpaceToBatchHelper(space_tensor, batch_tensor, output_shape); + CalculateOutputShape(space_tensor, batch_tensor, output_shape.data()); functor_(const_cast(space_tensor), output_shape, batch_tensor, future); return true; } private: - inline void SpaceToBatchHelper(const Tensor *input_tensor, - Tensor *output, - std::vector &output_shape) { + inline void CalculateOutputShape(const Tensor *input_tensor, + Tensor *output, + index_t *output_shape) { auto paddings = OperatorBase::GetRepeatedArgument("paddings", {0, 0, 0, 0}); auto block_shape = OperatorBase::GetRepeatedArgument("block_shape", {1, 1}); MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); @@ -47,15 +46,15 @@ class SpaceToBatchNDOp : public Operator { const index_t block_dims = block_shape.size(); index_t block_shape_product = 1; for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) { - MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1"); - const index_t block_shape_value = block_shape[block_dim]; - const index_t padded_input_size = input_tensor->dim(block_dim + 1) - + paddings[block_dim * 2] - + paddings[block_dim * 2 + 1]; - MACE_CHECK(padded_input_size % block_shape_value == 0, - "padded input ", padded_input_size, " is not divisible by block_shape"); - block_shape_product *= block_shape_value; - output_shape[block_dim + 1] = padded_input_size / block_shape_value; + MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1"); + const index_t block_shape_value = block_shape[block_dim]; + const index_t padded_input_size = input_tensor->dim(block_dim + 1) + + paddings[block_dim * 2] + + paddings[block_dim * 2 + 1]; + MACE_CHECK(padded_input_size % block_shape_value == 0, + "padded input ", padded_input_size, " is not divisible by block_shape"); + block_shape_product *= block_shape_value; + output_shape[block_dim + 1] = padded_input_size / block_shape_value; } output_shape[0] = input_tensor->dim(0) * block_shape_product; output_shape[3] = input_tensor->dim(3); -- GitLab