space_to_batch.h 3.3 KB
Newer Older
L
Liangliang He 已提交
1
// Copyright 2018 Xiaomi, Inc.  All rights reserved.
L
liuqi 已提交
2
//
L
Liangliang He 已提交
3 4 5
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
L
liuqi 已提交
6
//
L
Liangliang He 已提交
7 8 9 10 11 12 13
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
L
liuqi 已提交
14 15 16 17 18

#ifndef MACE_OPS_SPACE_TO_BATCH_H_
#define MACE_OPS_SPACE_TO_BATCH_H_

#include <memory>
L
liutuo 已提交
19
#include <vector>
L
liuqi 已提交
20 21 22 23 24

#include "mace/core/operator.h"
#include "mace/kernels/space_to_batch.h"

namespace mace {
L
liutuo 已提交
25
namespace ops {
L
liuqi 已提交
26

27
template <DeviceType D, typename T>
L
liuqi 已提交
28 29 30
class SpaceToBatchNDOp : public Operator<D, T> {
 public:
  SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws)
L
liuqi 已提交
31 32 33 34 35
      : Operator<D, T>(op_def, ws),
        functor_(
            OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0}),
            OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
            false) {}
L
liuqi 已提交
36

37
  bool Run(StatsFuture *future) override {
38 39
    const Tensor *space_tensor = this->Input(INPUT);
    Tensor *batch_tensor = this->Output(OUTPUT);
L
liuqi 已提交
40

L
liuqi 已提交
41
    std::vector<index_t> output_shape(4, 0);
42
    CalculateOutputShape(space_tensor, batch_tensor, output_shape.data());
43 44
    functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor,
             future);
L
liuqi 已提交
45 46 47
    return true;
  }

L
liuqi 已提交
48
 private:
49 50 51
  inline void CalculateOutputShape(const Tensor *input_tensor,
                                   Tensor *output,
                                   index_t *output_shape) {
52 53 54 55
    auto paddings =
        OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
    auto block_shape =
        OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
L
liuqi 已提交
56 57 58 59 60 61 62
    MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
    MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
    MACE_CHECK(paddings.size() == 4, "Paddings' shape should be 2D");

    const index_t block_dims = block_shape.size();
    index_t block_shape_product = 1;
    for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
63 64
      MACE_CHECK(block_shape[block_dim] > 1,
                 "block_shape's value should be great to 1");
65
      const index_t block_shape_value = block_shape[block_dim];
66 67 68 69 70
      const index_t padded_input_size = input_tensor->dim(block_dim + 1) +
                                        paddings[block_dim * 2] +
                                        paddings[block_dim * 2 + 1];
      MACE_CHECK(padded_input_size % block_shape_value == 0, "padded input ",
                 padded_input_size, " is not divisible by block_shape");
71 72
      block_shape_product *= block_shape_value;
      output_shape[block_dim + 1] = padded_input_size / block_shape_value;
L
liuqi 已提交
73 74 75 76 77
    }
    output_shape[0] = input_tensor->dim(0) * block_shape_product;
    output_shape[3] = input_tensor->dim(3);
  }

L
liuqi 已提交
78 79 80 81
 private:
  kernels::SpaceToBatchFunctor<D, T> functor_;

 protected:
L
liuqi 已提交
82
  OP_INPUT_TAGS(INPUT);
L
liuqi 已提交
83 84 85
  OP_OUTPUT_TAGS(OUTPUT);
};

L
liutuo 已提交
86
}  // namespace ops
L
liuqi 已提交
87 88 89
}  // namespace mace

#endif  // MACE_OPS_SPACE_TO_BATCH_H_