concat_op.h 6.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <string>
18
#include <utility>
19
#include <vector>
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/op_registry.h"
C
chengduo 已提交
21
#include "paddle/fluid/operators/math/concat_and_split.h"
Y
Yi Wang 已提交
22
#include "paddle/fluid/operators/strided_memcpy.h"
23
#include "paddle/fluid/operators/utils.h"
24 25 26

namespace paddle {
namespace operators {
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
static inline framework::DDim ComputeAndCheckShape(
    const bool is_runtime, const std::vector<framework::DDim>& inputs_dims,
    const int axis) {
  const size_t n = inputs_dims.size();
  auto out_dims = inputs_dims[0];
  size_t in_zero_dims_size = out_dims.size();
  for (size_t i = 1; i < n; i++) {
    for (size_t j = 0; j < in_zero_dims_size; j++) {
      if (j == axis) {
        if (is_runtime) {
          out_dims[axis] += inputs_dims[i][j];
        } else {
          if (inputs_dims[i][j] == -1) {
            out_dims[axis] = -1;
          } else {
            out_dims[axis] += inputs_dims[i][j];
          }
        }
      } else {
        bool check_shape =
            is_runtime || (out_dims[j] > 0 && inputs_dims[i][j] > 0);
        if (check_shape) {
          // check all shape in run time
          PADDLE_ENFORCE_EQ(
              inputs_dims[0][j], inputs_dims[i][j],
              "ShapeError: Dimension %d in inputs' shapes must be equal. "
              "But recevied input[0]'s shape = "
              "[%s], input[%d]'s shape = [%s].",
              j, inputs_dims[0], i, inputs_dims[i]);
        }
      }
    }
  }
  return out_dims;
}
62

63 64 65 66 67 68 69
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
  if (axis < 0) {
    axis = axis + rank;
  }
  return axis > 0 ? axis : 0;
}

Q
QI JUN 已提交
70
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
71
class ConcatKernel : public framework::OpKernel<T> {
72 73 74
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<framework::Tensor>("X");
C
chengduoZH 已提交
75
    framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
    auto axis = ctx.Attr<int>("axis");
    bool need_resize_out_dims = false;
    if (ctx.HasInput("AxisTensor")) {
      auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
      axis = GetDataFromTensor<int>(axis_tensor)[0];
      need_resize_out_dims = true;
    }
    axis = ComputeAxis(static_cast<int64_t>(axis),
                       static_cast<int64_t>(ins[0]->dims().size()));

    if (need_resize_out_dims) {
      const size_t n = ins.size();
      std::vector<framework::DDim> ins_dims(n);
      for (size_t i = 0; i < n; i++) {
        ins_dims[i] = ins[i]->dims();
      }

      framework::DDim out_dims = ComputeAndCheckShape(true, ins_dims, axis);
      out->Resize(out_dims);
    }
Y
Yancey1989 已提交
97 98
    auto place = ctx.GetPlace();
    out->mutable_data<T>(place);
C
chengduoZH 已提交
99

C
chengduoZH 已提交
100 101 102 103
    // Sometimes direct copies will be faster, this maybe need deeply analysis.
    if (axis == 0 && ins.size() < 10) {
      size_t output_offset = 0;
      for (auto* in : ins) {
104 105 106
        if (!in || in->numel() == 0UL) {
          continue;
        }
C
chengduoZH 已提交
107 108 109 110 111 112 113 114
        auto in_stride = framework::stride_numel(in->dims());
        auto out_stride = framework::stride_numel(out->dims());
        StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
                                    out->data<T>() + output_offset, out_stride,
                                    in->data<T>(), in_stride, in_stride[axis]);
        output_offset += in_stride[axis];
      }
    } else {
115
      std::vector<framework::Tensor> inputs;
C
chengduoZH 已提交
116
      for (size_t j = 0; j < ins.size(); ++j) {
117 118 119 120 121
        if (ins[j] && ins[j]->numel() > 0) {
          inputs.push_back(*ins[j]);
        } else {
          continue;
        }
C
chengduoZH 已提交
122 123 124 125
      }
      auto& dev_ctx = ctx.template device_context<DeviceContext>();
      paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor;
      concat_functor(dev_ctx, inputs, static_cast<int>(axis), out);
126 127 128 129
    }
  }
};

Q
QI JUN 已提交
130
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
131
class ConcatGradKernel : public framework::OpKernel<T> {
132 133
 public:
  void Compute(const framework::ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
134 135
    auto* out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
136
    auto ins = ctx.MultiInput<framework::LoDTensor>("X");
Q
qiaolongfei 已提交
137
    auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
138 139 140 141 142 143 144 145 146 147 148 149
    auto outs =
        ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));

    {
      auto dx = outs;
      auto x = ins;
      for (size_t i = 0; i < dx.size(); ++i) {
        if (dx[i] != nullptr) {
          dx[i]->set_lod(x[i]->lod());
        }
      }
    }
150
    PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
Y
Yancey1989 已提交
151

152 153 154 155 156 157 158
    auto axis = ctx.Attr<int>("axis");
    if (ctx.HasInput("AxisTensor")) {
      auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
      axis = GetDataFromTensor<int>(axis_tensor)[0];
    }
    axis = ComputeAxis(static_cast<int64_t>(axis),
                       static_cast<int64_t>(ins[0]->dims().size()));
Q
qiaolongfei 已提交
159 160 161
    // get output tensor that the name is not kEmptyVarName
    std::vector<framework::Tensor*> outputs;
    for (size_t j = 0; j < outs.size(); ++j) {
162 163
      if (out_var_names[j] != framework::kEmptyVarName &&
          outs[j]->numel() != 0UL) {
Q
qiaolongfei 已提交
164 165 166 167 168 169
        outs[j]->mutable_data<T>(ctx.GetPlace());
        outputs.push_back(outs[j]);
      } else {
        outputs.push_back(nullptr);
      }
    }
C
chengduo 已提交
170
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
Q
qiaolongfei 已提交
171

C
chengduoZH 已提交
172 173
    // Sometimes direct copies will be faster, this maybe need deeply analysis.
    if (axis == 0 && outs.size() < 10) {
C
chengduo 已提交
174 175 176
      std::vector<const framework::Tensor*> ref_shape;
      ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end());
      StridedMemcpyWithAxis0<T>(dev_ctx, *out_grad, ref_shape, &outputs);
C
chengduoZH 已提交
177
    } else {
C
chengduo 已提交
178 179 180
      math::SplitFunctor<DeviceContext, T> split_functor;
      split_functor(dev_ctx, *out_grad, ctx.MultiInput<framework::Tensor>("X"),
                    static_cast<int>(axis), &outputs);
C
chengduoZH 已提交
181
    }
182 183 184 185 186
  }
};

}  // namespace operators
}  // namespace paddle