shrink_rnn_memory_op.cc 7.0 KB
Newer Older
L
Luo Tao 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Y
Yang Yu 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yang Yu 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yang Yu 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yang Yu 已提交
14 15 16 17 18 19 20
#include "paddle/framework/lod_rank_table.h"
#include "paddle/operators/array_operator.h"
#include "paddle/operators/math/math_function.h"

namespace paddle {
namespace operators {

Y
Yang Yu 已提交
21
class ShrinkRNNMemoryOp : public ArrayOp {
Y
Yang Yu 已提交
22
 public:
Y
Yang Yu 已提交
23 24 25 26
  ShrinkRNNMemoryOp(const std::string &type,
                    const framework::VariableNameMap &inputs,
                    const framework::VariableNameMap &outputs,
                    const framework::AttributeMap &attrs)
Y
Yang Yu 已提交
27 28 29
      : ArrayOp(type, inputs, outputs, attrs) {}

  void Run(const framework::Scope &scope,
D
dzhwinter 已提交
30
           const platform::Place &place) const override {
Y
Yang Yu 已提交
31 32 33
    auto *x_var = scope.FindVar(Input("X"));
    PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
    auto &x_tensor = x_var->Get<framework::LoDTensor>();
D
dzhwinter 已提交
34
    size_t offset = this->GetOffset(scope, place);
Y
Yang Yu 已提交
35 36 37 38
    auto *rank_table_var = scope.FindVar(Input("RankTable"));
    PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
    auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();

Y
Yang Yu 已提交
39 40 41 42 43 44
    auto &rank_items = rank_table.items();
    int dst_num_rows =
        std::lower_bound(rank_items.begin(), rank_items.end(), offset,
                         [](const framework::LoDRankTable::TableItem &a,
                            size_t b) { return a.length > b; }) -
        rank_items.begin();
Y
Yang Yu 已提交
45 46 47 48

    auto *out_var = scope.FindVar(Output("Out"));
    PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
    auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
Y
yangyaming 已提交
49 50 51 52 53 54 55 56 57 58 59

    // should consider multiple levels
    size_t height = dst_num_rows;
    auto lod_level = lod_rank_table.level();
    if (x_tensor.lod().size() > lod_level &&
        x_tensor.lod()[lod_level].size() < dst_num_rows) {
      auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
          x_tensor.lod(), 0, dst_num_rows + 1, lod_level);
      height = lod_offset.second.second;
    }

Y
Yang Yu 已提交
60
    if (dst_num_rows != 0) {
Y
yangyaming 已提交
61
      out_tensor.ShareDataWith(x_tensor.Slice(0, height));
Y
Yang Yu 已提交
62 63 64 65
    }
  }
};

Y
Yang Yu 已提交
66
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
Y
Yang Yu 已提交
67
 public:
68
  ShrinkRNNMemoryOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
Y
Yang Yu 已提交
69
      : OpProtoAndCheckerMaker(proto, op_checker) {
70 71 72 73 74 75 76 77
    AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
    AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
    AddInput("I",
             "(LoDTensor) The step index. The RNN step memory 'X' will be "
             "shrinked to match the size of the input of the index'th step.");
    AddOutput("Out", "(LoDTensor) The shrinked RNN step memory.");
    AddComment(
        R"DOC(
Y
yangyaming 已提交
78 79
        In dynamic RNN, we are able to handle sequences of different lengths.
        Because of the multiple lengths, the size of each step input can be
80
        different, which may lead to a mismatching between the input of
Y
yangyaming 已提交
81 82
        the current step and the memory generated by the previous one. This
        operator shrinks memory according to the size of the next step input,
83 84
        to make sure that they can match each other.
        )DOC");
Y
Yang Yu 已提交
85 86 87
  }
};

Y
Yang Yu 已提交
88
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
Y
Yang Yu 已提交
89 90 91 92 93 94 95 96 97
 public:
  void operator()(framework::InferShapeContext *context) const override {
    PADDLE_ENFORCE(context->HasInput("X"));
    PADDLE_ENFORCE(context->HasInput("I"));
    PADDLE_ENFORCE(context->HasInput("RankTable"));
    context->SetOutputDim("Out", context->GetInputDim("X"));
  }
};

Y
Yang Yu 已提交
98
class ShrinkRNNMemoryGradOp : public ArrayOp {
Y
Yang Yu 已提交
99
 public:
Y
Yang Yu 已提交
100 101 102 103
  ShrinkRNNMemoryGradOp(const std::string &type,
                        const framework::VariableNameMap &inputs,
                        const framework::VariableNameMap &outputs,
                        const framework::AttributeMap &attrs)
Y
Yang Yu 已提交
104 105 106
      : ArrayOp(type, inputs, outputs, attrs) {}

  void Run(const framework::Scope &scope,
D
dzhwinter 已提交
107
           const platform::Place &place) const override {
Y
Yang Yu 已提交
108
    auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
Y
Yang Yu 已提交
109
    auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
Y
Yang Yu 已提交
110 111 112 113 114 115 116 117 118
    PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
    auto *x_var = scope.FindVar(Input("X"));
    PADDLE_ENFORCE(x_var != nullptr);

    auto &x_tensor = x_var->Get<framework::LoDTensor>();
    auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
    dx_tensor.Resize(x_tensor.dims());
    dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());

D
dzhwinter 已提交
119
    // get device context from pool
Y
Yang Yu 已提交
120 121
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);
D
dzhwinter 已提交
122

Y
Yang Yu 已提交
123 124 125 126 127
    if (dout_var == nullptr) {  // dx_tensor fill zero
      math::set_constant(dev_ctx, &dx_tensor, 0.0f);
    } else {
      auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
      auto height = dout_tensor.dims()[0];
D
dzhwinter 已提交
128 129
      auto slice = dx_tensor.Slice(0, static_cast<int>(height));
      framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
Y
Refine  
Yang Yu 已提交
130
      if (dx_tensor.dims()[0] > height) {
Y
Yang Yu 已提交
131
        auto rest_tensor = dx_tensor.Slice(
Y
Refine  
Yang Yu 已提交
132
            static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
Y
Yang Yu 已提交
133 134 135 136 137 138
        math::set_constant(dev_ctx, &rest_tensor, 0.0f);
      }
    }
  }
};

Y
Yang Yu 已提交
139
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
Y
Yang Yu 已提交
140 141 142 143 144 145 146 147 148
 public:
  void operator()(framework::InferShapeContext *context) const override {
    PADDLE_ENFORCE(context->HasInput("X"));
    PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
    context->SetOutputDim(framework::GradVarName("X"),
                          context->GetInputDim("X"));
  }
};

Y
Yang Yu 已提交
149
class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
Y
Yang Yu 已提交
150 151 152 153
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
Y
Yu Yang 已提交
154 155
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto *op = new framework::OpDesc();
Y
Yang Yu 已提交
156
    op->SetType("shrink_rnn_memory_grad");
Y
Yang Yu 已提交
157 158 159 160
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());
Y
Yu Yang 已提交
161
    return std::unique_ptr<framework::OpDesc>(op);
Y
Yang Yu 已提交
162 163 164 165 166 167 168
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yu 已提交
169 170 171 172 173
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
                  ops::ShrinkRNNMemoryInferShape,
                  ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
                  ops::ShrinkRNNMemoryGradInferShape);