while_op.cc 30.7 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Y
Yang Yang(Tony) 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/executor.h"
16
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
19
#include "paddle/fluid/operators/controlflow/control_flow_op_helper.h"
S
sneaxiy 已提交
20
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
Y
Yang Yang(Tony) 已提交
21

22 23 24
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
25
#include "paddle/fluid/platform/flags.h"
Z
zhupengyang 已提交
26 27 28 29 30 31 32

PADDLE_DEFINE_EXPORTED_bool(
    cache_inference_while_scope,
    false,
    "Cache the scope of the while op to avoid repeated creation of the scope "
    "for each iteration and improve inference performance.");

W
wanghuancoder 已提交
33 34 35 36 37 38 39 40
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
class VarDesc;
}  // namespace framework
}  // namespace paddle

Y
Yang Yang(Tony) 已提交
41 42 43 44 45
namespace paddle {
namespace operators {

using StepScopeVar = std::vector<framework::Scope *>;

S
sneaxiy 已提交
46 47 48 49 50 51 52 53 54 55 56
namespace {  // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString(
    const std::vector<std::string> &vars) {
  std::string str = "Skip " + std::to_string(vars.size()) +
                    " var(s) in eager deletion mode: ";
  for (auto &var : vars) {
    str.append(var);
    str.push_back(' ');
  }
  return str;
}
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

static void TransferVariablePlace(const framework::Scope *scope,
                                  const std::string &var_name,
                                  const phi::Place &dst_place,
                                  const platform::DeviceContext &dev_ctx) {
  framework::Variable *var = scope->FindVar(var_name);
  if (var == nullptr) {
    VLOG(4) << "[TransferVariablePlace]"
            << "lost in_var: " << var_name;
    return;
  }
  if (var->Type() != framework::proto::VarType::LOD_TENSOR) {
    VLOG(10) << "[TransferVariablePlace]" << var_name << " type changed:"
             << framework::TransToPhiDataType(
                    framework::ToVarType(var->Type()));
    return;
  }
  phi::DenseTensor *t = var->GetMutable<phi::DenseTensor>();
  if (t->place() == dst_place) {
    VLOG(10) << "[TransferVariablePlace]"
             << "no need transfer: " << var_name;
    return;
  }

  phi::DenseTensor *new_t = new phi::DenseTensor;
  framework::TensorCopy(*t, dst_place, new_t);
  dev_ctx.Wait();

  t->set_meta(new_t->meta());
  t->ResetHolder(new_t->Holder());

  VLOG(4) << "[TransferVariablePlace]" << var_name
          << " place: " << new_t->place();
}

92
}  // namespace
Y
Yang Yang(Tony) 已提交
93 94 95

class WhileOp : public framework::OperatorBase {
 public:
96 97
  WhileOp(const std::string &type,
          const framework::VariableNameMap &inputs,
Y
Yang Yang(Tony) 已提交
98 99 100 101
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {}

102 103 104
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
105 106 107
    PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)),
                            platform::errors::NotFound(
                                "Input(Condition) of WhileOp is not found."));
108

109
    auto &cond = scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>();
110
    PADDLE_ENFORCE_EQ(
111 112
        cond.numel(),
        1,
113
        platform::errors::InvalidArgument(
114 115 116
            "The numel of Input(Condition) of WhileOp must be 1. But now "
            "the Condition's numel is ",
            cond.numel(),
117
            ".\n"));
Y
Yang Yang(Tony) 已提交
118

119
#ifdef PADDLE_WITH_MKLDNN
120 121
    // Executor on being destroyed clears oneDNN cache and resets
    // registered model data layout. This is unwanted for nested
122 123 124
    // Executors (executors declared inside control ops)
    platform::DontClearMKLDNNCache(dev_place);
#endif
Y
Yu Yang 已提交
125
    auto *block = Attr<framework::BlockDesc *>(kStepBlock);
D
dzhwinter 已提交
126

127 128 129 130
    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);

131 132 133 134
    bool is_test = Attr<bool>("is_test");

    std::set<std::string> no_copy_var_names;
    if (!is_test) {
135 136 137 138 139 140 141 142 143 144 145 146 147
      // set all persistable parameters into no_copy_var_names.
      auto *global_block = block;

      while (global_block->ID() != 0)
        global_block = global_block->ParentBlock();
      auto all_vars = global_block->AllVars();
      std::for_each(all_vars.begin(),
                    all_vars.end(),
                    [&no_copy_var_names](framework::VarDesc *var) {
                      if (var->IsParameter())
                        no_copy_var_names.insert(var->Name());
                    });

148 149 150 151 152 153 154 155 156 157 158 159 160
      const std::vector<framework::OpDesc *> &all_ops = block->AllOps();
      for (const framework::OpDesc *op : all_ops) {
        const framework::VariableNameMap &input_var_names = op->Inputs();
        const framework::VariableNameMap &output_var_names = op->Outputs();
        for (auto &ipt : input_var_names) {
          for (const std::string &var_name : ipt.second) {
            if (StrInVaraiableNameMap(var_name, output_var_names)) {
              no_copy_var_names.insert(var_name);
            }
          }
        }
      }
    }
Y
Yang Yang(Tony) 已提交
161 162 163

    auto step_scopes =
        scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
164 165 166 167 168 169 170 171 172 173 174

    if (step_scopes->size() > 0) {
      platform::DeviceContextPool::Instance().Get(dev_place)->Wait();
      for (auto &s : *step_scopes) {
        if (scope.HasKid(s)) {
          scope.DeleteScope(s);
        }
      }
      step_scopes->clear();
    }

175 176
    PADDLE_ENFORCE_EQ(step_scopes->size(),
                      0,
177 178
                      platform::errors::PreconditionNotMet(
                          "The Output(StepScope) of WhileOp should be empty."));
X
Xin Pan 已提交
179

180
    bool cond_data = GetCondData(cond);
S
sneaxiy 已提交
181
    auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
S
sneaxiy 已提交
182
    VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
S
fix bug  
sneaxiy 已提交
183

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    // note(lvyongkang): The assign op in while loop may change the place of
    // variable. However, InterpreterCore fix the kernel of every ops during its
    // first run. A cpu tensor may become gpu tensor after first run. This will
    // lead to segmetation fault when it's used in a cpu kernel. Here we record
    // the place of every inputs and restore their place after
    // InterpreterCore.run().
    std::map<std::string, phi::Place> input_var_original_places;
    for (const auto &in_name : Inputs(kX)) {
      framework::Variable *var = scope.FindVar(in_name);
      if (var == nullptr) {
        VLOG(4) << "[while op]"
                << "input not found:" << in_name;
      }

      if (var->Type() == framework::proto::VarType::LOD_TENSOR) {
        input_var_original_places[in_name] =
            (var->Get<phi::DenseTensor>()).place();
      } else {
        VLOG(10) << "[while op]"
                 << "skip backup input " << in_name << " type:"
                 << framework::TransToPhiDataType(
                        framework::ToVarType(var->Type()));
      }
    }

209 210 211 212 213 214 215 216 217 218 219 220
    LOG_FIRST_N(INFO, 1) << "[ControlFlow][WhileOp] New Executor is Running.";
    if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
      framework::Scope placeholder;  // Don't care if it's valid, just for
                                     // initialize InterpreterCore
      framework::interpreter::ExecutionConfig execution_config;
      execution_config.create_local_scope = false;
      execution_config.used_for_control_flow_op = true;
      execution_config.skip_gc_vars =
          std::set<std::string>(skip_vars.begin(), skip_vars.end());

      core_.reset(new framework::InterpreterCore(
          dev_place, *block, &placeholder, execution_config));
221 222
    }

223 224
    core_->SetOutputHooks(hookfuncs_);

225
    if (!is_test) {
226
      while (cond_data) {
227 228
        auto &current_scope = scope.NewScope();
        step_scopes->push_back(&current_scope);
229 230 231 232 233 234 235

        std::vector<std::string> rename_vars;
        for (const std::string &input_var_name : Inputs(kX)) {
          if (no_copy_var_names.find(input_var_name) ==
              no_copy_var_names.end()) {
            std::string input_var_rename = input_var_name + kSuffix;
            framework::Variable *input_var = scope.FindVar(input_var_name);
236
            if (input_var->IsType<phi::DenseTensor>()) {
237
              rename_vars.push_back(input_var_rename);
238
              auto input_var_tensor = input_var->Get<phi::DenseTensor>();
239
              auto *rename_input_var_tensor =
240 241
                  current_scope.Var(input_var_rename)
                      ->GetMutable<phi::DenseTensor>();
242 243
              framework::TensorCopy(
                  input_var_tensor, dev_place, rename_input_var_tensor);
244 245 246 247
              rename_input_var_tensor->set_lod(input_var_tensor.lod());
            }
          }
        }
248

249 250 251 252 253 254 255 256 257 258
        BuildScopeForControlFlowOp(*core_, *block, &current_scope);
        core_->reset_scope(&current_scope);
        core_->Run({}, false);

        // restore inputs place
        for (const auto &n : input_var_original_places) {
          const std::string &in_name = n.first;
          const phi::Place &original_place = n.second;
          // input vars exist in `scope` not `current_scope`
          TransferVariablePlace(&scope, in_name, original_place, dev_ctx);
259
        }
260 261 262 263 264 265

        for (auto &var_rename : rename_vars) {
          std::string input_var_name =
              var_rename.substr(0, var_rename.size() - strlen(kSuffix));
          current_scope.Rename(var_rename, input_var_name);
        }
266 267
        cond_data = GetCondData(
            scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
268 269
      }
    } else {
Z
zhupengyang 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282
      framework::Scope *current_scope = nullptr;
      if (!FLAGS_cache_inference_while_scope) {
        current_scope = &(scope.NewScope());
        BuildScopeForControlFlowOp(*core_, *block, current_scope);
        core_->reset_scope(current_scope);
      } else {
        if (cached_inference_scope_ == nullptr) {
          cached_inference_scope_ = &(scope.NewScope());
          BuildScopeForControlFlowOp(*core_, *block, cached_inference_scope_);
          core_->reset_scope(cached_inference_scope_);
        }
        current_scope = cached_inference_scope_;
      }
283

284
      while (cond_data) {
Z
zhupengyang 已提交
285 286
        for (auto &name : current_scope->LocalVarNames()) {
          auto *var = current_scope->Var(name);
287
          if (var->IsType<phi::DenseTensor>()) {
288
            // Clear all lod information for all lod_tensors.
289
            auto *t = var->GetMutable<phi::DenseTensor>();
290 291 292 293 294 295 296 297
            framework::LoD empty_lod;
            t->set_lod(empty_lod);
          } else if (var->IsType<framework::LoDTensorArray>()) {
            // Clear elements of all tensor arrays.
            auto *t = var->GetMutable<framework::LoDTensorArray>();
            t->clear();
          }
        }
298

299
        core_->Run({}, false);
300

301 302
        cond_data = GetCondData(
            scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
C
chengduo 已提交
303
      }
H
hong 已提交
304

Z
zhupengyang 已提交
305 306 307
      if (!FLAGS_cache_inference_while_scope) {
        scope.DeleteScope(current_scope);
      }
Y
Yang Yang(Tony) 已提交
308 309
    }
  }
310 311 312 313 314

 private:
  mutable std::shared_ptr<framework::Executor> executor_{nullptr};
  mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
  mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
Z
zhupengyang 已提交
315
  mutable framework::Scope *cached_inference_scope_{nullptr};
Y
Yang Yang(Tony) 已提交
316 317 318 319
};

class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
320
  void Make() override {
Y
Yang Yu 已提交
321
    AddInput(kX,
Y
Yang Yang(Tony) 已提交
322 323 324 325 326 327 328
             "A set of variables, which are required by operators inside the "
             "block of While Op.")
        .AsDuplicable();
    AddInput(
        kCondition,
        "(Bool) An scalar. When it's False, the While Op will be terminated.")
        .AsDuplicable();
Y
Yang Yang(Tony) 已提交
329
    AddOutput(kOutputs,
Y
Yang Yang(Tony) 已提交
330
              "A set of variables, which will be assigned with values "
Y
Yang Yang(Tony) 已提交
331
              "generated by the operators inside the block of While Op.")
Y
Yang Yang(Tony) 已提交
332 333 334 335 336
        .AsDuplicable();
    AddOutput(kStepScopes,
              "(StepScopeVar) A vector of local scope, which size equals the "
              "step number of While Op. The i'th scope storages temporary "
              "variables generated in the i'th step.");
Y
Yu Yang 已提交
337 338
    AddAttr<framework::BlockDesc *>(kStepBlock,
                                    "The step block inside WhileOp");
339 340 341 342
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
Y
Yang Yang(Tony) 已提交
343 344 345 346 347 348 349
    AddComment(R"DOC(
)DOC");
  }
};

class WhileGradOp : public framework::OperatorBase {
 public:
350 351
  WhileGradOp(const std::string &type,
              const framework::VariableNameMap &inputs,
Y
Yang Yang(Tony) 已提交
352 353 354 355
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
      : framework::OperatorBase(type, inputs, outputs, attrs) {}

356 357 358
 private:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &dev_place) const override {
359
    PADDLE_ENFORCE_EQ(
360 361
        Attr<bool>("is_test"),
        false,
362 363
        platform::errors::InvalidArgument(
            "WhileGradOp is only callable when is_test is false."));
364 365 366
    // get device context from pool
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &dev_ctx = *pool.Get(dev_place);
367

Y
Yu Yang 已提交
368
    auto *block = Attr<framework::BlockDesc *>(kStepBlock);
H
hong 已提交
369
    auto *parent_block = block->ParentBlock();
S
sneaxiy 已提交
370 371

    auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
S
sneaxiy 已提交
372
    VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
Y
Yang Yang(Tony) 已提交
373 374 375 376

    auto *step_scopes =
        scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();

Y
Yang Yang(Tony) 已提交
377 378 379 380
    auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
    auto inside_og_names =
        Attr<std::vector<std::string>>("original_output_grad");

381 382
    PADDLE_ENFORCE_EQ(outside_og_names.size(),
                      inside_og_names.size(),
383 384 385 386 387 388
                      platform::errors::InvalidArgument(
                          "The number of original output gradient names "
                          "does not match the number of backward input "
                          "gradient names. The number of Backward input "
                          "names is %d and the numbers of original output "
                          "gradient names is %d.",
389 390
                          outside_og_names.size(),
                          inside_og_names.size()));
Y
Yang Yang(Tony) 已提交
391

392 393 394 395 396 397 398 399 400 401 402 403 404 405
    LOG_FIRST_N(INFO, 1)
        << "[ControlFlow][WhileGradOp] New Executor is Running.";
    if (!core_ || !platform::is_same_place(core_->GetPlace(), dev_place)) {
      std::set<std::string> skip_gc_vars(skip_vars.begin(), skip_vars.end());
      framework::Scope placeholder;  // Don't care if it's valid, just for
                                     // initialize InterpreterCore
      framework::interpreter::ExecutionConfig execution_config;
      execution_config.create_local_scope = false;
      execution_config.used_for_control_flow_op = true;
      execution_config.skip_gc_vars =
          std::set<std::string>(skip_vars.begin(), skip_vars.end());

      core_.reset(new framework::InterpreterCore(
          dev_place, *block, &placeholder, execution_config));
406 407
    }

Y
Yang Yang(Tony) 已提交
408
    for (auto cur_scope_iter = step_scopes->rbegin();
409 410
         cur_scope_iter != step_scopes->rend();
         ++cur_scope_iter) {
M
minqiyang 已提交
411 412
      VLOG(3) << "Start backward at time_step "
              << cur_scope_iter - step_scopes->rbegin();
Y
Yang Yang(Tony) 已提交
413 414 415 416 417
      framework::Scope &cur_scope = **cur_scope_iter;
      // Link OG from outside to inside
      for (size_t i = 0; i < outside_og_names.size(); ++i) {
        auto outside_og_name = outside_og_names[i];
        auto inside_og_name = inside_og_names[i];
M
minqiyang 已提交
418 419
        VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
                << inside_og_name;
C
chengduo 已提交
420 421 422 423
        if (scope.FindVar(outside_og_name) == nullptr) {
          continue;
        }

H
hong 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
        if (cur_scope_iter == step_scopes->rbegin()) {
          auto &og_outside = *scope.FindVar(outside_og_name);
          if (og_outside.IsType<phi::DenseTensor>() &&
              !og_outside.GetMutable<phi::DenseTensor>()->IsInitialized()) {
            auto *var_desc = parent_block->FindVarRecursive(outside_og_name);
            PADDLE_ENFORCE_NOT_NULL(var_desc,
                                    platform::errors::PreconditionNotMet(
                                        "Var `%s` is not found in parent "
                                        "block, can't fill constant.",
                                        outside_og_name));
            auto shape = var_desc->GetShape();
            VLOG(8) << "Found uninitialized tensor " << outside_og_name
                    << " in step 0, fill it with 0.0f. dims="
                    << phi::make_ddim(shape);
            framework::AttributeMap attrs;
            attrs["dtype"] = var_desc->GetDataType();
            attrs["shape"] = phi::vectorize<int>(phi::make_ddim(shape));
            attrs["value"] = 0.0f;

            auto var_name = outside_og_name;
            auto zero_op =
                framework::OpRegistry::CreateOp("fill_constant",
                                                framework::VariableNameMap{},
                                                {{"Out", {var_name}}},
                                                attrs);
            zero_op->Run(scope, dev_place);
          }
        }

453 454
        auto &og_outside = *scope.FindVar(outside_og_name);
        auto &og_inside = *cur_scope.Var(inside_og_name);
455 456 457
        if (og_outside.IsType<phi::DenseTensor>()) {
          auto &outside_tensor = og_outside.Get<phi::DenseTensor>();
          auto &inside_tensor = *og_inside.GetMutable<phi::DenseTensor>();
Y
Yang Yang(Tony) 已提交
458 459
          inside_tensor.set_lod(outside_tensor.lod());
          inside_tensor.ShareDataWith(outside_tensor);
S
sneaxiy 已提交
460
        } else if (og_outside.IsType<framework::LoDTensorArray>()) {
461 462
          auto outside_array =
              og_outside.GetMutable<framework::LoDTensorArray>();
Y
Yang Yang(Tony) 已提交
463
          auto &inside_array =
464
              *og_inside.GetMutable<framework::LoDTensorArray>();
465 466 467
          inside_array.clear();
          inside_array.resize(outside_array->size());
          VLOG(8) << outside_og_name << " size = " << outside_array->size();
Y
Yang Yang(Tony) 已提交
468 469

          for (size_t j = 0; j < inside_array.size(); ++j) {
470 471 472 473 474 475 476
            if (!outside_array->at(j).IsInitialized()) {
              outside_array->at(j).Resize({0});
            }
            VLOG(8) << j << " " << outside_array->at(j).numel();
            if (outside_array->at(j).numel() != 0) {
              inside_array[j].set_lod(outside_array->at(j).lod());
              inside_array[j].ShareDataWith(outside_array->at(j));
Y
Yang Yang(Tony) 已提交
477
            } else {
478
              PADDLE_ENFORCE_EQ(
479 480
                  inside_array[j].numel(),
                  0,
481 482 483
                  platform::errors::InvalidArgument(
                      "The numel of %d-th element of var %s (LoDTensorArray) "
                      "in while block must be 0, but received its numel is %d.",
484 485 486
                      j,
                      inside_og_name,
                      inside_array[j].numel()));
Y
Yang Yang(Tony) 已提交
487 488
            }
          }
C
chengduo 已提交
489
        } else {
490
          PADDLE_THROW(platform::errors::Unimplemented(
491 492
              "Currently only support phi::DenseTensor and "
              "phi::DenseTensorArray in "
493
              "WhileGradOp."));
Y
Yang Yang(Tony) 已提交
494 495
        }
      }
496

497 498 499
      BuildScopeForControlFlowOp(*core_, *block, *cur_scope_iter);
      core_->reset_scope(*cur_scope_iter);
      core_->Run({}, false);
Y
Yang Yang(Tony) 已提交
500

C
chengduo 已提交
501 502 503
      // The Outputs(kXGRAD) contains the names of the gradient of parameters
      // and inputs.
      auto &pg_ig_names = Outputs(kXGRAD);
Y
Yang Yu 已提交
504
      auto &p_names = Inputs(kX);
505 506
      PADDLE_ENFORCE_EQ(pg_ig_names.size(),
                        p_names.size(),
507 508 509 510 511
                        platform::errors::PreconditionNotMet(
                            "The number of names in Outputs(X@GRAD) does not "
                            "match the number of names in Inputs(X). The "
                            "number of names in Outputs(X@GRAD) is %d and "
                            "the number of names in Inputs(X) is %d.",
512 513
                            pg_ig_names.size(),
                            p_names.size()));
C
chengduo 已提交
514 515
      for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) {
        if (pg_ig_names[param_id] == framework::kEmptyVarName) {
516
          continue;  // parameter doesn't have gradient
Y
Yang Yang(Tony) 已提交
517 518
        }
        auto inside_grad_name = framework::GradVarName(p_names[param_id]);
Y
Yang Yang(Tony) 已提交
519

C
chengduo 已提交
520 521 522 523
        // for some grad_op, their input doesn't have gradient,
        // for example lookup_table_grad_op, the input(Idx) doesn't have
        // gradient.
        auto pg_ig_var = cur_scope.FindVar(inside_grad_name);
524
        PADDLE_ENFORCE_NOT_NULL(
525 526 527
            pg_ig_var,
            platform::errors::NotFound("Variable %s is not found.",
                                       inside_grad_name));
C
chengduo 已提交
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
        if (pg_ig_var->IsType<framework::LoDTensorArray>()) {
          auto pg_ig_lod_t_arr =
              pg_ig_var->GetMutable<framework::LoDTensorArray>();
          bool empty = true;
          for (auto &each : *pg_ig_lod_t_arr) {
            if (each.numel() != 0) {
              empty = false;
              break;
            }
          }
          if (empty) {
            LOG(WARNING) << pg_ig_names[param_id]
                         << " is not found in cur_scope.";
            continue;
          }
        }

Y
Yang Yang(Tony) 已提交
545
        //  // TODO(tonyyang-svail): Not sure we need the following
Y
Yang Yang(Tony) 已提交
546 547 548 549 550 551 552 553
        //  // If does not compute gradient of that variable inside rnn,
        //  just
        //  // continue
        //  if (local_var_names.find(inside_grad_name) ==
        //  local_var_names.end()) {
        //    continue;
        //  }

H
hong 已提交
554 555 556 557
        auto is_var_input_and_output =
            std::find(outside_og_names.begin(),
                      outside_og_names.end(),
                      pg_ig_names[param_id]) != outside_og_names.end();
558

Y
Yang Yang(Tony) 已提交
559 560 561
        // zero gradient variable in step 0
        if (cur_scope_iter == step_scopes->rbegin()) {
          auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
562
          PADDLE_ENFORCE_NOT_NULL(
563 564 565
              var,
              platform::errors::NotFound("Variable %s is not found.",
                                         inside_grad_name));
566
          PADDLE_ENFORCE_EQ(
C
chengduoZH 已提交
567
              var->IsType<framework::LoDTensorArray>() ||
568
                  var->IsType<phi::DenseTensor>(),
569 570 571
              true,
              platform::errors::InvalidArgument(
                  "Currently the type of var only can be LoDTensorArray, "
572
                  "or phi::DenseTensor, but the received var[%s] is %s.",
573 574
                  inside_grad_name,
                  framework::ToTypeName(var->Type())));
C
chengduo 已提交
575

H
hong 已提交
576
          if (!is_var_input_and_output && var->IsType<phi::DenseTensor>()) {
577
            auto &inside_tensor = var->Get<phi::DenseTensor>();
Y
Yang Yang(Tony) 已提交
578
            framework::AttributeMap attrs;
579 580
            attrs["dtype"] =
                framework::TransToProtoVarType(inside_tensor.dtype());
581
            attrs["shape"] = phi::vectorize<int>(inside_tensor.dims());
Y
Yang Yang(Tony) 已提交
582 583
            attrs["value"] = 0.0f;

C
chengduo 已提交
584
            auto var_name = pg_ig_names[param_id];
585 586 587 588 589
            auto zero_op =
                framework::OpRegistry::CreateOp("fill_constant",
                                                framework::VariableNameMap{},
                                                {{"Out", {var_name}}},
                                                attrs);
D
dzhwinter 已提交
590
            zero_op->Run(scope, dev_place);
591 592
            scope.FindVar(var_name)->GetMutable<phi::DenseTensor>()->set_lod(
                inside_tensor.lod());
Y
Yang Yang(Tony) 已提交
593 594
          }
        }
H
hong 已提交
595
        if (!is_var_input_and_output) {
596 597
          auto new_inside_name = cur_scope.Rename(inside_grad_name);
          auto sum_op = framework::OpRegistry::CreateOp(
598 599
              "sum",
              {{"X", {pg_ig_names[param_id], new_inside_name}}},
600 601 602 603
              {{"Out", {pg_ig_names[param_id]}}},
              framework::AttributeMap{{"use_mkldnn", {false}}});
          sum_op->Run(cur_scope, dev_place);
          cur_scope.Rename(new_inside_name, inside_grad_name);
H
hong 已提交
604 605
        } else {
          ShareVariable(cur_scope, scope, pg_ig_names[param_id]);
606
        }
Y
Yang Yang(Tony) 已提交
607
      }
608 609
      dev_ctx.Wait();
      const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
Y
Yang Yang(Tony) 已提交
610
    }
611
    step_scopes->clear();
Y
Yang Yang(Tony) 已提交
612
  }
613

H
hong 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
  void ShareVariable(const framework::Scope &source,
                     const framework::Scope &dest,
                     std::string name) const {
    auto from_var = source.FindVar(name);
    auto to_var = dest.FindVar(name);
    if (from_var->IsType<phi::DenseTensor>()) {
      if (from_var->Get<phi::DenseTensor>().IsInitialized()) {
        to_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
            from_var->Get<phi::DenseTensor>());
      }
    } else if (from_var->IsType<framework::LoDTensorArray>()) {
      auto from_arr = from_var->GetMutable<framework::LoDTensorArray>();
      auto to_arr = to_var->GetMutable<framework::LoDTensorArray>();
      to_arr->clear();
      to_arr->resize(from_arr->size());
      for (size_t i = 0; i < to_arr->size(); ++i) {
        if (from_arr->at(i).IsInitialized()) {
          to_arr->at(i).ShareDataWith(from_arr->at(i));
        }
      }
    }
  }

637 638 639 640
 private:
  mutable std::shared_ptr<framework::Executor> executor_{nullptr};
  mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
  mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
Y
Yang Yang(Tony) 已提交
641 642
};

H
hong 已提交
643 644
template <typename T>
class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
Y
Yang Yang(Tony) 已提交
645
 public:
H
hong 已提交
646
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Y
Yang Yang(Tony) 已提交
647 648

 protected:
649
  void Apply(GradOpPtr<T> while_grad) const override {
F
Update  
fengjiayi 已提交
650
    while_grad->SetType("while_grad");
H
hong 已提交
651 652 653
    while_grad->SetInput(kX, this->Input(kX));
    while_grad->SetInput(kOutputs, this->Output(kOutputs));
    while_grad->SetInput(kStepScopes, this->Output(kStepScopes));
F
Update  
fengjiayi 已提交
654 655

    auto *grad_block = this->grad_block_[0];
Y
Yu Yang 已提交
656 657
    auto *fwd_block = grad_block->ForwardBlock();
    auto *parent_block = grad_block->ParentBlock();
658 659 660

    // Not all of IGs will be generated by inner gradient operators of while op.
    // Ignore IGs that is not generated by the inside block.
F
Update  
fengjiayi 已提交
661 662 663 664
    std::unordered_set<std::string> inner_op_outputs;
    for (const auto *op : grad_block->AllOps()) {
      for (auto &oname : op->OutputArgumentNames()) {
        inner_op_outputs.insert(oname);
665 666
      }
    }
H
hong 已提交
667 668
    auto igs = this->InputGrad(kX, /*do not drop empty gradient*/ false);

669
    for (auto &each_ig : igs) {
F
Update  
fengjiayi 已提交
670
      if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
M
minqiyang 已提交
671
        VLOG(8) << "Ignore " << each_ig;
672 673 674
        each_ig = framework::kEmptyVarName;
      }
    }
F
Update  
fengjiayi 已提交
675
    while_grad->SetOutput(framework::GradVarName(kX), igs);
Y
Yang Yang(Tony) 已提交
676 677 678 679

    // OG should be re-calculated by step blocks, since many outputs of while op
    // do not need to calculate gradients.
    std::unordered_set<std::string> block_ins;
H
hong 已提交
680 681
    block_ins.reserve(this->Input(kX).size() + this->Output(kOutputs).size());
    for (auto &p : this->Input(kX)) {
F
fengjiayi 已提交
682 683
      block_ins.insert(p);
    }
H
hong 已提交
684
    for (auto &o : this->Output(kOutputs)) {
F
fengjiayi 已提交
685 686
      block_ins.insert(o);
    }
Y
Yu Yang 已提交
687
    std::unordered_set<std::string> output_grads;
H
hong 已提交
688

F
Update  
fengjiayi 已提交
689 690 691 692
    for (const auto *op : grad_block->AllOps()) {
      for (auto &input_name : op->InputArgumentNames()) {
        // If the input of Op has been recorded or is generated by the forward
        // block, do not make it as input again.
Y
Yu Yang 已提交
693 694 695

        // The input is located in I/O or other op's outputs or the variable is
        // located in grad_block's parents
F
Update  
fengjiayi 已提交
696
        if (block_ins.find(input_name) != block_ins.end() ||
Y
Yu Yang 已提交
697 698
            (fwd_block->FindVarRecursive(input_name) != nullptr ||
             parent_block->FindVarRecursive(input_name) != nullptr)) {
Y
Yang Yang(Tony) 已提交
699 700
          continue;
        }
Y
Yu Yang 已提交
701
        output_grads.insert(input_name);
Y
Yang Yang(Tony) 已提交
702
      }
F
Update  
fengjiayi 已提交
703
      for (auto &output_name : op->OutputArgumentNames()) {
Y
Yang Yang(Tony) 已提交
704
        block_ins.insert(output_name);
Y
Yang Yang(Tony) 已提交
705 706
      }
    }
Y
Yang Yang(Tony) 已提交
707

Y
Yu Yang 已提交
708 709
    std::vector<std::string> output_grads_list;
    output_grads_list.resize(output_grads.size());
710 711
    std::copy(
        output_grads.begin(), output_grads.end(), output_grads_list.begin());
Y
Yu Yang 已提交
712
    while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
F
Update  
fengjiayi 已提交
713 714

    while_grad->SetAttrMap(this->Attrs());
A
Abhinav Arora 已提交
715
    while_grad->SetBlockAttr(kStepBlock, grad_block);
Y
Yang Yang(Tony) 已提交
716 717
    // record the original output gradient names, since the gradient name of
    // while operator could be renamed.
Y
Yu Yang 已提交
718
    while_grad->SetAttr("original_output_grad", output_grads_list);
Y
Yang Yang(Tony) 已提交
719

S
sneaxiy 已提交
720
    while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
Y
Yang Yang(Tony) 已提交
721 722 723
  }
};

724 725
class WhileGradOpVarTypeInference
    : public framework::StaticGraphVarTypeInference {
Y
Yang Yang(Tony) 已提交
726
 public:
M
minqiyang 已提交
727
  void operator()(framework::InferVarTypeContext *ctx) const override {
728 729
    auto p_names = Input(ctx, kX);
    auto pg_ig_names = Output(ctx, framework::GradVarName(kX));
Y
Yang Yang(Tony) 已提交
730 731

    for (size_t i = 0; i < p_names.size(); ++i) {
732
      if (HasVar(ctx, pg_ig_names[i])) {
M
minqiyang 已提交
733
        VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
734 735 736
                << " type: " << GetType(ctx, p_names[i]);
        SetType(ctx, pg_ig_names[i], GetType(ctx, p_names[i]));
        SetDataType(ctx, pg_ig_names[i], GetDataType(ctx, p_names[i]));
Y
Yang Yang(Tony) 已提交
737 738 739 740 741 742 743 744
      }
    }
  }
};

class WhileGradOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
Y
Yang Yu 已提交
745 746
    ctx->HasInputs(kX);
    ctx->HasOutputs(framework::GradVarName(kX));
Y
Yang Yang(Tony) 已提交
747 748
    ctx->HasInputs(kOutputs);
    ctx->HasInputs(framework::GradVarName(kOutputs));
C
chengduo 已提交
749
    auto pg_ig_names = ctx->Outputs(kXGRAD);
750 751
    auto in_var_ptrs = ctx->GetInputVarPtrs(kX);
    auto out_var_ptrs = ctx->GetOutputVarPtrs(kXGRAD);
752 753
    PADDLE_ENFORCE_EQ(in_var_ptrs.size(),
                      out_var_ptrs.size(),
754 755 756
                      platform::errors::InvalidArgument(
                          "The size of Inputs(X) must be the same as "
                          "the size of Outputs(X@GRAD)."));
X
Xin Pan 已提交
757 758

    for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
C
chengduo 已提交
759
      if (pg_ig_names[i] == framework::kEmptyVarName) {
Y
Yang Yang(Tony) 已提交
760 761
        continue;
      }
762
      framework::VarDesc *in_var =
R
Ruibiao Chen 已提交
763 764
          PADDLE_GET(framework::VarDesc *, in_var_ptrs[i]);
      PADDLE_GET(framework::VarDesc *, out_var_ptrs[i])
765
          ->SetShape(in_var->GetShape());
Y
Yang Yang(Tony) 已提交
766 767 768 769
    }
  }
};

Y
Yang Yang(Tony) 已提交
770 771 772
}  // namespace operators
}  // namespace paddle

H
hong 已提交
773
REGISTER_OPERATOR(
774 775 776
    while,
    paddle::operators::WhileOp,
    paddle::operators::WhileOpMaker,
H
hong 已提交
777
    paddle::operators::WhileGradOpMaker<paddle::framework::OpDesc>);
778 779
REGISTER_OPERATOR(while_grad,
                  paddle::operators::WhileGradOp,
Y
Yang Yang(Tony) 已提交
780 781
                  paddle::operators::WhileGradOpShapeInference,
                  paddle::operators::WhileGradOpVarTypeInference);