sequence_expand_op.h 5.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15 16

#pragma once

Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
Y
yangyaming 已提交
19
#include "paddle/fluid/operators/math/math_function.h"
W
wanghaoshuang 已提交
20 21 22 23 24

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
Y
yangyaming 已提交
25 26 27
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
W
wanghaoshuang 已提交
28

Q
QI JUN 已提交
29
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
30
class SequenceExpandKernel : public framework::OpKernel<T> {
W
wanghaoshuang 已提交
31 32 33
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<LoDTensor>("X");
W
wanghaoshuang 已提交
34
    auto* y = context.Input<LoDTensor>("Y");
Y
yangyaming 已提交
35 36
    auto* out = context.Output<LoDTensor>("Out");

Y
yangyaming 已提交
37
    int ref_level = context.Attr<int>("ref_level");
Y
yangyaming 已提交
38 39
    auto& x_lod = x->lod();
    auto& y_lod = y->lod();
Y
yangyaming 已提交
40 41 42 43 44 45 46
    PADDLE_ENFORCE_GT(y_lod.size(), 0,
                      "Level number of `Y`'s lod should be greater than 0.");
    PADDLE_ENFORCE(
        ref_level == -1 || (ref_level >= 0 && ref_level < y_lod.size()),
        "Invlid `ref_level`, which should be either equal to -1 "
        "or in [0, %d)",
        y_lod.size());
Y
yangyaming 已提交
47

Y
yangyaming 已提交
48 49
    if (ref_level == -1) ref_level = y_lod.size() - 1;

Y
yangyaming 已提交
50 51
    out->mutable_data<T>(context.GetPlace());

Y
yangyaming 已提交
52
    if (y_lod[ref_level].size() <= 1) {
Y
yangyaming 已提交
53 54 55
      framework::TensorCopy(*x, context.GetPlace(), out);
      return;
    }
W
wanghaoshuang 已提交
56

Y
yangyaming 已提交
57 58
    auto& out_lod = *out->mutable_lod();
    if (x_lod.size() == 1) {
Y
yangyaming 已提交
59
      out_lod.resize(1);
Y
yangyaming 已提交
60 61 62 63
      out_lod[0] = {0};
    }

    int out_offset = 0;
Y
yangyaming 已提交
64 65
    auto& eigen_place =
        *context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
66 67 68 69 70 71 72 73 74
    for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
      int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
      int x_start = i - 1;
      int x_end = i;
      if (x_lod.size() == 1) {
        x_start = x_lod[0][i - 1];
        x_end = x_lod[0][i];
      }
      int x_seq_len = x_end - x_start;
Y
yangyaming 已提交
75 76 77
      if (repeat_num > 0) {
        auto x_sub_tensor = x->Slice(x_start, x_end);
        x_sub_tensor.Resize({1, x_sub_tensor.numel()});
Y
yangyaming 已提交
78 79 80
        int out_start = out_offset;
        if (x_lod.size() == 1) {
          out_start = out_lod[0][out_offset];
Y
yangyaming 已提交
81
        }
Y
yangyaming 已提交
82 83 84 85 86 87 88 89 90 91 92
        auto out_sub_tensor =
            out->Slice(out_start, out_start + x_seq_len * repeat_num);
        out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
        EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
            EigenMatrix<T>::From(x_sub_tensor)
                .broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
      }
      for (int j = 0; j < repeat_num; ++j) {
        if (x_lod.size() == 1) {
          out_lod[0].push_back(out_lod[0].back() + x_seq_len);
        }
Y
yangyaming 已提交
93
        out_offset++;
Y
yangyaming 已提交
94
      }
W
wanghaoshuang 已提交
95
    }
W
wanghaoshuang 已提交
96 97 98
  }
};

99 100 101 102 103 104 105 106 107 108 109 110
/*
 *Given Grad(Out)
 *
 *    Grad(Out).lod = [[0,                            2],
 *                     [0,              3,            6]]
 *    Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
 * Then
 *    Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)]
 *                 = [0.6, 1.5]
 *    Grad(X).lod = Input(X).lod
 *
 * */
Q
QI JUN 已提交
111
template <typename DeviceContext, typename T>
W
wanghaoshuang 已提交
112
class SequenceExpandGradKernel : public framework::OpKernel<T> {
W
wanghaoshuang 已提交
113 114
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
115
    auto* g_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
W
wanghaoshuang 已提交
116
    auto* x = context.Input<LoDTensor>("X");
Y
yangyaming 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    auto* y = context.Input<LoDTensor>("Y");
    auto* g_x = context.Output<LoDTensor>(framework::GradVarName("X"));
    int ref_level = context.Attr<int>("ref_level");

    g_x->mutable_data<T>(context.GetPlace());
    g_x->set_lod(x->lod());

    auto& x_lod = x->lod();
    auto& y_lod = y->lod();

    if (ref_level == -1) ref_level = y_lod.size() - 1;

    // just copy the gradient
    if (y_lod[ref_level].size() <= 1) {
      framework::TensorCopy(*g_out, context.GetPlace(), g_x);
      return;
    }

    auto& dev_ctx = context.template device_context<DeviceContext>();

Y
yangyaming 已提交
137 138 139
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(dev_ctx, g_x, static_cast<T>(0));

Y
yangyaming 已提交
140 141 142 143 144 145 146 147 148 149 150 151
    int g_out_offset = 0;
    for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
      int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
      if (repeat_num > 0) {
        int x_start = i - 1;
        int x_end = i;
        if (x_lod.size() == 1) {
          x_start = x_lod[0][i - 1];
          x_end = x_lod[0][i];
        }
        int x_seq_len = x_end - x_start;
        auto g_x_sub = g_x->Slice(x_start, x_end);
Y
yangyaming 已提交
152
        g_x_sub.Resize(flatten_to_1d(g_x_sub.dims()));
Y
yangyaming 已提交
153 154
        int g_out_end = g_out_offset + repeat_num * x_seq_len;
        auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
Y
yangyaming 已提交
155
        g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]});
Y
yangyaming 已提交
156 157 158 159
        math::ColwiseSum<DeviceContext, T> col_sum;
        col_sum(dev_ctx, g_out_sub, &g_x_sub);
        g_out_offset += repeat_num * x_seq_len;
      }
W
wanghaoshuang 已提交
160
    }
W
wanghaoshuang 已提交
161 162 163 164 165
  }
};

}  // namespace operators
}  // namespace paddle