concat_op.h 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <vector>
Y
Yi Wang 已提交
18 19
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
20 21 22 23

namespace paddle {
namespace operators {

Q
QI JUN 已提交
24
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
25
class ConcatKernel : public framework::OpKernel<T> {
26 27 28 29 30
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<framework::Tensor>("X");
    auto* out = ctx.Output<framework::Tensor>("Out");
    int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
Y
Yancey1989 已提交
31 32 33 34 35
    auto place = ctx.GetPlace();
    out->mutable_data<T>(place);

    auto out_stride = framework::stride_numel(out->dims());

36
    size_t output_offset = 0;
Y
Yancey1989 已提交
37 38
    for (auto* in : ins) {
      auto in_stride = framework::stride_numel(in->dims());
Y
fix ci  
Yancey1989 已提交
39 40
      StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
                                  out->data<T>() + output_offset, out_stride,
T
typhoonzero 已提交
41
                                  in->data<T>(), in_stride, in_stride[axis]);
Y
Yancey1989 已提交
42
      output_offset += in_stride[axis];
43 44 45 46
    }
  }
};

Q
QI JUN 已提交
47
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
48
class ConcatGradKernel : public framework::OpKernel<T> {
49 50 51 52 53 54
 public:
  void Compute(const framework::ExecutionContext& ctx) const {
    auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
    int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
    size_t input_offset = 0;
Y
Yancey1989 已提交
55 56 57
    auto in_stride = framework::stride_numel(in->dims());

    for (auto& out : outs) {
58
      out->mutable_data<T>(ctx.GetPlace());
Y
Yancey1989 已提交
59
      auto out_stride = framework::stride_numel(out->dims());
Y
fix ci  
Yancey1989 已提交
60 61
      StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
                                  out_stride, in->data<T>() + input_offset,
T
typhoonzero 已提交
62
                                  in_stride, out_stride[axis]);
Y
Yancey1989 已提交
63
      input_offset += out_stride[axis];
64 65 66 67 68 69
    }
  }
};

}  // namespace operators
}  // namespace paddle