shrink_rnn_memory_op.cc 8.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yang Yu 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yang Yu 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yang Yu 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14 15
#include "paddle/fluid/operators/array_operator.h"
#include "paddle/fluid/operators/math/math_function.h"
Y
Yang Yu 已提交
16

17 18
#include "paddle/pten/core/lod_utils.h"

19 20 21 22 23 24 25 26 27 28
namespace paddle {
namespace framework {
class OpDesc;
class Scope;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
}  // namespace paddle

Y
Yang Yu 已提交
29 30 31
namespace paddle {
namespace operators {

Y
Yang Yu 已提交
32
class ShrinkRNNMemoryOp : public ArrayOp {
Y
Yang Yu 已提交
33
 public:
Y
Yang Yu 已提交
34 35 36 37
  ShrinkRNNMemoryOp(const std::string &type,
                    const framework::VariableNameMap &inputs,
                    const framework::VariableNameMap &outputs,
                    const framework::AttributeMap &attrs)
Y
Yang Yu 已提交
38 39
      : ArrayOp(type, inputs, outputs, attrs) {}

40 41 42
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
Y
Yang Yu 已提交
43
    auto *x_var = scope.FindVar(Input("X"));
44 45 46
    PADDLE_ENFORCE_NOT_NULL(x_var,
                            platform::errors::NotFound(
                                "Input(X) of ShrinkRNNMemoryOp is not found."));
Y
Yang Yu 已提交
47
    auto &x_tensor = x_var->Get<framework::LoDTensor>();
D
dzhwinter 已提交
48
    size_t offset = this->GetOffset(scope, place);
Y
Yang Yu 已提交
49
    auto *rank_table_var = scope.FindVar(Input("RankTable"));
50 51 52 53
    PADDLE_ENFORCE_NOT_NULL(
        rank_table_var,
        platform::errors::NotFound(
            "Input(RankTable) of ShrinkRNNMemoryOp is not found."));
Y
Yang Yu 已提交
54 55
    auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();

Y
Yang Yu 已提交
56 57 58 59 60 61
    auto &rank_items = rank_table.items();
    int dst_num_rows =
        std::lower_bound(rank_items.begin(), rank_items.end(), offset,
                         [](const framework::LoDRankTable::TableItem &a,
                            size_t b) { return a.length > b; }) -
        rank_items.begin();
Y
Yang Yu 已提交
62 63

    auto *out_var = scope.FindVar(Output("Out"));
64 65 66
    PADDLE_ENFORCE_NOT_NULL(
        out_var, platform::errors::NotFound(
                     "Output(Out) of ShrinkRNNMemoryOp is not found."));
Y
Yang Yu 已提交
67
    auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
Y
yangyaming 已提交
68 69

    size_t height = dst_num_rows;
Y
yangyaming 已提交
70

71 72 73
    // do shrink for the top level LoD
    if (x_tensor.lod().size() > 0 &&
        x_tensor.lod()[0].size() > static_cast<size_t>(dst_num_rows)) {
74 75 76 77
      auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0,
                                                              dst_num_rows, 0);
      height = lod_offset.second.second;
      auto out_lod = out_tensor.mutable_lod();
78
      pten::AppendLoD(out_lod, lod_offset.first);
Y
yangyaming 已提交
79 80
    }

81
    if (dst_num_rows != 0) {
D
dzhwinter 已提交
82 83 84 85
      out_tensor.mutable_data(place, x_tensor.type());
      auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
      framework::TensorCopy(x_tensor.Slice(0, height), place, *dev_ctx,
                            &out_tensor);
Y
Yang Yu 已提交
86 87 88 89
    }
  }
};

Y
Yang Yu 已提交
90
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
Y
Yang Yu 已提交
91
 public:
Y
Yu Yang 已提交
92
  void Make() override {
T
tianshuo78520a 已提交
93
    AddInput("X", "(LoDTensor) The RNN step memory to be shrank.");
94 95 96
    AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
    AddInput("I",
             "(LoDTensor) The step index. The RNN step memory 'X' will be "
T
tianshuo78520a 已提交
97 98
             "shrank to match the size of the input of the index'th step.");
    AddOutput("Out", "(LoDTensor) The shrank RNN step memory.");
99 100 101 102 103 104 105 106 107 108 109
    AddComment(R"DOC(
This operator is used to shrink output batch of memory defined in dynamic RNN.

Dynamic RNN is able to handle variable-length sequences, in which, sequences in
a mini-batch are sorted by their lengths first. After that, the longest sequence
becomes the first one in the sorted batch, followed by the second longest, the
third longest, and so on. Dynamic RNN then slices a batch input timestep by
timestep from the sorted input. Once any sequence in the input batch reaches its
end, memory defined in dynamicRNN has to shrink its outputs to adapt to the input
batch size for the next time step.
)DOC");
Y
Yang Yu 已提交
110 111 112
  }
};

Y
Yang Yu 已提交
113
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
Y
Yang Yu 已提交
114 115
 public:
  void operator()(framework::InferShapeContext *context) const override {
116 117 118 119
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "ShrinkRNNMemory");
    OP_INOUT_CHECK(context->HasInput("I"), "Input", "I", "ShrinkRNNMemory");
    OP_INOUT_CHECK(context->HasInput("RankTable"), "Input", "RankTable",
                   "ShrinkRNNMemory");
Y
Yang Yu 已提交
120
    context->SetOutputDim("Out", context->GetInputDim("X"));
121 122
    // For runtime, output's lod is computed according to input's lod, but
    // remove the finished sequence. It is set in detail kernel implementation.
C
chengduo 已提交
123
    if (!context->IsRuntime()) {
124
      context->ShareLoD("X", /*->*/ "Out");
C
chengduo 已提交
125
    }
Y
Yang Yu 已提交
126 127 128
  }
};

Y
Yang Yu 已提交
129
class ShrinkRNNMemoryGradOp : public ArrayOp {
Y
Yang Yu 已提交
130
 public:
Y
Yang Yu 已提交
131 132 133 134
  ShrinkRNNMemoryGradOp(const std::string &type,
                        const framework::VariableNameMap &inputs,
                        const framework::VariableNameMap &outputs,
                        const framework::AttributeMap &attrs)
Y
Yang Yu 已提交
135 136
      : ArrayOp(type, inputs, outputs, attrs) {}

137 138 139
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
Y
Yang Yu 已提交
140
    auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
Y
Yang Yu 已提交
141
    auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
142 143 144
    PADDLE_ENFORCE_NOT_NULL(
        dx_var, platform::errors::NotFound(
                    "Input(X@GRAD) of ShrinkRNNMemoryGradOp is not found."));
Y
Yang Yu 已提交
145
    auto *x_var = scope.FindVar(Input("X"));
146 147 148
    PADDLE_ENFORCE_NOT_NULL(
        x_var, platform::errors::NotFound(
                   "Input(x) of ShrinkRNNMemoryGradOp is not found."));
Y
Yang Yu 已提交
149 150 151 152 153
    auto &x_tensor = x_var->Get<framework::LoDTensor>();
    auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
    dx_tensor.Resize(x_tensor.dims());
    dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());

D
dzhwinter 已提交
154
    // get device context from pool
Y
Yang Yu 已提交
155 156
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(place);
D
dzhwinter 已提交
157

Y
Yang Yu 已提交
158 159 160 161 162
    if (dout_var == nullptr) {  // dx_tensor fill zero
      math::set_constant(dev_ctx, &dx_tensor, 0.0f);
    } else {
      auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
      auto height = dout_tensor.dims()[0];
163 164
      auto slice = dx_tensor.Slice(0, static_cast<int>(height));
      framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
Y
Refine  
Yang Yu 已提交
165
      if (dx_tensor.dims()[0] > height) {
Y
Yang Yu 已提交
166
        auto rest_tensor = dx_tensor.Slice(
Y
Refine  
Yang Yu 已提交
167
            static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
Y
Yang Yu 已提交
168 169 170
        math::set_constant(dev_ctx, &rest_tensor, 0.0f);
      }
    }
171
    dx_tensor.set_lod(x_tensor.lod());
Y
Yang Yu 已提交
172 173 174
  }
};

Y
Yang Yu 已提交
175
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
Y
Yang Yu 已提交
176 177
 public:
  void operator()(framework::InferShapeContext *context) const override {
178 179 180
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "ShrinkRNNMemoryGrad");
    OP_INOUT_CHECK(context->HasOutput(framework::GradVarName("X")), "Output",
                   "X", "ShrinkRNNMemoryGrad");
181 182 183

    context->ShareDim("X", /*->*/ framework::GradVarName("X"));
    context->ShareLoD("X", /*->*/ framework::GradVarName("X"));
Y
Yang Yu 已提交
184 185 186
  }
};

H
hong 已提交
187 188
template <typename T>
class ShrinkRNNGradOpMaker : public framework::SingleGradOpMaker<T> {
Y
Yang Yu 已提交
189
 public:
H
hong 已提交
190
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yang Yu 已提交
191 192

 protected:
193
  void Apply(GradOpPtr<T> op) const override {
Y
Yang Yu 已提交
194
    op->SetType("shrink_rnn_memory_grad");
H
hong 已提交
195 196 197 198
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
Y
Yang Yu 已提交
199 200 201 202 203 204 205
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yu 已提交
206 207
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
                  ops::ShrinkRNNMemoryInferShape,
H
hong 已提交
208 209 210
                  ops::ShrinkRNNMemoryOpProtoMaker,
                  ops::ShrinkRNNGradOpMaker<paddle::framework::OpDesc>,
                  ops::ShrinkRNNGradOpMaker<paddle::imperative::OpBase>);
Y
Yang Yu 已提交
211 212
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
                  ops::ShrinkRNNMemoryGradInferShape);