layer.cc 22.6 KB
Newer Older
J
Jiabin Yang 已提交
1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/layer.h"
16

J
Jiabin Yang 已提交
17
#include "paddle/fluid/eager/eager_tensor.h"
18
#include "paddle/fluid/framework/convert_utils.h"
19
#include "paddle/fluid/framework/op_registry.h"
20 21
#include "paddle/fluid/imperative/infer_var_type_context.h"
#include "paddle/fluid/imperative/op_base.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/imperative/prepared_operator.h"
J
Jiabin Yang 已提交
23
#include "paddle/fluid/imperative/var_helper.h"
M
minqiyang 已提交
24
#include "paddle/fluid/platform/device_context.h"
J
Jiabin Yang 已提交
25
#include "paddle/fluid/platform/enforce.h"
C
chengduo 已提交
26
#include "paddle/fluid/platform/profiler.h"
27
#include "paddle/phi/kernels/funcs/math_function.h"
28 29 30 31 32
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

DECLARE_bool(use_mkldnn);
33 34 35
namespace paddle {
namespace imperative {

J
Jiabin Yang 已提交
36
using framework::Variable;
Z
Zeng Jinle 已提交
37 38 39 40 41 42 43 44
void ThreadSafeNameSet::Insert(const std::string& name) {
  std::lock_guard<std::mutex> guard(mtx_);
  set_.insert(name);
}

void ThreadSafeNameSet::Remove(const std::string& name) {
  std::lock_guard<std::mutex> guard(mtx_);
  auto iter = set_.find(name);
45
  PADDLE_ENFORCE_EQ(
46 47
      iter != set_.end(),
      true,
48
      platform::errors::NotFound("Variable name %s does not exist", name));
Z
Zeng Jinle 已提交
49 50 51 52 53 54 55 56 57 58 59 60
  set_.erase(iter);
}

std::vector<std::string> ThreadSafeNameSet::Names() const {
  std::lock_guard<std::mutex> guard(mtx_);
  return std::vector<std::string>(set_.begin(), set_.end());
}

ThreadSafeNameSet VarBase::name_set_;

std::vector<std::string> VarBase::AliveVarNames() { return name_set_.Names(); }

J
Jiabin Yang 已提交
61 62 63 64 65 66 67 68 69
static framework::RuntimeContext PrepareRuntimeContext(
    const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
  framework::VariableValueMap inputs, outputs;
  for (auto& in_pair : ins) {
    auto& in_ctx = inputs[in_pair.first];
    in_ctx.reserve(in_pair.second.size());
    for (auto& in_var : in_pair.second) {
      in_ctx.emplace_back(in_var->MutableVar());
    }
M
minqiyang 已提交
70 71
  }

J
Jiabin Yang 已提交
72 73 74 75 76
  for (auto& out_pair : outs) {
    auto& out_ctx = outputs[out_pair.first];
    out_ctx.reserve(out_pair.second.size());
    for (auto& out_var : out_pair.second) {
      out_ctx.emplace_back(out_var->MutableVar());
77
    }
J
Jiabin Yang 已提交
78 79 80 81
  }
  return framework::RuntimeContext(std::move(inputs), std::move(outputs));
}

82
template <typename VarType>
J
Jiabin Yang 已提交
83 84
static std::string DebugString(
    const std::string& name,
85
    const std::vector<std::shared_ptr<VarType>>& vars) {
J
Jiabin Yang 已提交
86 87
  std::stringstream ss;
  ss << name << "{";
M
minqiyang 已提交
88

J
Jiabin Yang 已提交
89 90
  for (size_t i = 0; i < vars.size(); ++i) {
    if (i > 0) ss << ", ";
M
minqiyang 已提交
91

J
Jiabin Yang 已提交
92 93 94 95
    if (vars[i] == nullptr) {
      ss << "NULL";
      continue;
    }
J
Jiabin Yang 已提交
96
    ss << GetNameFromVar(vars[i]) << "[";
97
    const framework::Variable& var = vars[i]->Var();
J
Jiabin Yang 已提交
98 99 100 101 102 103
    if (!var.IsInitialized()) {
      ss << "NOT_INITED_VAR";
    } else if (var.IsType<framework::LoDTensor>()) {
      auto& tensor = var.Get<framework::LoDTensor>();
      ss << "LoDTensor<";
      if (tensor.IsInitialized()) {
104 105 106
        ss << framework::DataTypeToString(
                  framework::TransToProtoVarType(tensor.dtype()))
           << ", ";
J
Jiabin Yang 已提交
107 108 109 110 111 112
        ss << tensor.place() << ", ";
        ss << "(" << tensor.dims() << ")";
      } else {
        ss << "NOT_INITED";
      }
      ss << ">";
113
    } else if (var.IsType<phi::SelectedRows>()) {
114
      ss << "SelectedRows<";
115
      auto& selected_rows = var.Get<phi::SelectedRows>();
116 117 118
      auto& tensor = selected_rows.value();
      auto& rows = selected_rows.rows();
      if (tensor.IsInitialized()) {
119 120 121
        ss << framework::DataTypeToString(
                  framework::TransToProtoVarType(tensor.dtype()))
           << ", ";
122 123
        ss << tensor.place() << ", ";
        ss << "height(" << selected_rows.height() << "), rows(";
124 125 126
        std::for_each(rows.cbegin(), rows.cend(), [&ss](const int64_t r) {
          ss << r << " ";
        });
127 128 129 130 131
        ss << "), dims(" << tensor.dims() << ")";
      } else {
        ss << "NOT_INITED";
      }
      ss << ">";
J
Jiabin Yang 已提交
132 133 134 135
    } else {
      ss << "UNRESOLVED_TYPE";
    }
    ss << "]";
136
  }
137

J
Jiabin Yang 已提交
138 139
  ss << "}";
  return ss.str();
140 141
}

142 143 144 145
template <typename VarType>
static std::string LayerDebugStringImpl(const std::string& op_type,
                                        const NameVarMap<VarType>& ins,
                                        const NameVarMap<VarType>& outs) {
J
Jiabin Yang 已提交
146 147 148 149 150 151 152 153
  std::stringstream ss;
  ss << "Op(" << op_type << "): ";

  ss << "Inputs: ";

  size_t i = 0;
  for (auto& pair : ins) {
    if (i > 0) ss << ", ";
154
    ss << DebugString<VarType>(pair.first, pair.second);
J
Jiabin Yang 已提交
155
    ++i;
156 157
  }

J
Jiabin Yang 已提交
158 159 160 161
  ss << ",   Outputs: ";
  i = 0;
  for (auto& pair : outs) {
    if (i > 0) ss << ", ";
162
    ss << DebugString<VarType>(pair.first, pair.second);
J
Jiabin Yang 已提交
163 164 165 166
    ++i;
  }
  return ss.str();
}
167

168 169 170 171 172 173 174 175 176 177
std::string LayerDebugString(const std::string& op_type,
                             const NameVarMap<VarBase>& ins,
                             const NameVarMap<VarBase>& outs) {
  return LayerDebugStringImpl<VarBase>(op_type, ins, outs);
}

std::string LayerDebugString(const std::string& op_type,
                             const NameVarMap<VariableWrapper>& ins,
                             const NameVarMap<VariableWrapper>& outs) {
  return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
J
Jiabin Yang 已提交
178
}
179

J
Jiabin Yang 已提交
180
std::string LayerDebugString(const std::string& op_type,
181 182 183
                             const NameVarMap<egr::EagerVariable>& ins,
                             const NameVarMap<egr::EagerVariable>& outs) {
  return LayerDebugStringImpl<egr::EagerVariable>(op_type, ins, outs);
J
Jiabin Yang 已提交
184 185 186 187 188 189
}

template <typename VarType>
static void SetForwardDataTypeOfGradVars(const NameVarMap<VarType>& outs) {
  for (auto& var_pair : outs) {
    for (auto& var : var_pair.second) {
190
      // NOTE(zhiqu): The output may be NULL because of pruning.
J
Jiabin Yang 已提交
191 192 193 194 195 196 197
      if (var) {
        SetForwardDataTypeOfGradVar(var);
      }
    }
  }
}
template <>
198 199
void SetForwardDataTypeOfGradVars<egr::EagerVariable>(
    const NameVarMap<egr::EagerVariable>& outs) {
J
Jiabin Yang 已提交
200 201 202
  // In eager mode we don't need this.
}

203 204 205 206 207
void TestSetForwardDataTypeOfGradVarsEager(
    const NameVarMap<egr::EagerVariable>& outs) {
  SetForwardDataTypeOfGradVars<egr::EagerVariable>(outs);
}

208
VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var)
209
    : var_(var), grad_node_(var->GetGradNode()) {
210 211
  if (auto grad_var = var_->GetGradVar()) {
    grad_var_ = std::make_shared<VarBase>(grad_var);
212 213 214 215 216 217 218 219 220 221 222 223
  }

  if (IsDebugEnabled()) {
    VLOG(10) << "Construct VarBase: " << Name();
    name_set_.Insert(Name());
  }
}

size_t VarBase::GradOpNum() const {
  return grad_node_ ? grad_node_->size() : 0;
}

224
void VarBase::ClearGradient(bool set_to_zero) {
225
  VLOG(4) << "ClearGradient " << Name();
J
Jiabin Yang 已提交
226
  if (grad_var_) {
227 228
    if (grad_var_->Var().IsType<phi::SelectedRows>()) {
      auto* grad_t = grad_var_->MutableVar()->GetMutable<phi::SelectedRows>();
229
      if (grad_t->mutable_value()->IsInitialized()) {
230
#ifdef PADDLE_WITH_MKLDNN
231
        if (FLAGS_use_mkldnn) platform::ClearMKLDNNCache(grad_t->place());
232
#endif
233 234 235 236
        grad_t->mutable_rows()->clear();
        grad_t->mutable_value()->clear();
      }
    } else {
237 238
      platform::RecordEvent record_event(
          "ClearGradient", platform::TracerEventType::UserDefined, 2);
239 240
      auto* grad_t =
          grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
241
      if (grad_t->IsInitialized()) {
242 243 244
        if (set_to_zero) {
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
245
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
246 247 248
        } else {
          grad_t->clear();
        }
249
#ifdef PADDLE_WITH_MKLDNN
250
        if (FLAGS_use_mkldnn) platform::ClearMKLDNNCache(grad_t->place());
251
#endif
252
      }
253
    }
254 255 256 257
    // TODO(zhouwei): It's better to free memory of grad by grad_t->claer.
    // But will have some bug on mac CPU of yolov3 model, why?
    // After fix this bug, function SetIsEmpty() isn't need
    grad_var_->SharedVar()->SetIsEmpty(true);
258
  }
J
Jiabin Yang 已提交
259
}
260

261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
void VarBase::_GradientSetEmpty(bool is_empty) {
  VLOG(4) << "Set gradient " << Name() << " is_empty:" << is_empty;
  if (grad_var_) {
    auto share_var = grad_var_->SharedVar();
    if (share_var) {
      share_var->SetIsEmpty(is_empty);
    }
  }
}

bool VarBase::_IsGradientSetEmpty() {
  bool res = true;
  if (grad_var_) {
    auto share_var = grad_var_->SharedVar();
    if (share_var) {
      res = share_var->is_empty_;
      VLOG(4) << "Check gradient " << Name() << " is empty:" << res;
    }
  }
  return res;
}

J
Jiabin Yang 已提交
283
std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
M
minqiyang 已提交
284
                                             const bool blocking) const {
285
  PADDLE_ENFORCE_EQ(
286
      Var().IsInitialized() && (Var().IsType<framework::LoDTensor>() ||
287
                                Var().IsType<phi::SelectedRows>()),
288 289 290 291
      true,
      platform::errors::InvalidArgument(
          "Variable is not initialized or Variable's type is not "
          "LoDTensor or SelectedRows when getting numpy tensor"));
292

293 294
  if (Var().IsType<framework::LoDTensor>()) {
    auto& src_tensor = Var().Get<framework::LoDTensor>();
295 296
    // TODO(Jiabin): change this after move unique_name generator to CXX
    auto new_var = std::make_shared<VarBase>(
297
        true, Name() + std::to_string(copied_counter_++));
298

299 300
    auto* dst_tensor =
        new_var->MutableVar()->GetMutable<framework::LoDTensor>();
301
    dst_tensor->set_lod(src_tensor.lod());
302 303 304
    new_var->SetPersistable(Persistable());
    new_var->SetDataType(DataType());
    new_var->SetType(Type());
305 306 307 308 309 310 311
    framework::TensorCopy(src_tensor, dst_place, dst_tensor);
    if (blocking) {
      platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
      auto src_place = src_tensor.place();
      if (!(src_place == dst_place)) {
        platform::DeviceContextPool::Instance().Get(src_place)->Wait();
      }
312
    }
313 314
    VLOG(4) << "copy tensor " << Name() << " from " << Place() << " to "
            << dst_place;
315 316
    return new_var;
  } else {
317
    auto& src_selected_rows = Var().Get<phi::SelectedRows>();
318 319 320 321
    auto new_var = std::make_shared<VarBase>(
        false, "Itmp" + std::to_string(copied_counter_++));
    new_var->SetType(framework::proto::VarType::SELECTED_ROWS);
    auto* dst_selected_rows =
322
        new_var->MutableVar()->GetMutable<phi::SelectedRows>();
323

324 325
    framework::TensorCopy(src_selected_rows.value(),
                          dst_place,
326 327 328 329 330 331 332 333 334 335
                          dst_selected_rows->mutable_value());
    if (blocking) {
      platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
      auto src_place = src_selected_rows.place();
      if (!(src_place == dst_place)) {
        platform::DeviceContextPool::Instance().Get(src_place)->Wait();
      }
    }
    dst_selected_rows->set_height(src_selected_rows.height());
    dst_selected_rows->set_rows(src_selected_rows.rows());
336 337
    VLOG(4) << "copy tensor " << Name() << " from " << Place() << " to "
            << dst_place;
338 339
    return new_var;
  }
M
minqiyang 已提交
340 341
}

342
void VarBase::CopyFrom(const VarBase& src, const bool blocking) {
343 344 345 346 347 348
  if (src.SharedVar()->IsEmpty()) {
    return;
  }

  VLOG(3) << "Deep copy Tensor from " << src.Name() << " to " << Name();
  if (Var().IsInitialized()) {
349 350
    PADDLE_ENFORCE_EQ(DataType(),
                      src.DataType(),
351 352 353
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s, "
                          "Tensor Copy cannot be performed!",
354 355 356 357
                          Name(),
                          src.Name()));
    PADDLE_ENFORCE_EQ(Type(),
                      src.Type(),
358 359 360
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "Copy cannot be performed!",
361 362
                          Name(),
                          src.Name()));
363
  } else {
364 365
    SetDataType(src.DataType());
    SetType(src.Type());
366 367 368 369 370 371 372 373 374
    SetPersistable(src.Persistable());
    InnerSetOverridedStopGradient(src.OverridedStopGradient());
  }

  platform::Place place = src.Place();
  if (src.Var().IsType<framework::LoDTensor>()) {
    auto& src_tensor = src.Var().Get<framework::LoDTensor>();
    auto* dst_tensor = MutableVar()->GetMutable<framework::LoDTensor>();
    if (dst_tensor && dst_tensor->IsInitialized()) {
375 376
      PADDLE_ENFORCE_EQ(dst_tensor->dims(),
                        src_tensor.dims(),
377 378 379
                        platform::errors::PreconditionNotMet(
                            "Tensor %s has different dims with Tensor %s, "
                            "Tensor Copy cannot be performed!",
380 381 382 383
                            Name(),
                            src.Name()));
      PADDLE_ENFORCE_EQ(dst_tensor->lod(),
                        src_tensor.lod(),
384 385 386
                        platform::errors::PreconditionNotMet(
                            "Tensor %s has different dims with Tensor %s, "
                            "Tensor Copy cannot be performed!",
387 388
                            Name(),
                            src.Name()));
389 390 391 392 393 394
      place = Place();
    } else {
      dst_tensor->set_lod(src_tensor.lod());
      dst_tensor->Resize(src_tensor.dims());
    }
    framework::TensorCopy(src_tensor, place, dst_tensor);
395 396 397
  } else if (src.Var().IsType<phi::SelectedRows>()) {
    auto& src_selected_rows = src.Var().Get<phi::SelectedRows>();
    auto* dst_selected_rows = MutableVar()->GetMutable<phi::SelectedRows>();
398 399 400 401 402 403
    dst_selected_rows->set_height(src_selected_rows.height());
    dst_selected_rows->set_rows(src_selected_rows.rows());

    auto& src_tensor = src_selected_rows.value();
    auto* dst_tensor = dst_selected_rows->mutable_value();
    if (dst_tensor && dst_tensor->IsInitialized()) {
404 405
      PADDLE_ENFORCE_EQ(dst_tensor->dims(),
                        src_tensor.dims(),
406 407 408
                        platform::errors::PreconditionNotMet(
                            "Tensor %s has different dims with Tensor %s, "
                            "Tensor Copy cannot be performed!",
409 410
                            Name(),
                            src.Name()));
411 412 413
      place = Place();
    } else {
      dst_tensor->Resize(src_tensor.dims());
414
    }
415 416 417 418
    framework::TensorCopy(src_tensor, place, dst_tensor);
  }
  if (blocking) {
    platform::DeviceContextPool::Instance().Get(place)->Wait();
419 420 421
  }
}

422 423
void VarBase::BumpInplaceVersion() {
  PADDLE_ENFORCE_EQ(
424 425
      Var().IsInitialized(),
      true,
426 427 428 429 430 431
      platform::errors::InvalidArgument(
          "Tensor %s has not been initialized, please check if it has no data.",
          Name()));
  MutableVar()->BumpInplaceVersion();
}

432 433 434 435 436
// NOTE(weilong wu):
// This function try to copy the data from target varbase,
// and fill into the grad_var_ of the current varbase.
void VarBase::_CopyGradientFrom(const VarBase& src) {
  if (Var().IsInitialized()) {
437 438
    PADDLE_ENFORCE_EQ(DataType(),
                      src.DataType(),
439 440
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
441 442 443 444
                          Name(),
                          src.Name()));
    PADDLE_ENFORCE_EQ(Type(),
                      src.Type(),
445 446 447
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
448 449
                          Name(),
                          src.Name()));
450 451 452 453
  }
  VLOG(4) << " VarBase copy gradient with " << src.Name();
  if (grad_var_) {
    auto& src_tensor = src.Var().Get<framework::LoDTensor>();
454 455
    PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(),
                      true,
456
                      platform::errors::InvalidArgument(
457
                          "Tensor %s has not been initialized", src.Name()));
458 459 460 461 462 463 464
    auto* grad_t = grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
    auto* var_ = MutableVar()->GetMutable<framework::LoDTensor>();
    grad_t->ShareDataWith(src_tensor);
    grad_t->Resize(var_->dims());
  }
}

465
void OpBase::SetType(const std::string& type) {
H
hong 已提交
466
  op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
J
Jiabin Yang 已提交
467
}
468

469 470 471
void OpBase::ClearBackwardTrace() {
  ins_.clear();
  outs_.clear();
H
hong 已提交
472 473
}

474 475 476 477 478
template <typename VarType>
static void OpBaseRunImpl(const framework::OperatorBase& op,
                          const NameVarMap<VarType>& ins,
                          const NameVarMap<VarType>& outs,
                          const framework::AttributeMap& attrs,
479
                          const framework::AttributeMap& default_attrs,
480
                          const platform::Place& place) {
481
  auto* op_kernel = static_cast<const framework::OperatorWithKernel*>(&op);
482
  PADDLE_ENFORCE_NOT_NULL(
483 484 485
      op_kernel,
      platform::errors::PermissionDenied(
          "Only support operator with kernel in Dygraph mode."));
486
  auto& info = op.Info();
J
Jiabin Yang 已提交
487
  if (info.infer_var_type_) {
488 489
    RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(
        ins, outs, attrs, default_attrs);
J
Jiabin Yang 已提交
490
    info.infer_var_type_(&infer_var_type_ctx);
X
Xin Pan 已提交
491
  }
492

J
Jiabin Yang 已提交
493 494 495
  // Initialize output var type
  for (auto& var_pair : outs) {
    for (auto& var : var_pair.second) {
496
      if (var) {
J
Jiabin Yang 已提交
497
        InitializeVariable(var->MutableVar(), GetType(var));
498
      }
499 500
    }
  }
X
Xin Pan 已提交
501

502
  VLOG(5) << LayerDebugString(op.Type(), ins, outs);
503

504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
  /**
   * [ Why need temporary inputs here? ]
   *
   * PrepareData should not change original input tensor inplace.
   * Suppose the user defines a tensor(int), enters an op to execute,
   * and then this op rewrites GetExpectedKernelForVar, and converts
   * this tensor to float type during execution. After the dynamic
   * graph is executed, the user-defined variable will be lost, and
   * the user cannot get the originally defined int tensor, because
   * it has been converted to float, this should be regarded as a bug
   * in certain usage scenarios
   *
   * In static graph mode, when op is executed, a temporary scope
   * `transfer_scope` is created before PrepareData, the data after
   * transform is stored in the temporary scope, and then discarded
   * after the execution of op, but the original input is directly
   * overwritten in the previous dynamic graph implemention.
   */
522 523
  auto prepared_op =
      PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
524 525 526
  auto tmp_ins_ptr =
      PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
  if (tmp_ins_ptr == nullptr) {
527
    prepared_op.Run(ins, outs, attrs, default_attrs);
528
  } else {
529
    prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs);
530
  }
531

532
  VLOG(4) << LayerDebugString(op.Type(), ins, outs);
533 534

  // set the output var
J
Jiabin Yang 已提交
535
  SetForwardDataTypeOfGradVars<VarType>(outs);
536 537
}

538 539 540 541
void OpBase::Run(const framework::OperatorBase& op,
                 const NameVarMap<VarBase>& ins,
                 const NameVarMap<VarBase>& outs,
                 const framework::AttributeMap& attrs,
542
                 const framework::AttributeMap& default_attrs,
543
                 const platform::Place& place) {
544
  OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
545 546 547 548 549 550
}

void OpBase::Run(const framework::OperatorBase& op,
                 const NameVarMap<VariableWrapper>& ins,
                 const NameVarMap<VariableWrapper>& outs,
                 const framework::AttributeMap& attrs,
551
                 const framework::AttributeMap& default_attrs,
552
                 const platform::Place& place) {
553
  OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
554 555
}

J
Jiabin Yang 已提交
556
void OpBase::Run(const framework::OperatorBase& op,
557 558
                 const NameVarMap<egr::EagerVariable>& ins,
                 const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
559 560 561
                 const framework::AttributeMap& attrs,
                 const framework::AttributeMap& default_attrs,
                 const platform::Place& place) {
562
  OpBaseRunImpl<egr::EagerVariable>(op, ins, outs, attrs, default_attrs, place);
J
Jiabin Yang 已提交
563 564
}

565
void ClearNoNeedBufferInputs(OpBase* op) {
566 567 568 569 570 571 572 573 574 575 576 577 578
  auto& inferer = op->Info().NoNeedBufferVarsInferer();
  if (!inferer) return;
  auto* ins = op->GetMutableInsMap();
  const auto& no_need_buffer_slots =
      inferer(*ins, op->GetOutsMap(), op->Attrs());
  if (no_need_buffer_slots.empty()) return;

  for (auto& slot : no_need_buffer_slots) {
    auto iter = ins->find(slot);
    if (iter == ins->end()) continue;
    VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type();

    PADDLE_ENFORCE_EQ(
579 580
        iter->second.IsGrad(),
        false,
581 582 583 584 585 586 587
        platform::errors::InvalidArgument(
            "Only forward variable buffers can be clear, this may be a bug"));

    for (auto& each_var : *(iter->second.MutableVarList())) {
      if (!each_var) continue;

      auto& var = each_var->Var();
588 589
      PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(),
                        true,
590 591 592 593 594 595 596 597
                        platform::errors::PermissionDenied(
                            "NoNeedBufferVars only support LoDTensor"));
      auto new_var = new VariableWrapper(each_var->Name());
      auto* new_tensor =
          new_var->MutableVar()->GetMutable<framework::LoDTensor>();
      auto& old_tensor = var.Get<framework::LoDTensor>();
      new_tensor->Resize(old_tensor.dims());
      new_tensor->set_lod(old_tensor.lod());
598 599
      new_tensor->set_type(old_tensor.dtype());
      new_tensor->set_layout(old_tensor.layout());
600 601 602 603 604 605
      each_var.reset(new_var);
    }
  }
}

std::shared_ptr<GradOpNode> CreateGradOpNode(
606 607 608 609 610 611
    const framework::OperatorBase& op,
    const NameVarBaseMap& ins,
    const NameVarBaseMap& outs,
    const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs,
    const platform::Place& place,
612
    const std::map<std::string, std::string>& inplace_map) {
613 614 615 616 617
  const auto& info = op.Info();
  if (!info.dygraph_grad_op_maker_) {
    return nullptr;
  }

618 619
  auto grad_node = info.dygraph_grad_op_maker_(
      op.Type(), ins, outs, attrs, default_attrs, inplace_map);
620
  if (grad_node && !grad_node->empty()) {
621 622 623 624
    for (auto& grad_op : *grad_node) {
      grad_op.SetId(OpBase::GenerateUniqueId());
      grad_op.SetPlace(place);
      ClearNoNeedBufferInputs(&grad_op);
625 626 627 628 629 630 631
    }
    return grad_node;
  } else {
    return nullptr;
  }
}

J
Jiabin Yang 已提交
632
std::shared_ptr<GradOpNode> CreateGradOpNode(
633 634 635 636 637 638
    const framework::OperatorBase& op,
    const NameTensorMap& ins,
    const NameTensorMap& outs,
    const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs,
    const platform::Place& place,
J
Jiabin Yang 已提交
639 640 641 642 643
    const std::map<std::string, std::string>& inplace_map) {
  // Do Nothing in Eager Mode.
  return nullptr;
}

644 645
}  // namespace imperative
}  // namespace paddle