recurrent_op.cc 14.1 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "paddle/operators/recurrent_op.h"
Y
Yan Chunwei 已提交
16 17 18 19 20 21

#include <glog/logging.h>
#include <cstring>
#include <sstream>

#include "paddle/framework/op_registry.h"
Y
Yan Chunwei 已提交
22
#include "paddle/operators/net_op.h"
Y
Yan Chunwei 已提交
23 24 25 26 27 28 29
#include "paddle/platform/enforce.h"

namespace paddle {
namespace operators {

namespace rnn {

Y
Yu Yang 已提交
30
void SegmentInputs(const std::vector<Scope*>& step_scopes,
31
                   const std::vector<Link>& inlinks, const size_t seq_len,
D
dangqingqing 已提交
32
                   bool infer_shape_mode) {
Y
Yan Chunwei 已提交
33 34
  PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
  for (size_t i = 0; i < inlinks.size(); ++i) {
D
dangqingqing 已提交
35
    auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
36
    PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
D
dangqingqing 已提交
37 38
                   inlinks[i].external);
    Tensor* input = input_var->GetMutable<Tensor>();
Y
Yi Wang 已提交
39
    framework::DDim dims = input->dims();
Y
Yan Chunwei 已提交
40 41
    PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
                   "all the inlinks must have same length");
Y
Yi Wang 已提交
42
    framework::DDim step_dims = slice_ddim(dims, 1, dims.size());
Y
Yan Chunwei 已提交
43
    for (size_t j = 0; j < seq_len; j++) {
44 45
      Tensor* step_input =
          step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
D
dangqingqing 已提交
46
      if (!infer_shape_mode) {
47 48
        *step_input = input->Slice<float>(j, j + 1);
      }
Y
Yan Chunwei 已提交
49 50 51 52 53
      step_input->Resize(step_dims);
    }
  }
}

Y
Yu Yang 已提交
54
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
55
                   const std::vector<Link>& outlinks, const size_t seq_len,
D
dangqingqing 已提交
56
                   bool infer_shape_mode) {
Y
Yan Chunwei 已提交
57
  for (size_t i = 0; i < outlinks.size(); i++) {
D
dangqingqing 已提交
58
    auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
59
    PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
D
dangqingqing 已提交
60
                   outlinks[i].external);
D
dangqingqing 已提交
61
    Tensor* output = output_var->GetMutable<Tensor>();
D
dangqingqing 已提交
62
    if (infer_shape_mode) {
Y
Yi Wang 已提交
63 64 65 66
      framework::DDim step_dims = step_scopes[0]
                                      ->FindVar(outlinks[i].internal)
                                      ->GetMutable<Tensor>()
                                      ->dims();
67 68
      std::vector<int> dims_vec = vectorize(step_dims);
      dims_vec.insert(dims_vec.begin(), seq_len);
Y
Yi Wang 已提交
69
      output->Resize(framework::make_ddim(dims_vec));
70 71
    } else {
      output->mutable_data<float>(platform::CPUPlace());
D
dangqingqing 已提交
72
      for (size_t j = 0; j < seq_len; j++) {
D
dangqingqing 已提交
73 74
        Tensor* step_output =
            step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
D
dangqingqing 已提交
75 76 77 78 79
        // TODO(luotao02) data type and platform::DeviceContext() should set
        // correctly
        (output->Slice<float>(j, j + 1))
            .CopyFrom<float>(*step_output, platform::CPUPlace());
      }
Y
Yan Chunwei 已提交
80 81 82 83
    }
  }
}

Y
Yu Yang 已提交
84
void LinkMemories(const std::vector<Scope*>& scopes,
Y
Yan Chunwei 已提交
85
                  const std::vector<rnn::MemoryAttr>& memories,
86
                  const size_t step_id, const int offset,
D
dangqingqing 已提交
87
                  bool infer_shape_mode) {
Y
Yan Chunwei 已提交
88
  PADDLE_ENFORCE(step_id < scopes.size(),
89
                 "step [%d] is out of range of step scopes' size [%d]", step_id,
Y
Yan Chunwei 已提交
90 91
                 scopes.size());
  PADDLE_ENFORCE(static_cast<int>(step_id) + offset >= 0,
92
                 "offset [%d] must be large than -[%d]", offset, step_id);
Y
Yan Chunwei 已提交
93 94
  PADDLE_ENFORCE(step_id + offset < scopes.size(),
                 "offset [%d] is out of range, it must be less than (%d - %d)",
95
                 offset, scopes.size(), step_id);
Y
Yu Yang 已提交
96 97
  auto scope = scopes[step_id];
  auto linked_scope = scopes[step_id + offset];
Y
Yan Chunwei 已提交
98
  for (auto& attr : memories) {
D
dangqingqing 已提交
99
    auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
100
    auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
D
dangqingqing 已提交
101
    if (infer_shape_mode) {
102 103 104 105
      mem->Resize(linked_mem->dims());
    } else {
      mem->ShareDataWith<float>(*linked_mem);
    }
Y
Yan Chunwei 已提交
106 107 108
  }
}

109
void InitArgument(const ArgumentName& name, Argument* arg,
Y
Yan Chunwei 已提交
110 111 112 113 114 115 116 117
                  const OperatorBase& op) {
  arg->step_net = op.Input(name.step_net);
  arg->step_scopes = op.Output(name.step_scopes);

  auto inlinks = op.Inputs(name.inlinks);
  auto inlink_alias = op.GetAttr<std::vector<std::string>>(name.inlink_alias);
  PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
                 "the size of inlinks and inlink_alias don't match:%d,%d",
118
                 inlinks.size(), inlink_alias.size());
Y
Yan Chunwei 已提交
119 120 121 122 123 124 125 126 127 128 129
  for (size_t i = 0; i < inlinks.size(); ++i) {
    rnn::Link link;
    link.external = inlinks[i];
    link.internal = inlink_alias[i];
    (arg->inlinks).push_back(link);
  }

  auto outlinks = op.Outputs(name.outlinks);
  auto outlink_alias = op.GetAttr<std::vector<std::string>>(name.outlink_alias);
  PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
                 "the size of outlinks and outlink_alias don't match:%d,%d",
130
                 outlinks.size(), outlink_alias.size());
Y
Yan Chunwei 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  for (size_t i = 0; i < outlinks.size(); ++i) {
    rnn::Link link;
    link.external = outlinks[i];
    link.internal = outlink_alias[i];
    (arg->outlinks).push_back(link);
  }

  auto boot_memories = op.Inputs(name.boot_memories);

  // attributes
  auto memories = op.GetAttr<std::vector<std::string>>(name.memories);
  auto pre_memories = op.GetAttr<std::vector<std::string>>(name.pre_memories);

  PADDLE_ENFORCE(memories.size() == boot_memories.size(),
                 "the size of memories, boot_memories don't match:%d,%d",
146
                 memories.size(), boot_memories.size());
Y
Yan Chunwei 已提交
147 148
  PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(),
                 "the size of pre_memories, boot_memories don't match:%d,%d",
149
                 pre_memories.size(), boot_memories.size());
Y
Yan Chunwei 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162
  PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set");

  for (size_t i = 0; i < memories.size(); ++i) {
    rnn::MemoryAttr mem_attr;
    mem_attr.var = memories[i];
    mem_attr.pre_var = pre_memories[i];
    mem_attr.boot_var = boot_memories[i];
    (arg->memories).push_back(mem_attr);
  }
}

}  // namespace rnn

Y
Yu Yang 已提交
163 164
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
  seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
Y
Yan Chunwei 已提交
165 166 167
                 ->GetMutable<Tensor>()
                 ->dims()[0];
  CreateScopes(scope);
168
  auto step_scopes = GetStepScopes(scope);
169 170
  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
                     true /*infer_shape_mode*/);
D
dangqingqing 已提交
171
  InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Y
Yu Yang 已提交
172
  Variable* net = scope.FindVar(arg_->step_net);
Y
Yan Chunwei 已提交
173 174 175
  PADDLE_ENFORCE(net != nullptr, "failed to get step net");
  for (size_t i = 0; i < seq_len_; i++) {
    if (i > 0) {
176 177
      rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
                        true /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
178
    }
Y
Yu Yang 已提交
179
    net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
Y
Yan Chunwei 已提交
180
  }
181 182
  rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
                     true /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
183 184
}

Y
Yu Yang 已提交
185
void RecurrentAlgorithm::Run(const Scope& scope,
Y
Yan Chunwei 已提交
186 187
                             const platform::DeviceContext& dev_ctx) const {
  auto step_scopes = GetStepScopes(scope);
188 189
  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
                     false /*infer_shape_mode*/);
D
dangqingqing 已提交
190
  InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Y
Yu Yang 已提交
191
  Variable* net = scope.FindVar(arg_->step_net);
D
dangqingqing 已提交
192

Y
Yan Chunwei 已提交
193 194
  for (size_t step_id = 0; step_id < seq_len_; step_id++) {
    if (step_id > 0) {
195 196
      rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
                        false /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
197
    }
Y
Yu Yang 已提交
198
    net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
Y
Yan Chunwei 已提交
199
  }
200 201
  rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
                     false /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
202 203
}

Y
Yu Yang 已提交
204
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
Y
Yan Chunwei 已提交
205 206
  // TODO(xxx) Only two scopes are needed for inference, this case will be
  // supported later.
Y
Yu Yang 已提交
207 208
  auto step_scopes =
      scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
Y
Yan Chunwei 已提交
209 210 211

  if (seq_len_ > step_scopes->size()) {
    for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
Y
Yu Yang 已提交
212
      auto& step_scope = scope.NewScope();
Y
Yan Chunwei 已提交
213 214

      // Now all variables in scope must be created outside of op.
Y
Yu Yang 已提交
215
      auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>();
Y
Yan Chunwei 已提交
216
      for (auto& input : net_op->inputs_) {
217
        // the weight are located in parent scope
Y
Yu Yang 已提交
218
        if (!step_scope.FindVar(input)) step_scope.NewVar(input);
Y
Yan Chunwei 已提交
219 220
      }
      for (auto& output : net_op->outputs_) {
Y
Yu Yang 已提交
221
        step_scope.NewVar(output);
Y
Yan Chunwei 已提交
222
      }
Y
Yu Yang 已提交
223
      step_scopes->emplace_back(&step_scope);
Y
Yan Chunwei 已提交
224 225 226 227
    }
  }
}

D
dangqingqing 已提交
228
void RecurrentAlgorithm::InitMemories(Scope* step_scope,
D
dangqingqing 已提交
229
                                      bool infer_shape_mode) const {
Y
Yan Chunwei 已提交
230
  for (auto& attr : arg_->memories) {
231
    Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
Y
Yu Yang 已提交
232
    PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
233
                   "memory [%s]'s boot variable [%s] not exists", attr.var,
Y
Yan Chunwei 已提交
234
                   attr.boot_var);
235
    Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
D
dangqingqing 已提交
236
    if (infer_shape_mode) {
237 238 239 240
      pre_mem->Resize(boot_mem->dims());
    } else {
      pre_mem->ShareDataWith<float>(*boot_mem);
    }
Y
Yan Chunwei 已提交
241 242 243
  }
}

244 245 246 247 248 249 250 251 252
const rnn::ArgumentName RecurrentOp::kArgName{
    "step_net", "step_scopes",  "inlinks",
    "outlinks", "inlink_alias", "outlink_alias",
    "memories", "pre_memories", "boot_memories"};

const rnn::ArgumentName RecurrentGradientOp::kArgName{
    "step_net",    "step_scopes",  "outlink@grad",
    "inlink@grad", "inlink_alias", "outlink_alias",
    "memories",    "pre_memories", "boot_memories@grad"};
Y
Yan Chunwei 已提交
253 254 255 256 257 258 259 260 261

void RecurrentOp::Init() {
  OperatorBase::Init();
  std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
  rnn::InitArgument(kArgName, arg.get(), *this);
  alg_.Init(std::move(arg));
}

class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
262
 public:
Y
Yan Chunwei 已提交
263 264 265 266 267
  RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto,
                                         OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
    const auto& name = RecurrentOp::kArgName;
    // inputs and outputs stored in proto
D
dangqingqing 已提交
268 269
    AddInput(name.inlinks,
             "the inputs that need to be segmented for each step.")
Y
Yu Yang 已提交
270 271 272
        .SetMultiple();
    AddInput(name.boot_memories, "variables to initialize memories.")
        .SetMultiple();
Y
Yan Chunwei 已提交
273 274
    AddInput(name.step_net, "network shared by all steps.");

D
dangqingqing 已提交
275
    AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
Y
Yu Yang 已提交
276
        .SetMultiple();
Y
Yan Chunwei 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290
    AddOutput(name.step_scopes, "step scopes");

    // Attributes stored in AttributeMap
    AddAttr<std::vector<std::string>>(name.inlink_alias, "alias of inlinks");
    AddAttr<std::vector<std::string>>(name.outlink_alias, "alias of outlinks");
    AddAttr<std::vector<std::string>>(name.pre_memories,
                                      "names of pre-memories");
    AddAttr<std::vector<std::string>>(name.memories, "names of memories");

    AddComment("This is a recurrent group operator.");
  }
};

void RecurrentGradientAlgorithm::Run(
Y
Yu Yang 已提交
291
    const Scope& scope, const platform::DeviceContext& dev_ctx) const {
Y
Yan Chunwei 已提交
292
  auto step_scopes = GetStepScopes(scope);
293 294
  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
                     false /*infer_shape_mode*/);
Y
Yu Yang 已提交
295
  Variable* net = scope.FindVar(arg_->step_net);
Y
Yan Chunwei 已提交
296 297 298
  PADDLE_ENFORCE(net != nullptr, "failed to get step net");
  for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
    if (static_cast<size_t>(step_id) != seq_len_ - 1) {
299 300
      rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
                        false /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
301
    }
Y
Yu Yang 已提交
302
    net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
Y
Yan Chunwei 已提交
303
  }
304
  LinkBootMemoryGradients(step_scopes[0], false);
305 306
  rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
                     false /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
307 308 309
}

void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
D
dangqingqing 已提交
310
    Scope* step_scope, bool infer_shape_mode) const {
Y
Yan Chunwei 已提交
311
  for (auto& attr : arg_->memories) {
D
dangqingqing 已提交
312
    PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
313
                   "memory variable [%s] does not exists", attr.var);
Y
Yu Yang 已提交
314
    PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
315
                   "boot variable [%s] does not exists", attr.boot_var);
D
dangqingqing 已提交
316
    Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
Y
Yan Chunwei 已提交
317
    Tensor* boot_mem_grad =
318
        step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>();
D
dangqingqing 已提交
319
    if (infer_shape_mode) {
320 321 322 323
      boot_mem_grad->Resize(mem_grad->dims());
    } else {
      boot_mem_grad->ShareDataWith<float>(*mem_grad);
    }
Y
Yan Chunwei 已提交
324 325 326
  }
}

Y
Yu Yang 已提交
327 328
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
  seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
Y
Yan Chunwei 已提交
329 330 331
                 ->GetMutable<Tensor>()
                 ->dims()[0];
  auto step_scopes = GetStepScopes(scope);
332 333
  rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
                     true /*infer_shape_mode*/);
Y
Yu Yang 已提交
334
  Variable* net = scope.FindVar(arg_->step_net);
Y
Yan Chunwei 已提交
335 336 337
  PADDLE_ENFORCE(net != nullptr, "failed to get step net");
  for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
    if (static_cast<size_t>(step_id) != seq_len_ - 1) {
338 339
      rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
                        true /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
340
    }
Y
Yu Yang 已提交
341
    net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
Y
Yan Chunwei 已提交
342
  }
343 344
  rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
                     true /*infer_shape_mode*/);
D
dangqingqing 已提交
345
  LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
Y
Yan Chunwei 已提交
346 347 348 349 350 351 352 353 354 355 356 357
}

void RecurrentGradientOp::Init() {
  OperatorBase::Init();
  std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
  rnn::InitArgument(kArgName, arg.get(), *this);
  alg_.Init(std::move(arg));
}

}  // namespace operators
}  // namespace paddle

358
REGISTER_OP(recurrent_op, paddle::operators::RecurrentOp,
Y
Yan Chunwei 已提交
359
            paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);