concat.h 1.9 KB
Newer Older
L
Liangliang He 已提交
1
// Copyright 2018 Xiaomi, Inc.  All rights reserved.
L
liuqi 已提交
2
//
L
Liangliang He 已提交
3 4 5
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
L
liuqi 已提交
6
//
L
Liangliang He 已提交
7 8 9 10 11 12 13
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
L
liuqi 已提交
14 15 16 17

#ifndef MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_

L
liutuo 已提交
18 19
#include <vector>

L
liuqi 已提交
20 21
#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
L
liutuo 已提交
22

L
liuqi 已提交
23
namespace mace {
L
liutuo 已提交
24
namespace ops {
L
liuqi 已提交
25

L
Liangliang He 已提交
26
template <DeviceType D, typename T>
L
liuqi 已提交
27 28 29
class ConcatOp : public Operator<D, T> {
 public:
  ConcatOp(const OperatorDef &op_def, Workspace *ws)
30
      : Operator<D, T>(op_def, ws),
31
        functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
L
liuqi 已提交
32

33
  bool Run(StatsFuture *future) override {
34 35
    MACE_CHECK(this->InputSize() >= 2)
        << "There must be at least two inputs to concat";
36
    const std::vector<const Tensor *> input_list = this->Inputs();
37
    const int32_t concat_axis = OperatorBase::GetSingleArgument<int>("axis", 3);
38
    const int32_t input_dims = input_list[0]->dim_size();
L
Liangliang He 已提交
39 40 41 42 43
    const int32_t axis =
        concat_axis < 0 ? concat_axis + input_dims : concat_axis;
    MACE_CHECK((0 <= axis && axis < input_dims),
               "Expected concatenating axis in the range [", -input_dims, ", ",
               input_dims, "], but got", concat_axis);
L
liuqi 已提交
44 45 46

    Tensor *output = this->Output(OUTPUT);

47
    functor_(input_list, output, future);
L
liuqi 已提交
48 49
    return true;
  }
L
Liangliang He 已提交
50

L
liuqi 已提交
51 52 53 54 55 56 57
 private:
  kernels::ConcatFunctor<D, T> functor_;

 private:
  OP_OUTPUT_TAGS(OUTPUT);
};

L
liutuo 已提交
58
}  // namespace ops
L
Liangliang He 已提交
59
}  // namespace mace
L
liuqi 已提交
60

L
Liangliang He 已提交
61
#endif  // MACE_OPS_CONCAT_H_