shrink_rnn_memory_op.cc 8.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yang Yu 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yang Yu 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yang Yu 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14 15 16 17
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/array_operator.h"
#include "paddle/fluid/operators/math/math_function.h"
Y
Yang Yu 已提交
18 19 20 21

namespace paddle {
namespace operators {

Y
Yang Yu 已提交
22
class ShrinkRNNMemoryOp : public ArrayOp {
Y
Yang Yu 已提交
23
 public:
Y
Yang Yu 已提交
24 25 26 27
  ShrinkRNNMemoryOp(const std::string &type,
                    const framework::VariableNameMap &inputs,
                    const framework::VariableNameMap &outputs,
                    const framework::AttributeMap &attrs)
Y
Yang Yu 已提交
28 29
      : ArrayOp(type, inputs, outputs, attrs) {}

30 31 32
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
Y
Yang Yu 已提交
33
    auto *x_var = scope.FindVar(Input("X"));
34 35 36
    PADDLE_ENFORCE_NOT_NULL(x_var,
                            platform::errors::NotFound(
                                "Input(X) of ShrinkRNNMemoryOp is not found."));
Y
Yang Yu 已提交
37
    auto &x_tensor = x_var->Get<framework::LoDTensor>();
D
dzhwinter 已提交
38
    size_t offset = this->GetOffset(scope, place);
Y
Yang Yu 已提交
39
    auto *rank_table_var = scope.FindVar(Input("RankTable"));
40 41 42 43
    PADDLE_ENFORCE_NOT_NULL(
        rank_table_var,
        platform::errors::NotFound(
            "Input(RankTable) of ShrinkRNNMemoryOp is not found."));
Y
Yang Yu 已提交
44 45
    auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();

Y
Yang Yu 已提交
46 47 48 49 50 51
    auto &rank_items = rank_table.items();
    int dst_num_rows =
        std::lower_bound(rank_items.begin(), rank_items.end(), offset,
                         [](const framework::LoDRankTable::TableItem &a,
                            size_t b) { return a.length > b; }) -
        rank_items.begin();
Y
Yang Yu 已提交
52 53

    auto *out_var = scope.FindVar(Output("Out"));
54 55 56
    PADDLE_ENFORCE_NOT_NULL(
        out_var, platform::errors::NotFound(
                     "Output(Out) of ShrinkRNNMemoryOp is not found."));
Y
Yang Yu 已提交
57
    auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
Y
yangyaming 已提交
58 59

    size_t height = dst_num_rows;
Y
yangyaming 已提交
60

61 62 63
    // do shrink for the top level LoD
    if (x_tensor.lod().size() > 0 &&
        x_tensor.lod()[0].size() > static_cast<size_t>(dst_num_rows)) {
64 65 66 67 68
      auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0,
                                                              dst_num_rows, 0);
      height = lod_offset.second.second;
      auto out_lod = out_tensor.mutable_lod();
      framework::AppendLoD(out_lod, lod_offset.first);
Y
yangyaming 已提交
69 70
    }

71
    if (dst_num_rows != 0) {
D
dzhwinter 已提交
72 73 74 75
      out_tensor.mutable_data(place, x_tensor.type());
      auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
      framework::TensorCopy(x_tensor.Slice(0, height), place, *dev_ctx,
                            &out_tensor);
Y
Yang Yu 已提交
76 77 78 79
    }
  }
};

Y
Yang Yu 已提交
80
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
Y
Yang Yu 已提交
81
 public:
Y
Yu Yang 已提交
82
  void Make() override {
T
tianshuo78520a 已提交
83
    AddInput("X", "(LoDTensor) The RNN step memory to be shrank.");
84 85 86
    AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
    AddInput("I",
             "(LoDTensor) The step index. The RNN step memory 'X' will be "
T
tianshuo78520a 已提交
87 88
             "shrank to match the size of the input of the index'th step.");
    AddOutput("Out", "(LoDTensor) The shrank RNN step memory.");
89 90 91 92 93 94 95 96 97 98 99
    AddComment(R"DOC(
This operator is used to shrink output batch of memory defined in dynamic RNN.

Dynamic RNN is able to handle variable-length sequences, in which, sequences in
a mini-batch are sorted by their lengths first. After that, the longest sequence
becomes the first one in the sorted batch, followed by the second longest, the
third longest, and so on. Dynamic RNN then slices a batch input timestep by
timestep from the sorted input. Once any sequence in the input batch reaches its
end, memory defined in dynamicRNN has to shrink its outputs to adapt to the input
batch size for the next time step.
)DOC");
Y
Yang Yu 已提交
100 101 102
  }
};

Y
Yang Yu 已提交
103
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
Y
Yang Yu 已提交
104 105
 public:
  void operator()(framework::InferShapeContext *context) const override {
106 107 108 109
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "ShrinkRNNMemory");
    OP_INOUT_CHECK(context->HasInput("I"), "Input", "I", "ShrinkRNNMemory");
    OP_INOUT_CHECK(context->HasInput("RankTable"), "Input", "RankTable",
                   "ShrinkRNNMemory");
Y
Yang Yu 已提交
110
    context->SetOutputDim("Out", context->GetInputDim("X"));
111 112
    // For runtime, output's lod is computed according to input's lod, but
    // remove the finished sequence. It is set in detail kernel implementation.
C
chengduo 已提交
113
    if (!context->IsRuntime()) {
114
      context->ShareLoD("X", /*->*/ "Out");
C
chengduo 已提交
115
    }
Y
Yang Yu 已提交
116 117 118
  }
};

Y
Yang Yu 已提交
119
class ShrinkRNNMemoryGradOp : public ArrayOp {
Y
Yang Yu 已提交
120
 public:
Y
Yang Yu 已提交
121 122 123 124
  ShrinkRNNMemoryGradOp(const std::string &type,
                        const framework::VariableNameMap &inputs,
                        const framework::VariableNameMap &outputs,
                        const framework::AttributeMap &attrs)
Y
Yang Yu 已提交
125 126
      : ArrayOp(type, inputs, outputs, attrs) {}

127 128 129
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
Y
Yang Yu 已提交
130
    auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
Y
Yang Yu 已提交
131
    auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
132 133 134
    PADDLE_ENFORCE_NOT_NULL(
        dx_var, platform::errors::NotFound(
                    "Input(X@GRAD) of ShrinkRNNMemoryGradOp is not found."));
Y
Yang Yu 已提交
135
    auto *x_var = scope.FindVar(Input("X"));
136 137 138
    PADDLE_ENFORCE_NOT_NULL(
        x_var, platform::errors::NotFound(
                   "Input(x) of ShrinkRNNMemoryGradOp is not found."));
Y
Yang Yu 已提交
139 140 141 142 143
    auto &x_tensor = x_var->Get<framework::LoDTensor>();
    auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
    dx_tensor.Resize(x_tensor.dims());
    dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());

D
dzhwinter 已提交
144
    // get device context from pool
Y
Yang Yu 已提交
145 146
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);
D
dzhwinter 已提交
147

Y
Yang Yu 已提交
148 149 150 151 152
    if (dout_var == nullptr) {  // dx_tensor fill zero
      math::set_constant(dev_ctx, &dx_tensor, 0.0f);
    } else {
      auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
      auto height = dout_tensor.dims()[0];
153 154
      auto slice = dx_tensor.Slice(0, static_cast<int>(height));
      framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
Y
Refine  
Yang Yu 已提交
155
      if (dx_tensor.dims()[0] > height) {
Y
Yang Yu 已提交
156
        auto rest_tensor = dx_tensor.Slice(
Y
Refine  
Yang Yu 已提交
157
            static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
Y
Yang Yu 已提交
158 159 160
        math::set_constant(dev_ctx, &rest_tensor, 0.0f);
      }
    }
161
    dx_tensor.set_lod(x_tensor.lod());
Y
Yang Yu 已提交
162 163 164
  }
};

Y
Yang Yu 已提交
165
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
Y
Yang Yu 已提交
166 167
 public:
  void operator()(framework::InferShapeContext *context) const override {
168 169 170
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "ShrinkRNNMemoryGrad");
    OP_INOUT_CHECK(context->HasOutput(framework::GradVarName("X")), "Output",
                   "X", "ShrinkRNNMemoryGrad");
171 172 173

    context->ShareDim("X", /*->*/ framework::GradVarName("X"));
    context->ShareLoD("X", /*->*/ framework::GradVarName("X"));
Y
Yang Yu 已提交
174 175 176
  }
};

H
hong 已提交
177 178
template <typename T>
class ShrinkRNNGradOpMaker : public framework::SingleGradOpMaker<T> {
Y
Yang Yu 已提交
179
 public:
H
hong 已提交
180
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yang Yu 已提交
181 182

 protected:
183
  void Apply(GradOpPtr<T> op) const override {
Y
Yang Yu 已提交
184
    op->SetType("shrink_rnn_memory_grad");
H
hong 已提交
185 186 187 188
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
Y
Yang Yu 已提交
189 190 191 192 193 194 195
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yu 已提交
196 197
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
                  ops::ShrinkRNNMemoryInferShape,
H
hong 已提交
198 199 200
                  ops::ShrinkRNNMemoryOpProtoMaker,
                  ops::ShrinkRNNGradOpMaker<paddle::framework::OpDesc>,
                  ops::ShrinkRNNGradOpMaker<paddle::imperative::OpBase>);
Y
Yang Yu 已提交
201 202
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
                  ops::ShrinkRNNMemoryGradInferShape);