concat.h 2.4 KB
Newer Older
L
liuqi 已提交
1 2 3 4 5 6 7 8 9
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//

#ifndef MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_

#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
L
Liangliang He 已提交
10
#include "mace/proto/mace.pb.h"
L
liuqi 已提交
11 12
namespace mace {

L
Liangliang He 已提交
13
template <DeviceType D, typename T>
L
liuqi 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27
class ConcatOp : public Operator<D, T> {
 public:
  ConcatOp(const OperatorDef &op_def, Workspace *ws)
      : Operator<D, T>(op_def, ws) {}

  bool Run() override {
    int32_t values_count = this->InputSize() - 1;
    const Tensor *input0 = this->Input(0);
    const Tensor *axis_tensor = this->Input(values_count);
    MACE_CHECK(axis_tensor->dim_size() == 0,
               "axis should be a scalar integer, but got shape: ",
               axis_tensor->dim_size());
    const int32_t concat_axis = *(axis_tensor->data<int32_t>());
    const int32_t input_dims = input0->dim_size();
L
Liangliang He 已提交
28 29 30 31 32
    const int32_t axis =
        concat_axis < 0 ? concat_axis + input_dims : concat_axis;
    MACE_CHECK((0 <= axis && axis < input_dims),
               "Expected concatenating axis in the range [", -input_dims, ", ",
               input_dims, "], but got", concat_axis);
L
liuqi 已提交
33 34 35 36 37 38 39 40 41 42 43 44
    std::vector<index_t> output_shape(input0->shape());
    index_t inner_size = 1;
    for (int i = 0; i < axis; ++i) {
      inner_size *= output_shape[i];
    }
    std::vector<index_t> outer_sizes(values_count, 0);
    std::vector<const T *> input_list(values_count, nullptr);
    input_list[0] = input0->data<T>();
    outer_sizes[0] = input0->size() / inner_size;
    const Tensor *input = nullptr;
    for (int i = 1; i < values_count; ++i) {
      input = this->Input(i);
L
Liangliang He 已提交
45 46
      MACE_CHECK(input->dim_size() == input0->dim_size(),
                 "Ranks of all input tensors must be same.");
L
liuqi 已提交
47
      for (int j = 0; j < axis_tensor->dim_size(); ++j) {
L
Liangliang He 已提交
48 49 50 51 52
        if (j == axis) {
          continue;
        }
        MACE_CHECK(input->dim(j) == input0->dim(j),
                   "Dimensions of inputs should equal except axis.");
L
liuqi 已提交
53 54 55 56 57 58 59 60 61
      }
      input_list[i] = input->data<T>();
      outer_sizes[i] = input->size() / inner_size;
      output_shape[axis] += input->dim(axis);
    }

    Tensor *output = this->Output(OUTPUT);
    output->Resize(output_shape);

L
Liangliang He 已提交
62 63
    functor_(input_list, inner_size, outer_sizes.data(),
             output->mutable_data<T>());
L
liuqi 已提交
64 65
    return true;
  }
L
Liangliang He 已提交
66

L
liuqi 已提交
67 68 69 70 71 72 73
 private:
  kernels::ConcatFunctor<D, T> functor_;

 private:
  OP_OUTPUT_TAGS(OUTPUT);
};

L
Liangliang He 已提交
74
}  //  namespace mace
L
liuqi 已提交
75

L
Liangliang He 已提交
76
#endif  //  MACE_OPS_CONCAT_H_