sequence_concat_op.h 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yancey1989 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
16 17
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
Y
Yancey1989 已提交
18 19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;

template <typename T>
27
LoD ConcatLoD(const std::vector<const T*> ins, const size_t level) {
Y
Yancey1989 已提交
28
  auto out_lod = ins[0]->lod();
29
  auto numLevels = ins[0]->NumLevels();
Y
Yancey1989 已提交
30
  const size_t n = ins.size();
31 32 33 34 35 36
  const size_t level_idx = ins[0]->NumLevels() - 1 - level;
  for (size_t i = 1; i < n; ++i) {
    for (size_t j = 0; j < ins[i]->lod()[level_idx].size(); ++j) {
      out_lod[level_idx][j] += ins[i]->lod()[level_idx][j];
    }
  }
Y
update  
Yancey1989 已提交
37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
  for (size_t i = level_idx; i < numLevels - 1; ++i) {
    size_t lod_len = 1;
    for (size_t j = 0; j < n; ++j) {
      lod_len += ins[j]->lod()[i + 1].size() - 1;
    }
    out_lod[i + 1].clear();
    out_lod[i + 1].resize(lod_len);

    size_t idx = 1;
    for (size_t j = 0; j < ins[0]->lod()[i].size() - 1; ++j) {
      for (size_t k = 0; k < n; ++k) {
        for (size_t m = ins[k]->lod()[i][j]; m < ins[k]->lod()[i][j + 1]; ++m) {
          out_lod[i + 1][idx] = out_lod[i + 1][idx - 1] +
                                ins[k]->lod()[i + 1][m + 1] -
                                ins[k]->lod()[i + 1][m];
          idx++;
Y
Yancey1989 已提交
54 55 56 57
        }
      }
    }
  }
58

Y
Yancey1989 已提交
59 60 61
  return out_lod;
}

Q
QI JUN 已提交
62
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
63
class SequenceConcatOpKernel : public framework::OpKernel<T> {
Y
Yancey1989 已提交
64 65 66 67 68 69 70
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<LoDTensor>("X");
    auto* out = ctx.Output<LoDTensor>("Out");
    const size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
    const size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
    const size_t n = ins.size();
Y
Yancey1989 已提交
71 72 73

    for (size_t i = 1; i < n; ++i) {
      PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), ins[i]->NumLevels(),
Y
update  
Yancey1989 已提交
74
                        "The levels of all the input LoDTensors "
Y
Yancey1989 已提交
75 76
                        "should be the same.");
      PADDLE_ENFORCE_EQ(ins[0]->dims().size(), ins[i]->dims().size(),
Y
Yancey1989 已提交
77
                        "The dimension size of all the input LoDTensors "
Y
Yancey1989 已提交
78 79 80 81 82 83
                        "should be the same.");

      const size_t dims_size = ins[i]->dims().size();
      for (size_t j = 0; j < dims_size; ++j) {
        if (j == axis) continue;
        PADDLE_ENFORCE_EQ(ins[0]->dims()[j], ins[i]->dims()[j],
Y
Yancey1989 已提交
84 85 86 87
                          "Except for the dimension of the specified "
                          "axis along which all the inputs are concatenated, "
                          "dimensions of all the other axises of the input "
                          "LoDTensors should be the same.");
Y
Yancey1989 已提交
88 89
      }
    }
Y
Yancey1989 已提交
90 91 92
    PADDLE_ENFORCE_GT(ins[0]->NumLevels(), level,
                      "The levels of all the input LoDTensors "
                      "should be greater than the specify level");
Y
Yancey1989 已提交
93

Y
Yancey1989 已提交
94
    out->mutable_data<T>(ctx.GetPlace());
95 96 97 98
    auto out_lod = ins[0]->lod();
    if (axis == 0) {
      out_lod = ConcatLoD<LoDTensor>(ins, level);
    }
Y
Yancey1989 已提交
99 100
    out->set_lod(out_lod);

101 102
    const size_t level_idx = out_lod.size() - level - 1;
    auto out_lod_level = framework::ToAbsOffset(out_lod)[level_idx];
Y
Yancey1989 已提交
103
    for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
104 105
      Tensor out_t = out->Slice(static_cast<int>(out_lod_level[i]),
                                static_cast<int>(out_lod_level[i + 1]));
Y
Yancey1989 已提交
106 107
      auto out_stride = framework::stride(out_t.dims());
      size_t offset = 0;
Y
Yancey1989 已提交
108
      for (size_t j = 0; j < n; ++j) {
109
        auto in_lod_level = framework::ToAbsOffset(ins[j]->lod())[level_idx];
Y
Yancey1989 已提交
110
        auto in_stride = framework::stride(ins[j]->dims());
111 112
        Tensor in_t = ins[j]->Slice(static_cast<int>(in_lod_level[i]),
                                    static_cast<int>(in_lod_level[i + 1]));
Y
Yancey1989 已提交
113 114 115 116 117 118 119 120 121
        size_t axis_dim = in_t.dims()[axis];
        StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(), in_stride,
                         in_t.dims(), out_stride, out_t.data<T>() + offset);
        offset += axis_dim * in_stride[axis];
      }
    }
  }
};

Q
QI JUN 已提交
122
template <typename DeviceContext, typename T>
Y
Yancey1989 已提交
123
class SequenceConcatGradOpKernel : public framework::OpKernel<T> {
Y
Yancey1989 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<framework::LoDTensor>("X");
    auto* out_grad =
        ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto x_grads =
        ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
    size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
    size_t level = static_cast<size_t>(ctx.Attr<int>("level"));
    const size_t n = x_grads.size();

    // Set Grad(X) LoD as X
    for (size_t i = 0; i < n; i++) {
      x_grads[i]->set_lod(ins[i]->lod());
      x_grads[i]->mutable_data<T>(ctx.GetPlace());
    }
140 141 142 143 144 145
    auto out_lod = ins[0]->lod();
    if (axis == 0UL) {
      out_lod = ConcatLoD<LoDTensor>(ins, level);
    }
    const size_t level_idx = out_lod.size() - level - 1;
    auto out_lod_level = framework::ToAbsOffset(out_lod)[level_idx];
Y
Yancey1989 已提交
146

Y
Yancey1989 已提交
147
    for (size_t i = 0; i < out_lod_level.size() - 1; ++i) {
Y
Yancey1989 已提交
148
      Tensor out_grad_t =
149 150
          out_grad->Slice(static_cast<int>(out_lod_level[i]),
                          static_cast<int>(out_lod_level[i + 1]));
Y
Yancey1989 已提交
151 152 153
      auto out_grad_stride = framework::stride(out_grad_t.dims());
      size_t offset = 0;

Y
Yancey1989 已提交
154
      for (size_t j = 0; j < n; ++j) {
155 156
        auto x_grad_lod_level =
            framework::ToAbsOffset(x_grads[j]->lod())[level_idx];
Y
Yancey1989 已提交
157 158
        auto x_grad_stride = framework::stride(x_grads[j]->dims());
        Tensor x_grad_t =
159 160
            x_grads[j]->Slice(static_cast<int>(x_grad_lod_level[i]),
                              static_cast<int>(x_grad_lod_level[i + 1]));
Y
Yancey1989 已提交
161 162 163 164 165 166 167 168 169 170 171 172
        size_t axis_dim = x_grad_t.dims()[axis];
        StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>() + offset,
                         out_grad_stride, out_grad_t.dims(), x_grad_stride,
                         x_grad_t.data<T>());
        offset += axis_dim * out_grad_stride[axis];
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle