recurrent_op.cc 26.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yan Chunwei 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yan Chunwei 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yan Chunwei 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yan Chunwei 已提交
14

15 16 17 18
#include "paddle/fluid/operators/recurrent_op.h"

#include <algorithm>
#include "paddle/fluid/string/string_helper.h"
Y
Yan Chunwei 已提交
19 20 21 22

namespace paddle {
namespace operators {

Y
Yu Yang 已提交
23 24
using StepScopeVar = std::vector<framework::Scope *>;

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
const char RecurrentBase::kInputs[] = "inputs";
const char RecurrentBase::kInitialStates[] = "initial_states";
const char RecurrentBase::kParameters[] = "parameters";
const char RecurrentBase::kOutputs[] = "outputs";
const char RecurrentBase::kStepScopes[] = "step_scopes";
const char RecurrentBase::kHasStates[] = "has_states";
const char RecurrentBase::kExStates[] = "ex_states";
const char RecurrentBase::kStates[] = "states";
const char RecurrentBase::kStepBlock[] = "sub_block";
const char RecurrentBase::kReverse[] = "reverse";
const char RecurrentBase::kIsTrain[] = "is_train";
const char RecurrentBase::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
#define GRAD_SUFFIX "@GRAD"
const char RecurrentBase::kInputGrads[] = "inputs" GRAD_SUFFIX;
const char RecurrentBase::kOutputGrads[] = "outputs" GRAD_SUFFIX;
const char RecurrentBase::kParamGrads[] = "parameters" GRAD_SUFFIX;
const char RecurrentBase::kInitStateGrads[] = "initial_states" GRAD_SUFFIX;

43 44 45 46 47 48 49 50
static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
                            framework::Scope *parent_scope,
                            StepScopeVar *step_scopes) {
  if (step_scopes->empty()) return;

  dev_ctx.Wait();

  for (auto *sub_scope : *step_scopes) {
51 52 53
    if (parent_scope->HasKid(sub_scope)) {
      parent_scope->DeleteScope(sub_scope);
    }
54 55 56 57 58
  }

  step_scopes->clear();
}

59 60 61 62 63 64 65 66
StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
                       const framework::Scope &parent, StepScopeVar *scopes,
                       bool is_train, size_t seq_len, bool is_backward)
    : counter_(is_backward ? seq_len - 1 : 0UL),
      scopes_(scopes),
      is_train_(is_train),
      is_backward_(is_backward) {
  size_t num_step_scopes = is_train ? seq_len : 2;
67 68
  PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
                    "Cannot backward when is not training");
69 70 71 72 73
  if (!is_backward_) {
    ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
    scopes->reserve(static_cast<size_t>(num_step_scopes));
    for (size_t i = 0; i < num_step_scopes; ++i) {
      scopes->emplace_back(&parent.NewScope());
Y
Yan Chunwei 已提交
74
    }
Y
Yu Yang 已提交
75
  }
76 77 78
}

framework::Scope &StepScopes::CurScope() { return GetScope(counter_); }
Y
Yu Yang 已提交
79

80 81 82 83
framework::Scope &StepScopes::ExScope() {
  auto &scope = GetScope(is_backward_ ? counter_ + 1 : counter_ - 1);
  return scope;
}
Y
Yu Yang 已提交
84

85 86 87 88 89 90 91 92
void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
                              framework::Scope *parent_scope) {
  PADDLE_ENFORCE_EQ(is_backward_, true,
                    "Cannot get backward next scope when is forward");
  if (counter_ + 2 == scopes_->size()) {
    parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
    scopes_->pop_back();
    VLOG(3) << "Deleted scope at " << counter_ + 1;
Y
Yu Yang 已提交
93
  }
94 95 96 97 98 99 100
  --counter_;
}

void StepScopes::ForwardNext() {
  PADDLE_ENFORCE_EQ(is_backward_, false,
                    "Cannot get forward next scope when is backward");
  ++counter_;
101
}
Y
Yu Yang 已提交
102

103 104 105
framework::Scope &StepScopes::GetScope(size_t scope_id) const {
  if (!is_train_) {
    scope_id %= 2;
Y
Yu Yang 已提交
106
  }
107 108 109
  PADDLE_ENFORCE_LT(scope_id, scopes_->size());
  return *(*scopes_)[scope_id];
}
Y
Yu Yang 已提交
110

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
RecurrentBase::RecurrentBase(const std::string &type,
                             const framework::VariableNameMap &inputs,
                             const framework::VariableNameMap &outputs,
                             const framework::AttributeMap &attrs)
    : OperatorBase(type, inputs, outputs, attrs) {}

// Get SequenceLength from Scope
//   The sequence length is got from input tensor. The input tensor's
//   dimension should be [SEQ_LEN, ..., ...]. The first of the tensor's shape
//   is SEQ_LEN. The second of the tensor's shape could be the batch size or
//   nested sequence length.
int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
  // Dim format SEQ_LEN, BATCH_SIZE, ...
  int64_t seq_len = -1;
  auto &all_inputs = Inputs(kInputs);
126
  PADDLE_ENFORCE_EQ(all_inputs.empty(), false);
127 128
  for (auto &iname : all_inputs) {
    auto *var = scope.FindVar(iname);
129 130
    PADDLE_ENFORCE_NOT_NULL(var);
    PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true);
131 132 133 134 135
    auto &dim = var->Get<framework::LoDTensor>().dims();
    if (seq_len == -1) {
      seq_len = dim[0];
    } else {
      PADDLE_ENFORCE_EQ(seq_len, dim[0]);
Y
Yu Yang 已提交
136 137
    }
  }
138 139
  return seq_len;
}
Y
Yu Yang 已提交
140

141 142 143 144 145 146 147 148 149 150 151 152 153
// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars),
//                                   map(dst_scope.Var, dst_vars)):
//   dst_tensor.ShareDataWith(src_tensor)
void RecurrentBase::LinkTensor(const framework::Scope &src_scope,
                               const std::vector<std::string> &src_vars,
                               framework::Scope *dst_scope,
                               const std::vector<std::string> &dst_vars) {
  LinkTensorWithCallback(
      src_scope, src_vars, dst_scope, dst_vars,
      [&](const framework::Tensor &src, framework::Tensor *dst) {
        dst->ShareDataWith(src);
      });
}
Y
Yu Yang 已提交
154

155 156 157 158 159 160 161
// (seq_len, shape) -> return [seq_len] + list(shape)
framework::DDim RecurrentBase::PrependDims(size_t seq_len,
                                           const framework::DDim &src) {
  auto dims = framework::vectorize(src);
  dims.insert(dims.begin(), static_cast<int64_t>(seq_len));
  return framework::make_ddim(dims);
}
Y
Yu Yang 已提交
162

163 164 165 166 167
RecurrentOp::RecurrentOp(const std::string &type,
                         const framework::VariableNameMap &inputs,
                         const framework::VariableNameMap &outputs,
                         const framework::AttributeMap &attrs)
    : RecurrentBase(type, inputs, outputs, attrs) {}
Y
Yu Yang 已提交
168

169 170 171 172
void RecurrentOp::RunImpl(const framework::Scope &scope,
                          const platform::Place &place) const {
  bool has_state = Attr<bool>(kHasStates);
  auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
Y
Yu Yang 已提交
173

174 175 176
  // get device context from pool
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &dev_ctx = *pool.Get(place);
Y
Yu Yang 已提交
177

178 179 180
  VLOG(3) << "Static RNN input sequence length = " << seq_len;
  StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len);
  auto reverse = Attr<bool>(kReverse);
Y
Yu Yang 已提交
181

182 183
  framework::Executor executor(place);
  auto *block = Attr<framework::BlockDesc *>(kStepBlock);
Y
Yu Yang 已提交
184

185 186 187 188
  auto *program = block->Program();
  auto ctx = executor.Prepare(
      *program, block->ID(), Attr<std::vector<std::string>>(
                                 kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/);
Y
Yu Yang 已提交
189

190 191 192
  for (size_t i = 0; i < seq_len; ++i) {
    size_t seq_offset = reverse ? seq_len - i - 1 : i;
    VLOG(3) << "Recurrent operate at the time step " << seq_offset;
Y
Yu Yang 已提交
193

194
    auto &cur_scope = scopes.CurScope();
Y
Yu Yang 已提交
195

196 197 198 199 200 201 202 203 204 205 206
    // Link outside::input --> inside::input
    //   inside::input = outside::input[seq_offset: seq_offset+1]
    LinkTensorWithCallback(
        scope, Inputs(kInputs), &cur_scope, Inputs(kInputs),
        [&seq_offset](const framework::Tensor &outside,
                      framework::Tensor *inside) {
          inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1));
          auto dims = framework::vectorize(inside->dims());
          dims.erase(dims.begin());
          inside->Resize(framework::make_ddim(dims));
        });
Y
Yu Yang 已提交
207

208 209 210 211 212 213 214 215 216 217 218
    if (has_state) {
      if (i == 0) {
        // Link initial states  --> ex_states
        LinkTensor(scope, Inputs(kInitialStates), &cur_scope,
                   Attr<std::vector<std::string>>(kExStates));
      } else {
        auto &ex_scope = scopes.ExScope();
        // Link ex_scope::state --> cur_scope::ex_state
        LinkTensor(ex_scope, Attr<std::vector<std::string>>(kStates),
                   &cur_scope, Attr<std::vector<std::string>>(kExStates));
      }
Y
Yu Yang 已提交
219 220
    }

221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
    // Link inside::output -> outside::output
    //   outside::output[seq_offset: seq_offset + 1] = inside::output
    executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
    if (i > 0) {
      LinkTensorWithCallback(scope, Outputs(kOutputs), cur_scope,
                             Outputs(kOutputs),
                             [&](const framework::LoDTensor &src_tensor,
                                 framework::LoDTensor *dst_tensor) {
                               framework::Tensor src_slice =
                                   src_tensor.Slice(seq_offset, seq_offset + 1);
                               dst_tensor->ShareDataWith(src_slice);
                             });
    }

    // Linked now, execute!
236 237
    executor.RunPreparedContext(ctx.get(), &cur_scope,
                                false /*create_local_scope*/,
238 239 240 241 242 243 244
                                false /*create_vars*/, true /* keep_kids */);
    if (i == 0) {
      LinkTensorWithCallback(
          cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs),
          [&](const framework::LoDTensor &src_tensor,
              framework::LoDTensor *dst_tensor) {
            // create output tensor at begin
245 246 247
            dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims()));
            dst_tensor->mutable_data(place, src_tensor.type());

248 249 250 251 252 253
            auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1);
            // Explicit copy output since the local RNN scope can be destroyed
            // early.
            framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out);
          });
    }
254

255
    scopes.ForwardNext();
Y
Yu Yang 已提交
256
  }
257
}
Y
Yu Yang 已提交
258

259 260 261 262
StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
                                         const framework::Scope &scope,
                                         size_t seq_len) const {
  auto *var = scope.FindVar(Output(kStepScopes));
263
  PADDLE_ENFORCE_NOT_NULL(var);
264 265 266
  return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
                    Attr<bool>(kIsTrain), seq_len);
}
Y
Yu Yang 已提交
267

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
RecurrentGradOp::RecurrentGradOp(const std::string &type,
                                 const framework::VariableNameMap &inputs,
                                 const framework::VariableNameMap &outputs,
                                 const framework::AttributeMap &attrs)
    : RecurrentBase(type, inputs, outputs, attrs) {}

void RecurrentGradOp::RunImpl(const framework::Scope &scope,
                              const platform::Place &place) const {
  bool has_state = Attr<bool>(kHasStates);
  const size_t seq_len = static_cast<size_t>(GetSequenceLength(scope));

  // get device context from pool
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &dev_ctx = *pool.Get(place);

  StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len);
  auto reverse = Attr<bool>(kReverse);

  framework::Executor executor(place);
  auto *block = Attr<framework::BlockDesc *>(kStepBlock);
  auto *program = block->Program();
  auto ctx = executor.Prepare(
      *program, block->ID(), Attr<std::vector<std::string>>(
                                 kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/);

  for (size_t step_id = 0; step_id < seq_len; ++step_id) {
    size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
    VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
    auto &cur_scope = scopes.CurScope();

    // Link outside::output_grads --> inside::output_grads
    //   inside::output_grad = outside::output_grad[seq_offset:seq_offset+1]
    LinkTensorWithCallback(
        scope, Inputs(kOutputGrads), &cur_scope, Inputs(kOutputGrads),
        [&](const framework::Tensor &outside, framework::Tensor *inside) {
          inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1));
          auto dims = framework::vectorize(inside->dims());
          dims.erase(dims.begin());
          inside->Resize(framework::make_ddim(dims));
        },
        true /*is_backward*/);
    auto og_set = List2Set(Inputs(kOutputGrads));

    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      std::copy(og_set.begin(), og_set.end(),
                std::ostream_iterator<std::string>(sout, ","));
      VLOG(10) << " RNN output gradients = [" << sout.str() << "]";
    }

    if (has_state) {
      // Link states
      //   if cur_scope::cur_state_grad in out_grads:
      //     cur_scope::cur_state_grad += ex_scope::ex_state_grad
      //   else:
      //     ex_scope::ex_state_grad --> cur_scope::cur_state_grad
      if (step_id != 0) {  // not at beginning
        auto &ex_scope = scopes.ExScope();
        auto ex_state_grads =
            GradVarLists(Attr<std::vector<std::string>>(kExStates));
        auto cur_state_grads =
            GradVarLists(Attr<std::vector<std::string>>(kStates));

        PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size());
        for (size_t i = 0; i < ex_state_grads.size(); ++i) {
          auto &cur_grad = cur_state_grads[i];
          auto &ex_grad = ex_state_grads[i];
335
          auto &ex_grad_tensor =
336 337 338 339
              ex_scope.FindVar(ex_grad)->Get<framework::LoDTensor>();

          VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad;
          auto *cur_grad_var = cur_scope.Var(cur_grad);
340
          framework::LoDTensor *cur_grad_tensor =
341
              cur_grad_var->GetMutable<framework::LoDTensor>();
342
          cur_grad_tensor->ShareDataWith(ex_grad_tensor);
Y
Yu Yang 已提交
343
        }
Y
Yan Chunwei 已提交
344
      }
345
    }
Y
Yu Yang 已提交
346

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    // Link inside::output -> outside::output
    //   outside::output[seq_offset: seq_offset + 1] = inside::output
    executor.CreateVariables(ctx->prog_, &cur_scope, ctx->block_id_);
    if (step_id > 0) {
      LinkTensorWithCallback(scope, Outputs(kInputGrads), cur_scope,
                             GradVarLists(Inputs(kInputs)),
                             [&](const framework::LoDTensor &src_tensor,
                                 framework::LoDTensor *dst_tensor) {
                               if (src_tensor.memory_size() ==
                                   0) {  // Inside Gradient is not created.
                                 return;
                               }
                               framework::Tensor src_slice =
                                   src_tensor.Slice(seq_offset, seq_offset + 1);
                               dst_tensor->ShareDataWith(src_slice);
                             },
                             true /*is_backward*/);
    }

366 367 368 369
    VLOG(5) << "Recurrent memory linking finished ";
    // Run step block with cur_scope
    executor.RunPreparedContext(ctx.get(), &cur_scope,
                                false /*create_local_scope*/,
370
                                false /*create_vars*/, true /* keep_kids */);
Y
Yu Yang 已提交
371

372
    VLOG(5) << "executor.Run finished ";
Y
Yu Yang 已提交
373

374
    auto local_var_names = LocalVarNames(cur_scope);
Y
Yu Yang 已提交
375

376 377 378 379 380 381 382 383
    // Accumulate params
    //   if (step == 0):
    //      outside::param_grad = 0.0
    //   outside::param_grad += inside::param_grad
    {
      auto &pg_names = Outputs(kParamGrads);
      auto &p_names = Inputs(kParameters);
      PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
Y
Yu Yang 已提交
384

385 386
      for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
        auto inside_grad_name = framework::GradVarName(p_names[param_id]);
Y
Yu Yang 已提交
387

388 389 390 391 392
        // If does not compute gradient of that variable inside rnn, just
        // continue
        if (local_var_names.find(inside_grad_name) == local_var_names.end()) {
          continue;
        }
Y
Yu Yang 已提交
393

394 395 396 397 398 399
        // zero gradient variable in step 0
        if (step_id == 0) {
          auto &inside_tensor =
              cur_scope.FindVar(inside_grad_name)->Get<framework::LoDTensor>();
          framework::AttributeMap attrs;
          attrs["dtype"] = inside_tensor.type();
400
          attrs["shape"] = framework::vectorize<int>(inside_tensor.dims());
401 402 403 404 405 406 407
          attrs["value"] = 0.0f;

          auto zero_op = framework::OpRegistry::CreateOp(
              "fill_constant", framework::VariableNameMap{},
              {{"Out", {pg_names[param_id]}}}, attrs);
          zero_op->Run(scope, place);
        }
Y
Yu Yang 已提交
408

409
        auto new_inside_name = cur_scope.Rename(inside_grad_name);
Y
Yu Yang 已提交
410

411 412 413 414 415 416
        // sum gradient
        auto sum_op = framework::OpRegistry::CreateOp(
            "sum", {{"X", {pg_names[param_id], new_inside_name}}},
            {{"Out", {pg_names[param_id]}}},
            framework::AttributeMap{{"use_mkldnn", {false}}});
        sum_op->Run(cur_scope, place);
Y
Yu Yang 已提交
417

418
        cur_scope.Rename(new_inside_name, inside_grad_name);
Y
Yan Chunwei 已提交
419
      }
420 421 422 423 424
    }
    VLOG(5) << "Accumulate Parameter finished ";

    // Copy input gradient from inside to outside
    //   outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad
425 426 427 428 429 430 431 432 433
    if (step_id == 0) {
      LinkTensorWithCallback(
          cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads),
          [&](const framework::LoDTensor &inside,
              framework::LoDTensor *outside) {
            if (inside.memory_size() == 0) {  // IG is not created.
              return;
            }
            // Alloc outside memory
434 435 436
            outside->Resize(PrependDims(seq_len, inside.dims()));
            outside->mutable_data(place, inside.type());

437 438 439 440 441
            auto dst = outside->Slice(seq_offset, seq_offset + 1);
            framework::TensorCopy(inside, place, dev_ctx, &dst);
          },
          true /*is_backward*/);
    }
442 443 444 445 446 447 448 449 450 451 452
    VLOG(5) << "Link outside gradient finished ";

    if (has_state) {
      if (step_id + 1 == seq_len) {  // at_end
        // copy initialize states gradient from inside to outside
        LinkTensorWithCallback(
            cur_scope, GradVarLists(Attr<std::vector<std::string>>(kExStates)),
            scope, Outputs(kInitStateGrads),
            [&](const framework::LoDTensor &inside,
                framework::LoDTensor *outside) {
              outside->Resize(inside.dims());
D
dzhwinter 已提交
453
              outside->mutable_data(place, inside.type());
454 455 456 457
              framework::TensorCopy(inside, place, dev_ctx, outside);
            },
            true /*is_backward*/);
        VLOG(5) << "Link initialize state gradient finished ";
Y
Yu Yang 已提交
458
      }
Y
Yan Chunwei 已提交
459
    }
460
    scopes.BackwardNext(dev_ctx, const_cast<framework::Scope *>(&scope));
Y
Yan Chunwei 已提交
461
  }
462 463
  // Delete the scope of StepScopes
  auto *var = scope.FindVar(Input(kStepScopes));
464
  PADDLE_ENFORCE_NOT_NULL(var);
465 466 467
  auto *step_scopes = var->GetMutable<StepScopeVar>();
  ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
}
Y
Yu Yang 已提交
468

469 470 471 472
StepScopes RecurrentGradOp::CreateStepScopes(
    const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
    size_t seq_len) const {
  auto *var = scope.FindVar(Input(kStepScopes));
473
  PADDLE_ENFORCE_NOT_NULL(var);
474 475 476
  return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
                    Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
}
Y
Yu Yang 已提交
477

478 479 480 481 482 483
std::unordered_set<std::string> RecurrentGradOp::List2Set(
    const std::vector<std::string> &list) const {
  std::unordered_set<std::string> local_var_name_set;
  local_var_name_set.reserve(list.size());
  for (auto &each : list) {
    local_var_name_set.insert(each);
Y
Yu Yang 已提交
484
  }
485 486
  return local_var_name_set;
}
Y
Yu Yang 已提交
487

488 489 490 491
std::unordered_set<std::string> RecurrentGradOp::LocalVarNames(
    const framework::Scope &scope) const {
  return this->List2Set(scope.LocalVarNames());
}
492

493 494 495 496 497 498 499 500
std::vector<std::string> RecurrentGradOp::GradVarLists(
    const std::vector<std::string> &var_names) {
  std::vector<std::string> retv;
  retv.reserve(var_names.size());
  std::transform(var_names.begin(), var_names.end(), std::back_inserter(retv),
                 framework::GradVarName);
  return retv;
}
Y
Yu Yang 已提交
501 502

class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
503
 public:
Y
Yu Yang 已提交
504
  void Make() override {
505 506 507 508
    AddInput(RecurrentBase::kInputs, "rnn inputs").AsDuplicable();
    AddInput(RecurrentBase::kInitialStates, "rnn initial states")
        .AsDuplicable();
    AddInput(RecurrentBase::kParameters,
Y
Yu Yang 已提交
509
             "Parameters are used by step block as its input. However, the "
K
kexinzhao 已提交
510 511
             "input is not a sequence tensor. Every time step, each operator "
             "in step block just use the parameter directly.")
Y
Yu Yang 已提交
512
        .AsDuplicable();
513
    AddOutput(RecurrentBase::kOutputs,
K
kexinzhao 已提交
514
              "The output sequence of RNN. The sequence length must be same.")
Y
Yu Yang 已提交
515
        .AsDuplicable();
516
    AddOutput(RecurrentBase::kStepScopes,
K
kexinzhao 已提交
517
              "StepScopes contain all local variables in each time step.");
518 519 520 521 522 523
    AddAttr<bool>(RecurrentBase::kHasStates, "Whether has states.")
        .SetDefault(false);
    AddAttr<std::vector<std::string>>(
        RecurrentBase::kExStates,
        string::Sprintf(
            R"DOC(The ex-state variable names.
Y
Yu Yang 已提交
524 525
The ex-state means the state value in the ex-timestep or the previous time step
[%s, %s, %s] must be the same order)DOC",
526 527
            RecurrentBase::kExStates, RecurrentBase::kStates,
            RecurrentBase::kInitStateGrads));
Y
Yu Yang 已提交
528
    AddAttr<std::vector<std::string>>(
529
        RecurrentBase::kStates,
Y
Yu Yang 已提交
530 531
        string::Sprintf(
            "The state variable names. [%s, %s, %s] must be the same order",
532 533 534 535 536
            RecurrentBase::kExStates, RecurrentBase::kStates,
            RecurrentBase::kInitStateGrads));
    AddAttr<framework::BlockDesc *>(RecurrentBase::kStepBlock,
                                    "The step block inside RNN");
    AddAttr<bool>(RecurrentBase::kReverse, R"DOC(Calculate RNN reversely or not.
Y
Yu Yang 已提交
537
By default reverse=False
Y
Yan Chunwei 已提交
538

Y
Yu Yang 已提交
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
Assume the input data is [A, B, C, D]

if reverse is False:
  the computation of RNN is like
      A          B          C         D
      |          |          |         |
      v          v          v         v
     rnn -----> rnn -----> rnn ----> rnn
      |          |          |         |
      v          v          v         v
      o          o          o         o

if reverse is True
  the computation of RNN is like
      A          B          C         D
      |          |          |         |
      v          v          v         v
     rnn <----- rnn <----- rnn <---- rnn
      |          |          |         |
      v          v          v         v
      o          o          o         o
)DOC").SetDefault(false);
561 562 563 564 565 566
    AddAttr<bool>(RecurrentBase::kIsTrain, "").SetDefault(true);
    AddAttr<std::vector<std::string>>(RecurrentBase::kSkipEagerDeletionVars,
                                      "Vars that would skip eager deletion."
                                      "Users should not set this manually.")
        .SetDefault(std::vector<std::string>());

K
kexinzhao 已提交
567 568 569 570 571
    AddComment(R"DOC(
Static Length Recurrent Operator.

The static length recurrent operator can only operate on fixed size sequence
data, i.e. in each mini-batch, the sequence length of all inputs are the same.
Y
Yu Yang 已提交
572 573 574 575 576

)DOC");
  }
};

H
hong 已提交
577 578
template <typename T>
class RecurrentGradOpMaker : public framework::SingleGradOpMaker<T> {
Y
Yu Yang 已提交
579
 public:
H
hong 已提交
580
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yan Chunwei 已提交
581

Y
Yu Yang 已提交
582
 protected:
H
hong 已提交
583 584
  virtual std::unique_ptr<T> Apply() const {
    auto *grad = new T();
Y
Yu Yang 已提交
585 586 587 588
    grad->SetType("recurrent_grad");
    for (auto &input_param : this->InputNames()) {
      grad->SetInput(input_param, this->Input(input_param));
      grad->SetOutput(framework::GradVarName(input_param),
589
                      this->InputGrad(input_param, false));
Y
Yu Yang 已提交
590 591 592
    }

    for (auto &output_param : this->OutputNames()) {
593
      if (output_param == RecurrentBase::kStepScopes) {
Y
Yu Yang 已提交
594 595 596 597 598 599 600 601 602 603
        grad->SetInput(output_param, this->Output(output_param));
        grad->SetInput(framework::GradVarName(output_param),
                       this->Output(output_param));
      } else {
        grad->SetInput(output_param, this->Output(output_param));
        grad->SetInput(framework::GradVarName(output_param),
                       this->OutputGrad(output_param));
      }
    }
    grad->SetAttrMap(this->Attrs());
H
hong 已提交
604
    grad->SetBlockAttr(RecurrentBase::kStepBlock, this->grad_block_[0]);
Y
Yan Chunwei 已提交
605

H
hong 已提交
606
    return std::unique_ptr<T>(grad);
Y
Yan Chunwei 已提交
607 608 609
  }
};

Y
Yu Yang 已提交
610 611 612
class RecurrentGradOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
613
    std::vector<std::string> output{RecurrentBase::kOutputs};
C
chengduo 已提交
614 615 616

    // In some case the kInitialStates is empty.
    // If the kInitialStates is empty, all the states should be empty.
617
    if (!ctx->HasInputs(RecurrentBase::kInitialStates)) {
C
chengduo 已提交
618
      PADDLE_ENFORCE_EQ(
619 620 621 622
          ctx->Attrs()
              .Get<std::vector<std::string>>(RecurrentBase::kExStates)
              .size(),
          0, "The Attr(%s) should be empty.", RecurrentBase::kExStates);
C
chengduo 已提交
623
      PADDLE_ENFORCE_EQ(
624 625 626 627
          ctx->Attrs()
              .Get<std::vector<std::string>>(RecurrentBase::kStates)
              .size(),
          0, "The Attr(%s) should be empty.", RecurrentBase::kStates);
Y
Yu Yang 已提交
628
    }
C
chengduo 已提交
629

630 631 632 633 634 635
    PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true,
                      "The input(%s) should not be empty.",
                      RecurrentBase::kInputs);
    PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true,
                      "The input(%s) should not be empty.",
                      RecurrentBase::kOutputs);
C
chengduo 已提交
636 637

    // In some case the kInitialStates is empty.
638
    if (ctx->HasInputs(RecurrentBase::kInitialStates)) {
639 640 641 642
      PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName(
                            RecurrentBase::kInitialStates)),
                        true, "The output of(%s) should not be empty.",
                        framework::GradVarName(RecurrentBase::kInitialStates));
643 644
      ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates),
                         ctx->GetInputsDim(RecurrentBase::kInitialStates));
Y
Yan Chunwei 已提交
645
    }
C
chengduo 已提交
646

647 648
    PADDLE_ENFORCE_EQ(
        ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
649 650 651 652
        "The output of(%s) should not be empty.",
        framework::GradVarName(RecurrentBase::kInputs));
    ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
                       ctx->GetInputsDim(RecurrentBase::kInputs));
C
chengduo 已提交
653 654

    // In some case the kParameters is empty.
655
    if (ctx->HasInputs(RecurrentBase::kParameters)) {
656
      PADDLE_ENFORCE_EQ(
657
          ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
658
          true, "The output of(%s) should not be empty.",
659 660 661
          framework::GradVarName(RecurrentBase::kParameters));
      ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
                         ctx->GetInputsDim(RecurrentBase::kParameters));
Y
Yu Yang 已提交
662 663 664
    }
  }
};
Y
Yan Chunwei 已提交
665 666 667 668

}  // namespace operators
}  // namespace paddle

H
hong 已提交
669 670 671 672
REGISTER_OPERATOR(
    recurrent, paddle::operators::RecurrentOp,
    paddle::operators::RecurrentOpProtoMaker,
    paddle::operators::RecurrentGradOpMaker<paddle::framework::OpDesc>);
Y
Yu Yang 已提交
673 674
REGISTER_OPERATOR(recurrent_grad, paddle::operators::RecurrentGradOp,
                  paddle::operators::RecurrentGradOpShapeInference);