while_op.cc 23.1 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yang Yang(Tony) 已提交
14

Y
Yi Wang 已提交
15 16 17
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
S
sneaxiy 已提交
18
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yang Yang(Tony) 已提交
19

20 21 22
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
W
wanghuancoder 已提交
23 24 25 26 27 28 29 30
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class VarDesc;
}  // namespace framework
}  // namespace paddle

Y
Yang Yang(Tony) 已提交
31 32 33 34 35 36
namespace paddle {
namespace operators {

using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor;

S
sneaxiy 已提交
37 38 39 40 41 42 43 44 45 46 47
namespace {  // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString(
    const std::vector<std::string> &vars) {
  std::string str = "Skip " + std::to_string(vars.size()) +
                    " var(s) in eager deletion mode: ";
  for (auto &var : vars) {
    str.append(var);
    str.push_back(' ');
  }
  return str;
}
48
}  // namespace
Y
Yang Yang(Tony) 已提交
49 50 51

class WhileOp : public framework::OperatorBase {
 public:
52 53
  WhileOp(const std::string &type,
          const framework::VariableNameMap &inputs,
Y
Yang Yang(Tony) 已提交
54 55 56 57
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {}

58 59 60
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
61 62 63
    PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)),
                            platform::errors::NotFound(
                                "Input(Condition) of WhileOp is not found."));
64

Y
Yang Yang(Tony) 已提交
65
    auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
66
    PADDLE_ENFORCE_EQ(
67 68
        cond.dims(),
        phi::make_ddim({1}),
69 70 71
        platform::errors::InvalidArgument(
            "The shape of Input(Condition) of WhileOp must be 1. But now "
            "the Condition's shape is ",
72 73
            cond.dims().to_str(),
            ".\n"));
Y
Yang Yang(Tony) 已提交
74

75 76 77 78 79 80
#ifdef PADDLE_WITH_MKLDNN
    // (jczaja) Executor on being destroyed clears oneDNN cache and
    // resets registered model data layout. This is unwanted for nested
    // Executors (executors declared inside control ops)
    platform::DontClearMKLDNNCache(dev_place);
#endif
D
dzhwinter 已提交
81
    framework::Executor executor(dev_place);
Y
Yu Yang 已提交
82
    auto *block = Attr<framework::BlockDesc *>(kStepBlock);
D
dzhwinter 已提交
83

Y
Yang Yang(Tony) 已提交
84
    auto *program = block->Program();
85 86 87 88
    bool is_test = Attr<bool>("is_test");

    std::set<std::string> no_copy_var_names;
    if (!is_test) {
89 90 91 92 93 94 95 96 97 98 99 100 101
      // set all persistable parameters into no_copy_var_names.
      auto *global_block = block;

      while (global_block->ID() != 0)
        global_block = global_block->ParentBlock();
      auto all_vars = global_block->AllVars();
      std::for_each(all_vars.begin(),
                    all_vars.end(),
                    [&no_copy_var_names](framework::VarDesc *var) {
                      if (var->IsParameter())
                        no_copy_var_names.insert(var->Name());
                    });

102 103 104 105 106 107 108 109 110 111 112 113 114
      const std::vector<framework::OpDesc *> &all_ops = block->AllOps();
      for (const framework::OpDesc *op : all_ops) {
        const framework::VariableNameMap &input_var_names = op->Inputs();
        const framework::VariableNameMap &output_var_names = op->Outputs();
        for (auto &ipt : input_var_names) {
          for (const std::string &var_name : ipt.second) {
            if (StrInVaraiableNameMap(var_name, output_var_names)) {
              no_copy_var_names.insert(var_name);
            }
          }
        }
      }
    }
Y
Yang Yang(Tony) 已提交
115 116 117

    auto step_scopes =
        scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
118 119 120 121 122 123 124 125 126 127 128

    if (step_scopes->size() > 0) {
      platform::DeviceContextPool::Instance().Get(dev_place)->Wait();
      for (auto &s : *step_scopes) {
        if (scope.HasKid(s)) {
          scope.DeleteScope(s);
        }
      }
      step_scopes->clear();
    }

129 130
    PADDLE_ENFORCE_EQ(step_scopes->size(),
                      0,
131 132
                      platform::errors::PreconditionNotMet(
                          "The Output(StepScope) of WhileOp should be empty."));
X
Xin Pan 已提交
133

134
    bool cond_data = GetCondData(cond);
S
sneaxiy 已提交
135
    auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
S
sneaxiy 已提交
136
    VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
S
fix bug  
sneaxiy 已提交
137

S
sneaxiy 已提交
138
    auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
139
    if (!is_test) {
140
      while (cond_data) {
141 142
        auto &current_scope = scope.NewScope();
        step_scopes->push_back(&current_scope);
143 144 145 146 147 148 149 150 151 152 153 154

        std::vector<std::string> rename_vars;
        for (const std::string &input_var_name : Inputs(kX)) {
          if (no_copy_var_names.find(input_var_name) ==
              no_copy_var_names.end()) {
            std::string input_var_rename = input_var_name + kSuffix;
            framework::Variable *input_var = scope.FindVar(input_var_name);
            if (input_var->IsType<framework::LoDTensor>()) {
              rename_vars.push_back(input_var_rename);
              auto input_var_tensor = input_var->Get<LoDTensor>();
              auto *rename_input_var_tensor =
                  current_scope.Var(input_var_rename)->GetMutable<LoDTensor>();
155 156
              framework::TensorCopy(
                  input_var_tensor, dev_place, rename_input_var_tensor);
157 158 159 160
              rename_input_var_tensor->set_lod(input_var_tensor.lod());
            }
          }
        }
161 162
        executor.RunPreparedContext(
            ctx.get(), &current_scope, false, true, true);
163 164 165 166 167 168

        for (auto &var_rename : rename_vars) {
          std::string input_var_name =
              var_rename.substr(0, var_rename.size() - strlen(kSuffix));
          current_scope.Rename(var_rename, input_var_name);
        }
169 170
        cond_data =
            GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>());
171 172
      }
    } else {
Y
Yang Yang(Tony) 已提交
173
      auto &current_scope = scope.NewScope();
174
      executor.CreateVariables(*program, &current_scope, block->ID());
175
      while (cond_data) {
176 177 178 179 180 181 182 183 184 185 186 187 188
        for (auto &name : current_scope.LocalVarNames()) {
          auto *var = current_scope.Var(name);
          if (var->IsType<framework::LoDTensor>()) {
            // Clear all lod information for all lod_tensors.
            auto *t = var->GetMutable<framework::LoDTensor>();
            framework::LoD empty_lod;
            t->set_lod(empty_lod);
          } else if (var->IsType<framework::LoDTensorArray>()) {
            // Clear elements of all tensor arrays.
            auto *t = var->GetMutable<framework::LoDTensorArray>();
            t->clear();
          }
        }
189 190
        executor.RunPreparedContext(
            ctx.get(), &current_scope, false, false, false);
191 192
        cond_data =
            GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>());
C
chengduo 已提交
193
      }
194
      scope.DeleteScope(&current_scope);
Y
Yang Yang(Tony) 已提交
195 196 197 198 199 200
    }
  }
};

class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
201
  void Make() override {
Y
Yang Yu 已提交
202
    AddInput(kX,
Y
Yang Yang(Tony) 已提交
203 204 205 206 207 208 209
             "A set of variables, which are required by operators inside the "
             "block of While Op.")
        .AsDuplicable();
    AddInput(
        kCondition,
        "(Bool) An scalar. When it's False, the While Op will be terminated.")
        .AsDuplicable();
Y
Yang Yang(Tony) 已提交
210
    AddOutput(kOutputs,
Y
Yang Yang(Tony) 已提交
211
              "A set of variables, which will be assigned with values "
Y
Yang Yang(Tony) 已提交
212
              "generated by the operators inside the block of While Op.")
Y
Yang Yang(Tony) 已提交
213 214 215 216 217
        .AsDuplicable();
    AddOutput(kStepScopes,
              "(StepScopeVar) A vector of local scope, which size equals the "
              "step number of While Op. The i'th scope storages temporary "
              "variables generated in the i'th step.");
Y
Yu Yang 已提交
218 219
    AddAttr<framework::BlockDesc *>(kStepBlock,
                                    "The step block inside WhileOp");
220 221 222 223
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
Y
Yang Yang(Tony) 已提交
224 225 226 227 228 229 230
    AddComment(R"DOC(
)DOC");
  }
};

class WhileGradOp : public framework::OperatorBase {
 public:
231 232
  WhileGradOp(const std::string &type,
              const framework::VariableNameMap &inputs,
Y
Yang Yang(Tony) 已提交
233 234 235 236
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {}

237 238 239
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
240
    PADDLE_ENFORCE_EQ(
241 242
        Attr<bool>("is_test"),
        false,
243 244
        platform::errors::InvalidArgument(
            "WhileGradOp is only callable when is_test is false."));
245 246 247
    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);
D
dzhwinter 已提交
248
    framework::Executor executor(dev_place);
Y
Yu Yang 已提交
249
    auto *block = Attr<framework::BlockDesc *>(kStepBlock);
Y
Yang Yang(Tony) 已提交
250
    auto *program = block->Program();
S
sneaxiy 已提交
251 252

    auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
S
sneaxiy 已提交
253
    VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
S
sneaxiy 已提交
254
    auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
Y
Yang Yang(Tony) 已提交
255 256 257 258

    auto *step_scopes =
        scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();

Y
Yang Yang(Tony) 已提交
259 260 261 262
    auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
    auto inside_og_names =
        Attr<std::vector<std::string>>("original_output_grad");

263 264
    PADDLE_ENFORCE_EQ(outside_og_names.size(),
                      inside_og_names.size(),
265 266 267 268 269 270
                      platform::errors::InvalidArgument(
                          "The number of original output gradient names "
                          "does not match the number of backward input "
                          "gradient names. The number of Backward input "
                          "names is %d and the numbers of original output "
                          "gradient names is %d.",
271 272
                          outside_og_names.size(),
                          inside_og_names.size()));
Y
Yang Yang(Tony) 已提交
273

Y
Yang Yang(Tony) 已提交
274
    for (auto cur_scope_iter = step_scopes->rbegin();
275 276
         cur_scope_iter != step_scopes->rend();
         ++cur_scope_iter) {
M
minqiyang 已提交
277 278
      VLOG(3) << "Start backward at time_step "
              << cur_scope_iter - step_scopes->rbegin();
Y
Yang Yang(Tony) 已提交
279 280 281 282 283
      framework::Scope &cur_scope = **cur_scope_iter;
      // Link OG from outside to inside
      for (size_t i = 0; i < outside_og_names.size(); ++i) {
        auto outside_og_name = outside_og_names[i];
        auto inside_og_name = inside_og_names[i];
M
minqiyang 已提交
284 285
        VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
                << inside_og_name;
C
chengduo 已提交
286 287 288 289
        if (scope.FindVar(outside_og_name) == nullptr) {
          continue;
        }

290 291
        auto &og_outside = *scope.FindVar(outside_og_name);
        auto &og_inside = *cur_scope.Var(inside_og_name);
S
sneaxiy 已提交
292
        if (og_outside.IsType<framework::LoDTensor>()) {
Y
Yang Yang(Tony) 已提交
293
          auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
294
          auto &inside_tensor = *og_inside.GetMutable<framework::LoDTensor>();
Y
Yang Yang(Tony) 已提交
295 296
          inside_tensor.set_lod(outside_tensor.lod());
          inside_tensor.ShareDataWith(outside_tensor);
S
sneaxiy 已提交
297
        } else if (og_outside.IsType<framework::LoDTensorArray>()) {
298 299
          auto outside_array =
              og_outside.GetMutable<framework::LoDTensorArray>();
Y
Yang Yang(Tony) 已提交
300
          auto &inside_array =
301
              *og_inside.GetMutable<framework::LoDTensorArray>();
302 303 304
          inside_array.clear();
          inside_array.resize(outside_array->size());
          VLOG(8) << outside_og_name << " size = " << outside_array->size();
Y
Yang Yang(Tony) 已提交
305 306

          for (size_t j = 0; j < inside_array.size(); ++j) {
307 308 309 310 311 312 313
            if (!outside_array->at(j).IsInitialized()) {
              outside_array->at(j).Resize({0});
            }
            VLOG(8) << j << " " << outside_array->at(j).numel();
            if (outside_array->at(j).numel() != 0) {
              inside_array[j].set_lod(outside_array->at(j).lod());
              inside_array[j].ShareDataWith(outside_array->at(j));
Y
Yang Yang(Tony) 已提交
314
            } else {
315
              PADDLE_ENFORCE_EQ(
316 317
                  inside_array[j].numel(),
                  0,
318 319 320
                  platform::errors::InvalidArgument(
                      "The numel of %d-th element of var %s (LoDTensorArray) "
                      "in while block must be 0, but received its numel is %d.",
321 322 323
                      j,
                      inside_og_name,
                      inside_array[j].numel()));
Y
Yang Yang(Tony) 已提交
324 325
            }
          }
C
chengduo 已提交
326
        } else {
327 328 329
          PADDLE_THROW(platform::errors::Unimplemented(
              "Currently only support LoDTensor and LoDTensorArray in "
              "WhileGradOp."));
Y
Yang Yang(Tony) 已提交
330 331
        }
      }
332 333
      executor.RunPreparedContext(
          ctx.get(), *cur_scope_iter, false, true, true);
Y
Yang Yang(Tony) 已提交
334

C
chengduo 已提交
335 336 337
      // The Outputs(kXGRAD) contains the names of the gradient of parameters
      // and inputs.
      auto &pg_ig_names = Outputs(kXGRAD);
Y
Yang Yu 已提交
338
      auto &p_names = Inputs(kX);
339 340
      PADDLE_ENFORCE_EQ(pg_ig_names.size(),
                        p_names.size(),
341 342 343 344 345
                        platform::errors::PreconditionNotMet(
                            "The number of names in Outputs(X@GRAD) does not "
                            "match the number of names in Inputs(X). The "
                            "number of names in Outputs(X@GRAD) is %d and "
                            "the number of names in Inputs(X) is %d.",
346 347
                            pg_ig_names.size(),
                            p_names.size()));
C
chengduo 已提交
348 349
      for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) {
        if (pg_ig_names[param_id] == framework::kEmptyVarName) {
350
          continue;  // parameter doesn't have gradient
Y
Yang Yang(Tony) 已提交
351 352
        }
        auto inside_grad_name = framework::GradVarName(p_names[param_id]);
Y
Yang Yang(Tony) 已提交
353

C
chengduo 已提交
354 355 356 357
        // for some grad_op, their input doesn't have gradient,
        // for example lookup_table_grad_op, the input(Idx) doesn't have
        // gradient.
        auto pg_ig_var = cur_scope.FindVar(inside_grad_name);
358
        PADDLE_ENFORCE_NOT_NULL(
359 360 361
            pg_ig_var,
            platform::errors::NotFound("Variable %s is not found.",
                                       inside_grad_name));
C
chengduo 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
        if (pg_ig_var->IsType<framework::LoDTensorArray>()) {
          auto pg_ig_lod_t_arr =
              pg_ig_var->GetMutable<framework::LoDTensorArray>();
          bool empty = true;
          for (auto &each : *pg_ig_lod_t_arr) {
            if (each.numel() != 0) {
              empty = false;
              break;
            }
          }
          if (empty) {
            LOG(WARNING) << pg_ig_names[param_id]
                         << " is not found in cur_scope.";
            continue;
          }
        }

Y
Yang Yang(Tony) 已提交
379
        //  // TODO(tonyyang-svail): Not sure we need the following
Y
Yang Yang(Tony) 已提交
380 381 382 383 384 385 386 387
        //  // If does not compute gradient of that variable inside rnn,
        //  just
        //  // continue
        //  if (local_var_names.find(inside_grad_name) ==
        //  local_var_names.end()) {
        //    continue;
        //  }

388 389 390
        auto var_iter = std::find(outside_og_names.begin(),
                                  outside_og_names.end(),
                                  pg_ig_names[param_id]);
391

Y
Yang Yang(Tony) 已提交
392 393 394
        // zero gradient variable in step 0
        if (cur_scope_iter == step_scopes->rbegin()) {
          auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
395
          PADDLE_ENFORCE_NOT_NULL(
396 397 398
              var,
              platform::errors::NotFound("Variable %s is not found.",
                                         inside_grad_name));
399
          PADDLE_ENFORCE_EQ(
C
chengduoZH 已提交
400 401
              var->IsType<framework::LoDTensorArray>() ||
                  var->IsType<LoDTensor>(),
402 403 404 405
              true,
              platform::errors::InvalidArgument(
                  "Currently the type of var only can be LoDTensorArray, "
                  "or LoDTensor, but the received var[%s] is %s.",
406 407
                  inside_grad_name,
                  framework::ToTypeName(var->Type())));
C
chengduo 已提交
408

409 410
          if ((var_iter == outside_og_names.end()) &&
              var->IsType<LoDTensor>()) {
Y
Yang Yang(Tony) 已提交
411 412
            auto &inside_tensor = var->Get<framework::LoDTensor>();
            framework::AttributeMap attrs;
413 414
            attrs["dtype"] =
                framework::TransToProtoVarType(inside_tensor.dtype());
415
            attrs["shape"] = phi::vectorize<int>(inside_tensor.dims());
Y
Yang Yang(Tony) 已提交
416 417
            attrs["value"] = 0.0f;

C
chengduo 已提交
418
            auto var_name = pg_ig_names[param_id];
419 420 421 422 423
            auto zero_op =
                framework::OpRegistry::CreateOp("fill_constant",
                                                framework::VariableNameMap{},
                                                {{"Out", {var_name}}},
                                                attrs);
D
dzhwinter 已提交
424
            zero_op->Run(scope, dev_place);
425 426 427
            scope.FindVar(var_name)
                ->GetMutable<framework::LoDTensor>()
                ->set_lod(inside_tensor.lod());
Y
Yang Yang(Tony) 已提交
428 429
          }
        }
430 431 432 433 434 435
        auto var_outside = scope.FindVar(pg_ig_names[param_id]);
        if ((var_iter == outside_og_names.end()) ||
            ((var_iter != outside_og_names.end()) &&
             var_outside->IsType<framework::LoDTensorArray>())) {
          auto new_inside_name = cur_scope.Rename(inside_grad_name);
          auto sum_op = framework::OpRegistry::CreateOp(
436 437
              "sum",
              {{"X", {pg_ig_names[param_id], new_inside_name}}},
438 439 440 441 442
              {{"Out", {pg_ig_names[param_id]}}},
              framework::AttributeMap{{"use_mkldnn", {false}}});
          sum_op->Run(cur_scope, dev_place);
          cur_scope.Rename(new_inside_name, inside_grad_name);
        }
Y
Yang Yang(Tony) 已提交
443
      }
444 445
      dev_ctx.Wait();
      const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
Y
Yang Yang(Tony) 已提交
446
    }
447
    step_scopes->clear();
Y
Yang Yang(Tony) 已提交
448 449 450
  }
};

H
hong 已提交
451 452
template <typename T>
class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
Y
Yang Yang(Tony) 已提交
453
 public:
H
hong 已提交
454
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yang Yang(Tony) 已提交
455 456

 protected:
457
  void Apply(GradOpPtr<T> while_grad) const override {
F
Update  
fengjiayi 已提交
458
    while_grad->SetType("while_grad");
H
hong 已提交
459 460 461
    while_grad->SetInput(kX, this->Input(kX));
    while_grad->SetInput(kOutputs, this->Output(kOutputs));
    while_grad->SetInput(kStepScopes, this->Output(kStepScopes));
F
Update  
fengjiayi 已提交
462 463

    auto *grad_block = this->grad_block_[0];
Y
Yu Yang 已提交
464 465
    auto *fwd_block = grad_block->ForwardBlock();
    auto *parent_block = grad_block->ParentBlock();
466 467 468

    // Not all of IGs will be generated by inner gradient operators of while op.
    // Ignore IGs that is not generated by the inside block.
F
Update  
fengjiayi 已提交
469 470 471 472
    std::unordered_set<std::string> inner_op_outputs;
    for (const auto *op : grad_block->AllOps()) {
      for (auto &oname : op->OutputArgumentNames()) {
        inner_op_outputs.insert(oname);
473 474
      }
    }
H
hong 已提交
475 476
    auto igs = this->InputGrad(kX, /*do not drop empty gradient*/ false);

477
    for (auto &each_ig : igs) {
F
Update  
fengjiayi 已提交
478
      if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
M
minqiyang 已提交
479
        VLOG(8) << "Ignore " << each_ig;
480 481 482
        each_ig = framework::kEmptyVarName;
      }
    }
F
Update  
fengjiayi 已提交
483
    while_grad->SetOutput(framework::GradVarName(kX), igs);
Y
Yang Yang(Tony) 已提交
484 485 486 487

    // OG should be re-calculated by step blocks, since many outputs of while op
    // do not need to calculate gradients.
    std::unordered_set<std::string> block_ins;
H
hong 已提交
488 489
    block_ins.reserve(this->Input(kX).size() + this->Output(kOutputs).size());
    for (auto &p : this->Input(kX)) {
F
fengjiayi 已提交
490 491
      block_ins.insert(p);
    }
H
hong 已提交
492
    for (auto &o : this->Output(kOutputs)) {
F
fengjiayi 已提交
493 494
      block_ins.insert(o);
    }
Y
Yu Yang 已提交
495
    std::unordered_set<std::string> output_grads;
F
Update  
fengjiayi 已提交
496 497 498 499
    for (const auto *op : grad_block->AllOps()) {
      for (auto &input_name : op->InputArgumentNames()) {
        // If the input of Op has been recorded or is generated by the forward
        // block, do not make it as input again.
Y
Yu Yang 已提交
500 501 502

        // The input is located in I/O or other op's outputs or the variable is
        // located in grad_block's parents
F
Update  
fengjiayi 已提交
503
        if (block_ins.find(input_name) != block_ins.end() ||
Y
Yu Yang 已提交
504 505
            (fwd_block->FindVarRecursive(input_name) != nullptr ||
             parent_block->FindVarRecursive(input_name) != nullptr)) {
Y
Yang Yang(Tony) 已提交
506 507
          continue;
        }
C
chengduo 已提交
508

Y
Yu Yang 已提交
509
        output_grads.insert(input_name);
Y
Yang Yang(Tony) 已提交
510
      }
F
Update  
fengjiayi 已提交
511
      for (auto &output_name : op->OutputArgumentNames()) {
Y
Yang Yang(Tony) 已提交
512
        block_ins.insert(output_name);
Y
Yang Yang(Tony) 已提交
513 514
      }
    }
Y
Yang Yang(Tony) 已提交
515

Y
Yu Yang 已提交
516 517
    std::vector<std::string> output_grads_list;
    output_grads_list.resize(output_grads.size());
518 519
    std::copy(
        output_grads.begin(), output_grads.end(), output_grads_list.begin());
Y
Yu Yang 已提交
520
    while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
F
Update  
fengjiayi 已提交
521 522

    while_grad->SetAttrMap(this->Attrs());
A
Abhinav Arora 已提交
523
    while_grad->SetBlockAttr(kStepBlock, grad_block);
Y
Yang Yang(Tony) 已提交
524 525
    // record the original output gradient names, since the gradient name of
    // while operator could be renamed.
Y
Yu Yang 已提交
526
    while_grad->SetAttr("original_output_grad", output_grads_list);
Y
Yang Yang(Tony) 已提交
527

S
sneaxiy 已提交
528
    while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
Y
Yang Yang(Tony) 已提交
529 530 531
  }
};

532 533
class WhileGradOpVarTypeInference
    : public framework::StaticGraphVarTypeInference {
Y
Yang Yang(Tony) 已提交
534
 public:
M
minqiyang 已提交
535
  void operator()(framework::InferVarTypeContext *ctx) const override {
536 537
    auto p_names = Input(ctx, kX);
    auto pg_ig_names = Output(ctx, framework::GradVarName(kX));
Y
Yang Yang(Tony) 已提交
538 539

    for (size_t i = 0; i < p_names.size(); ++i) {
540
      if (HasVar(ctx, pg_ig_names[i])) {
M
minqiyang 已提交
541
        VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
542 543 544
                << " type: " << GetType(ctx, p_names[i]);
        SetType(ctx, pg_ig_names[i], GetType(ctx, p_names[i]));
        SetDataType(ctx, pg_ig_names[i], GetDataType(ctx, p_names[i]));
Y
Yang Yang(Tony) 已提交
545 546 547 548 549 550 551 552
      }
    }
  }
};

class WhileGradOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
Y
Yang Yu 已提交
553 554
    ctx->HasInputs(kX);
    ctx->HasOutputs(framework::GradVarName(kX));
Y
Yang Yang(Tony) 已提交
555 556
    ctx->HasInputs(kOutputs);
    ctx->HasInputs(framework::GradVarName(kOutputs));
C
chengduo 已提交
557
    auto pg_ig_names = ctx->Outputs(kXGRAD);
558 559
    auto in_var_ptrs = ctx->GetInputVarPtrs(kX);
    auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD);
560 561
    PADDLE_ENFORCE_EQ(in_var_ptrs.size(),
                      out_var_ptrs.size(),
562 563 564
                      platform::errors::InvalidArgument(
                          "The size of Inputs(X) must be the same as "
                          "the size of Outputs(X@GRAD)."));
X
Xin Pan 已提交
565 566

    for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
C
chengduo 已提交
567
      if (pg_ig_names[i] == framework::kEmptyVarName) {
Y
Yang Yang(Tony) 已提交
568 569
        continue;
      }
570
      framework::VarDesc *in_var =
R
Ruibiao Chen 已提交
571 572
          PADDLE_GET(framework::VarDesc *, in_var_ptrs[i]);
      PADDLE_GET(framework::VarDesc *, out_var_ptrs[i])
573
          ->SetShape(in_var->GetShape());
Y
Yang Yang(Tony) 已提交
574 575 576 577
    }
  }
};

Y
Yang Yang(Tony) 已提交
578 579 580
}  // namespace operators
}  // namespace paddle

H
hong 已提交
581
REGISTER_OPERATOR(
582 583 584
    while,
    paddle::operators::WhileOp,
    paddle::operators::WhileOpMaker,
H
hong 已提交
585
    paddle::operators::WhileGradOpMaker<paddle::framework::OpDesc>);
586 587
REGISTER_OPERATOR(while_grad,
                  paddle::operators::WhileGradOp,
Y
Yang Yang(Tony) 已提交
588 589
                  paddle::operators::WhileGradOpShapeInference,
                  paddle::operators::WhileGradOpVarTypeInference);