space_to_batch.h 2.6 KB
Newer Older
L
liuqi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//

#ifndef MACE_OPS_SPACE_TO_BATCH_H_
#define MACE_OPS_SPACE_TO_BATCH_H_

#include <memory>

#include "mace/core/operator.h"
#include "mace/kernels/space_to_batch.h"

namespace mace {

L
liuqi 已提交
15
template<DeviceType D, typename T>
L
liuqi 已提交
16 17 18
class SpaceToBatchNDOp : public Operator<D, T> {
 public:
  SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws)
L
liuqi 已提交
19 20 21 22 23
      : Operator<D, T>(op_def, ws),
        functor_(
            OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0}),
            OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
            false) {}
L
liuqi 已提交
24

25
  bool Run(StatsFuture *future) override {
26 27
    const Tensor *space_tensor = this->Input(INPUT);
    Tensor *batch_tensor = this->Output(OUTPUT);
L
liuqi 已提交
28

L
liuqi 已提交
29
    std::vector<index_t> output_shape(4, 0);
30
    CalculateOutputShape(space_tensor, batch_tensor, output_shape.data());
L
liuqi 已提交
31
    functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
L
liuqi 已提交
32 33 34
    return true;
  }

L
liuqi 已提交
35 36
 private:

37 38 39
  inline void CalculateOutputShape(const Tensor *input_tensor,
                                   Tensor *output,
                                   index_t *output_shape) {
L
liuqi 已提交
40 41 42 43 44 45 46 47 48
    auto paddings = OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
    auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
    MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
    MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
    MACE_CHECK(paddings.size() == 4, "Paddings' shape should be 2D");

    const index_t block_dims = block_shape.size();
    index_t block_shape_product = 1;
    for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
49 50 51 52 53 54 55 56 57
      MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1");
      const index_t block_shape_value = block_shape[block_dim];
      const index_t padded_input_size = input_tensor->dim(block_dim + 1)
          + paddings[block_dim * 2]
          + paddings[block_dim * 2 + 1];
      MACE_CHECK(padded_input_size % block_shape_value == 0,
                 "padded input ", padded_input_size, " is not divisible by block_shape");
      block_shape_product *= block_shape_value;
      output_shape[block_dim + 1] = padded_input_size / block_shape_value;
L
liuqi 已提交
58 59 60 61 62
    }
    output_shape[0] = input_tensor->dim(0) * block_shape_product;
    output_shape[3] = input_tensor->dim(3);
  }

L
liuqi 已提交
63 64 65 66
 private:
  kernels::SpaceToBatchFunctor<D, T> functor_;

 protected:
L
liuqi 已提交
67
  OP_INPUT_TAGS(INPUT);
L
liuqi 已提交
68 69 70 71 72 73
  OP_OUTPUT_TAGS(OUTPUT);
};

}  // namespace mace

#endif  // MACE_OPS_SPACE_TO_BATCH_H_