concat.h 1.3 KB
Newer Older
L
liuqi 已提交
1 2 3 4 5 6 7 8 9 10 11
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//

#ifndef MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_

#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
namespace mace {

L
Liangliang He 已提交
12
template <DeviceType D, typename T>
L
liuqi 已提交
13 14 15
class ConcatOp : public Operator<D, T> {
 public:
  ConcatOp(const OperatorDef &op_def, Workspace *ws)
16
      : Operator<D, T>(op_def, ws),
17
        functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
L
liuqi 已提交
18

19
  bool Run(StatsFuture *future) override {
20 21
    MACE_CHECK(this->InputSize() >= 2)
        << "There must be at least two inputs to concat";
22
    const std::vector<const Tensor *> input_list = this->Inputs();
23
    const int32_t concat_axis = OperatorBase::GetSingleArgument<int>("axis", 3);
24
    const int32_t input_dims = input_list[0]->dim_size();
L
Liangliang He 已提交
25 26 27 28 29
    const int32_t axis =
        concat_axis < 0 ? concat_axis + input_dims : concat_axis;
    MACE_CHECK((0 <= axis && axis < input_dims),
               "Expected concatenating axis in the range [", -input_dims, ", ",
               input_dims, "], but got", concat_axis);
L
liuqi 已提交
30 31 32

    Tensor *output = this->Output(OUTPUT);

33
    functor_(input_list, output, future);
L
liuqi 已提交
34 35
    return true;
  }
L
Liangliang He 已提交
36

L
liuqi 已提交
37 38 39 40 41 42 43
 private:
  kernels::ConcatFunctor<D, T> functor_;

 private:
  OP_OUTPUT_TAGS(OUTPUT);
};

L
Liangliang He 已提交
44
}  // namespace mace
L
liuqi 已提交
45

L
Liangliang He 已提交
46
#endif  // MACE_OPS_CONCAT_H_