layer.h 30.6 KB
Newer Older
J
Jiabin Yang 已提交
1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
J
Jiabin Yang 已提交
16 17
#include <algorithm>
#include <atomic>
Z
Zeng Jinle 已提交
18
#include <cstdint>
J
Jiabin Yang 已提交
19
#include <list>
20 21 22
#include <map>
#include <memory>
#include <mutex>  // NOLINT
Z
Zeng Jinle 已提交
23
#include <set>
24 25 26
#include <string>
#include <unordered_map>
#include <unordered_set>
27
#include <utility>
J
Jiabin Yang 已提交
28
#include <vector>
29 30
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
H
hong 已提交
31 32 33 34
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
M
minqiyang 已提交
35
#include "paddle/fluid/framework/var_type_inference.h"
J
Jiabin Yang 已提交
36
#include "paddle/fluid/framework/variable.h"
Z
Zeng Jinle 已提交
37
#include "paddle/fluid/imperative/flags.h"
J
Jiabin Yang 已提交
38
#include "paddle/fluid/imperative/type_defs.h"
39
#include "paddle/fluid/imperative/variable_wrapper.h"
J
Jiabin Yang 已提交
40 41
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
M
minqiyang 已提交
42

43 44 45 46 47
namespace paddle {
namespace imperative {

class OpBase;

Z
Zeng Jinle 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
class ThreadSafeNameSet {
 public:
  void Insert(const std::string& name);

  void Remove(const std::string& name);

  std::vector<std::string> Names() const;

 private:
  std::multiset<std::string> set_;
  mutable std::mutex mtx_;
};

61
class VarBase {
J
Jiabin Yang 已提交
62 63
  DISABLE_COPY_AND_ASSIGN(VarBase);

64
 public:
Z
Zeng Jinle 已提交
65
  static std::vector<std::string> AliveVarNames();
J
Jiabin Yang 已提交
66
  explicit VarBase(bool has_grad, const std::string& name)
67
      : var_(std::make_shared<VariableWrapper>(name)),
J
Jiabin Yang 已提交
68
        grad_var_(has_grad ? new VarBase(false, GradVarName()) : nullptr) {
Z
Zeng Jinle 已提交
69
    if (IsDebugEnabled()) {
70 71
      VLOG(10) << "Construct VarBase: " << Name();
      name_set_.Insert(Name());
Z
Zeng Jinle 已提交
72
    }
73
  }
74

J
Jiabin Yang 已提交
75 76 77
  explicit VarBase(const std::string& name) : VarBase(true, name) {}

  ~VarBase() {
78
    VLOG(10) << "Destruct VarBase: " << Name();
Z
Zeng Jinle 已提交
79
    if (IsDebugEnabled()) {
80
      name_set_.Remove(Name());
Z
Zeng Jinle 已提交
81
    }
M
minqiyang 已提交
82
  }
83

84
  const std::shared_ptr<VariableWrapper>& SharedVar() const { return var_; }
85

86 87 88
  const framework::Variable& Var() const { return var_->Var(); }

  framework::Variable* MutableVar() { return var_->MutableVar(); }
M
minqiyang 已提交
89

J
Jiabin Yang 已提交
90 91 92 93
  bool HasGradVar() const { return grad_var_ != nullptr; }

  const std::shared_ptr<VarBase>& GradVarBase() const { return grad_var_; }

94 95 96 97 98 99 100
  void ClearGradVarBase() { grad_var_ = nullptr; }

  const std::shared_ptr<VarBase>& MutableGradVarBase() {
    if (grad_var_ == nullptr) {
      grad_var_ = std::make_shared<VarBase>(false, GradVarName());
      // NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property same as
      // fwd varbase
101
      grad_var_->SetOverridedStopGradient(var_->InnerOverridedStopGradient());
102 103 104 105
    }
    return grad_var_;
  }

J
Jiabin Yang 已提交
106
  const framework::Variable& GradVar() const {
107 108 109 110
    PADDLE_ENFORCE_NOT_NULL(
        grad_var_,
        platform::errors::NotFound("Gradient of %s does not exist", Name()));
    return grad_var_->Var();
M
minqiyang 已提交
111
  }
M
minqiyang 已提交
112

J
Jiabin Yang 已提交
113
  framework::Variable* MutableGradVar() {
114 115 116 117
    PADDLE_ENFORCE_NOT_NULL(
        grad_var_,
        platform::errors::NotFound("Gradient of %s does not exist", Name()));
    return grad_var_->MutableVar();
J
Jiabin Yang 已提交
118
  }
X
Xin Pan 已提交
119

120
  void SetOverridedStopGradient(bool stop_gradient) {
121
    var_->SetOverridedStopGradient(stop_gradient);
J
Jiabin Yang 已提交
122
    if (grad_var_) {
123 124 125 126
      grad_var_->SetOverridedStopGradient(stop_gradient);
    }
  }

127
  bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
128 129

  void InnerSetOverridedStopGradient(bool stop_gradient) {
130 131
    if (var_->InnerOverridedStopGradient() == -1) {
      var_->InnerSetOverridedStopGradient(stop_gradient);
132 133 134 135 136
      if (grad_var_) {
        grad_var_->InnerSetOverridedStopGradient(stop_gradient);
      }
    }
  }
137

138
  void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
139

140
  bool Persistable() const { return var_->Persistable(); }
X
Xin Pan 已提交
141

142 143 144 145 146
  // Only grad var is allowed to call these 2 methods
  void AddGradOp(const std::shared_ptr<OpBase>& op) {
    if (op &&
        std::find(grad_ops_.begin(), grad_ops_.end(), op) == grad_ops_.end()) {
      grad_ops_.emplace_back(op);
M
minqiyang 已提交
147
    }
148 149 150 151
  }

  const std::vector<std::shared_ptr<OpBase>>& GradOps() const {
    return grad_ops_;
X
Xin Pan 已提交
152
  }
153

J
Jiabin Yang 已提交
154
  void ClearGradOps() { grad_ops_.clear(); }
X
Xin Pan 已提交
155

156
  const std::string& Name() const { return var_->Name(); }
M
minqiyang 已提交
157

J
Jiabin Yang 已提交
158
  void SetName(const std::string& name) {
159
    var_->SetName(name);
J
Jiabin Yang 已提交
160 161 162
    if (grad_var_) {
      grad_var_->SetName(GradVarName());
    }
M
minqiyang 已提交
163 164
  }

165
  std::string GradVarName() { return framework::GradVarName(Name()); }
166

167
  void SetType(framework::proto::VarType::Type type) { var_->SetType(type); }
168

169
  framework::proto::VarType::Type Type() const { return var_->Type(); }
170

J
Jiabin Yang 已提交
171
  void SetDataType(framework::proto::VarType::Type data_type) {
172
    var_->SetDataType(data_type);
J
Jiabin Yang 已提交
173
    if (grad_var_) {
174
      grad_var_->SetDataType(data_type);
175 176 177
    }
  }

178
  framework::proto::VarType::Type DataType() const { return var_->DataType(); }
X
polish  
Xin Pan 已提交
179

J
Jiabin Yang 已提交
180
  void ClearGradient();
X
Xin Pan 已提交
181

J
Jiabin Yang 已提交
182 183
  std::shared_ptr<VarBase> NewVarBase(const platform::Place& dst_place,
                                      const bool blocking) const;
M
minqiyang 已提交
184

J
Jiabin Yang 已提交
185
 private:
186 187 188 189 190 191 192
  /**
   * NOTE(zengjinle): never remove the const qualifier of `var_` if you are
   * not very familiar with the autograd idea (including the higher order
   * derivative).
   */
  const std::shared_ptr<VariableWrapper> var_;

J
Jiabin Yang 已提交
193
  std::shared_ptr<VarBase> grad_var_;
194
  std::vector<std::shared_ptr<OpBase>> grad_ops_;
H
hong 已提交
195

J
Jiabin Yang 已提交
196
  mutable size_t copied_counter_ = 0;
197

J
Jiabin Yang 已提交
198
  static ThreadSafeNameSet name_set_;
199 200
};

201 202
using VariableWrapperList = std::vector<std::shared_ptr<VariableWrapper>>;

203 204 205 206
class Layer {
 public:
  virtual ~Layer() {}

207 208
  virtual std::vector<std::shared_ptr<VarBase>> Forward(
      const std::vector<std::shared_ptr<VarBase>>& inputs) {
J
Jiabin Yang 已提交
209
    return {};
210
  }
X
Xin Pan 已提交
211
};
212

213
template <typename VarType>
H
hong 已提交
214 215 216 217 218 219 220 221 222
class DygraphExecutionContext : public framework::ExecutionContext {
  using Variable = framework::Variable;

 public:
  DygraphExecutionContext(const framework::OperatorBase& op,
                          const framework::Scope& scope,
                          const platform::DeviceContext& device_context,
                          const framework::RuntimeContext& ctx,
                          std::vector<framework::KernelConfig>* configs,
223 224 225
                          const NameVarMap<VarType>& var_base_map_in,
                          const NameVarMap<VarType>& var_base_map_out,
                          const framework::AttributeMap& attrs)
H
hong 已提交
226 227 228 229 230
      : ExecutionContext(op, scope, device_context, ctx, configs),
        var_base_map_in_(var_base_map_in),
        var_base_map_out_(var_base_map_out),
        attrs_(attrs) {}

231
  std::string InputName(const std::string& name) const override {
H
hong 已提交
232 233 234 235 236 237
    auto it = var_base_map_in_.find(name);
    PADDLE_ENFORCE_NE(it, var_base_map_in_.end(),
                      platform::errors::PreconditionNotMet(
                          "Can not find [%s] in Input", name));
    return it->second[0]->Name();
  }
238
  std::vector<std::string> InputNames(const std::string& name) const override {
H
hong 已提交
239 240 241 242 243 244 245 246 247 248 249 250
    auto it = var_base_map_in_.find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_in_.end(),
        platform::errors::NotFound("Can not find [%s] in Input", name));
    std::vector<std::string> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i]->Name());
    }
    return vec_res;
  }

251
  std::string OutputName(const std::string& name) const override {
H
hong 已提交
252 253 254 255 256 257 258
    auto it = var_base_map_out_.find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_out_.end(),
        platform::errors::NotFound("Can not find [%s] in Output", name));
    return it->second[0]->Name();
  }

259
  std::vector<std::string> OutputNames(const std::string& name) const override {
H
hong 已提交
260 261 262 263 264 265 266 267 268 269 270 271
    auto it = var_base_map_out_.find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_out_.end(),
        platform::errors::NotFound("Can not find [%s] in Output", name));
    std::vector<std::string> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i]->Name());
    }
    return vec_res;
  }

272
  bool HasAttr(const std::string& name) const override {
273
    return attrs_.count(name) != 0;
274
  }
H
hong 已提交
275

276
  const framework::AttributeMap& Attrs() const override { return attrs_; }
H
hong 已提交
277

278
  const framework::Attribute& GetAttr(const std::string& name) const override {
279
    auto it = attrs_.find(name);
H
hong 已提交
280 281

    PADDLE_ENFORCE_NE(
282
        it, attrs_.end(),
H
hong 已提交
283 284 285 286 287
        platform::errors::NotFound("can not find [%s] in attrs", name));

    return it->second;
  }

288
  std::vector<std::string> InNameList() const override {
H
hong 已提交
289 290 291 292 293 294 295 296 297
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_in_.size());

    for (auto& v : var_base_map_in_) {
      vec_temp.push_back(v.first);
    }

    return vec_temp;
  }
298
  bool HasInput(const std::string& name) const override {
H
hong 已提交
299 300 301 302
    auto it = var_base_map_in_.find(name);
    return (it != var_base_map_in_.end() && it->second.size() > 0);
  }

303
  bool HasOutput(const std::string& name) const override {
H
hong 已提交
304 305 306 307
    auto it = var_base_map_out_.find(name);
    return (it != var_base_map_out_.end() && it->second.size() > 0);
  }

308
  size_t InputSize(const std::string& name) const override {
H
hong 已提交
309 310 311
    return InputNames(name).size();
  }

312
  size_t OutputSize(const std::string& name) const override {
H
hong 已提交
313 314 315 316 317 318 319 320 321 322 323 324
    return OutputNames(name).size();
  }

  const Variable* InputVar(const std::string& name) const override {
    auto it = var_base_map_in_.find(name);
    if (it == var_base_map_in_.end()) {
      return nullptr;
    }

    return it->second.empty() ? nullptr : it->second[0]->MutableVar();
  }

325
  Variable* OutputVar(const std::string& name) const override {
H
hong 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    auto it = var_base_map_out_.find(name);
    if (it == var_base_map_out_.end()) {
      return nullptr;
    }

    return it->second.empty() ? nullptr : it->second[0]->MutableVar();
  }

  const std::vector<Variable*> MultiInputVar(
      const std::string& name) const override {
    auto it = var_base_map_in_.find(name);
    if (it == var_base_map_in_.end()) {
      return {};
    }
    std::vector<Variable*> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i]->MutableVar());
    }

    return vec_res;
  }

  std::vector<Variable*> MultiOutputVar(
      const std::string& name) const override {
    auto it = var_base_map_out_.find(name);
    if (it == var_base_map_out_.end()) {
      return {};
    }
    std::vector<Variable*> vec_res;
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.push_back(it->second[i]->MutableVar());
    }

    return vec_res;
  }

 private:
365 366 367
  const NameVarMap<VarType>& var_base_map_in_;
  const NameVarMap<VarType>& var_base_map_out_;
  const framework::AttributeMap& attrs_;
H
hong 已提交
368 369
};

M
minqiyang 已提交
370
// infer var type context for imperative mode
371
template <typename VarType>
J
Jiabin Yang 已提交
372
class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
M
minqiyang 已提交
373
 public:
374 375
  RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
                             const NameVarMap<VarType>* outputs,
J
Jiabin Yang 已提交
376
                             const framework::AttributeMap& attrs_map)
M
minqiyang 已提交
377 378 379 380 381 382 383
      : InferVarTypeContext(nullptr, nullptr),
        inputs_(inputs),
        outputs_(outputs),
        attrs_(attrs_map),
        input_names_(),
        output_names_(),
        var_set_() {
J
Jiabin Yang 已提交
384 385 386
    input_names_.reserve(inputs_.size());
    for (auto& it : inputs_) {
      for (auto& var : it.second) {
M
minqiyang 已提交
387
        input_names_[it.first].emplace_back(var->Name());
J
Jiabin Yang 已提交
388
        var_set_[var->Name()] = var.get();
M
minqiyang 已提交
389 390 391 392 393
      }
    }

    output_names_.reserve(outputs_->size());
    for (auto& it : *outputs_) {
J
Jiabin Yang 已提交
394
      for (auto& var : it.second) {
M
minqiyang 已提交
395
        output_names_[it.first].emplace_back(var->Name());
J
Jiabin Yang 已提交
396
        var_set_[var->Name()] = var.get();
M
minqiyang 已提交
397 398 399 400
      }
    }
  }

M
minqiyang 已提交
401 402 403
  virtual ~RuntimeInferVarTypeContext() {}

  framework::Attribute GetAttr(const std::string& name) const override {
J
Jiabin Yang 已提交
404 405 406 407
    auto iter = attrs_.find(name);
    PADDLE_ENFORCE_EQ(iter != attrs_.end(), true, "Cannot find attribute %s",
                      name);
    return iter->second;
M
minqiyang 已提交
408 409
  }

M
minqiyang 已提交
410
  bool HasVar(const std::string& name) const override {
M
minqiyang 已提交
411 412 413
    return var_set_.count(name) > 0;
  }

M
minqiyang 已提交
414
  bool HasInput(const std::string& name) const override {
415 416
    auto it = inputs_.find(name);
    return (it != inputs_.end() && it->second.size() > 0);
M
minqiyang 已提交
417 418
  }

M
minqiyang 已提交
419
  bool HasOutput(const std::string& name) const override {
M
minqiyang 已提交
420
    PADDLE_ENFORCE_NOT_NULL(outputs_);
421 422
    auto it = outputs_->find(name);
    return (it != outputs_->end() && it->second.size() > 0);
M
minqiyang 已提交
423 424
  }

M
minqiyang 已提交
425 426
  const std::vector<std::string>& Input(
      const std::string& name) const override {
J
Jiabin Yang 已提交
427 428 429 430
    auto iter = input_names_.find(name);
    PADDLE_ENFORCE_EQ(iter != input_names_.end(), true, "Cannot find input %s",
                      name);
    return iter->second;
M
minqiyang 已提交
431 432
  }

M
minqiyang 已提交
433 434
  const std::vector<std::string>& Output(
      const std::string& name) const override {
J
Jiabin Yang 已提交
435
    auto iter = output_names_.find(name);
H
hong 已提交
436

J
Jiabin Yang 已提交
437 438 439
    PADDLE_ENFORCE_EQ(iter != output_names_.end(), true,
                      "Cannot find output %s", name);
    return iter->second;
M
minqiyang 已提交
440 441
  }

M
minqiyang 已提交
442 443
  framework::proto::VarType::Type GetType(
      const std::string& name) const override {
J
Jiabin Yang 已提交
444
    auto iter = var_set_.find(name);
H
hong 已提交
445

J
Jiabin Yang 已提交
446 447 448
    PADDLE_ENFORCE_EQ(iter != var_set_.end(), true,
                      "Cannot find var %s in GetType", name);
    return iter->second->Type();
M
minqiyang 已提交
449 450
  }

M
minqiyang 已提交
451 452
  void SetType(const std::string& name,
               framework::proto::VarType::Type type) override {
453 454 455 456
    if (name == "kLookupTablePath") {
      VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
    } else {
      var_set_[name]->SetType(type);
457 458 459 460
      if ((var_set_[name]->MutableVar()->IsInitialized() == true) &&
          (var_set_[name]->MutableVar()->Type() != type)) {
        var_set_[name]->MutableVar()->Clear();
      }
461
    }
M
minqiyang 已提交
462 463
  }

M
minqiyang 已提交
464 465
  framework::proto::VarType::Type GetDataType(
      const std::string& name) const override {
J
Jiabin Yang 已提交
466
    auto iter = var_set_.find(name);
H
hong 已提交
467

J
Jiabin Yang 已提交
468 469 470
    PADDLE_ENFORCE_EQ(iter != var_set_.end(), true,
                      "Cannot find var %s in GetDataType", name);
    return iter->second->DataType();
M
minqiyang 已提交
471 472
  }

M
minqiyang 已提交
473 474
  void SetDataType(const std::string& name,
                   framework::proto::VarType::Type type) override {
M
minqiyang 已提交
475
    var_set_[name]->SetDataType(type);
M
minqiyang 已提交
476 477
  }

M
minqiyang 已提交
478 479
  std::vector<framework::proto::VarType::Type> GetDataTypes(
      const std::string& name) const override {
M
minqiyang 已提交
480 481 482
    PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
  }

M
minqiyang 已提交
483 484 485
  void SetDataTypes(const std::string& name,
                    const std::vector<framework::proto::VarType::Type>&
                        multiple_data_type) override {
M
minqiyang 已提交
486 487 488
    PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
  }

M
minqiyang 已提交
489
  std::vector<int64_t> GetShape(const std::string& name) const override {
M
minqiyang 已提交
490 491 492
    PADDLE_THROW("Do not handle Shape in runtime InferVarType");
  }

M
minqiyang 已提交
493 494
  void SetShape(const std::string& name,
                const std::vector<int64_t>& dims) override {
M
minqiyang 已提交
495 496 497
    PADDLE_THROW("Do not handle Shape in runtime InferVarType");
  }

M
minqiyang 已提交
498
  int32_t GetLoDLevel(const std::string& name) const override {
M
minqiyang 已提交
499 500 501
    PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
  }

M
minqiyang 已提交
502
  void SetLoDLevel(const std::string& name, int32_t lod_level) override {
M
minqiyang 已提交
503 504 505 506
    PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
  }

 private:
507 508
  const NameVarMap<VarType>& inputs_;
  const NameVarMap<VarType>* outputs_;
J
Jiabin Yang 已提交
509
  const framework::AttributeMap& attrs_;
M
minqiyang 已提交
510 511
  std::unordered_map<std::string, std::vector<std::string>> input_names_;
  std::unordered_map<std::string, std::vector<std::string>> output_names_;
512
  std::unordered_map<std::string, VarType*> var_set_;
J
Jiabin Yang 已提交
513 514 515
};

// TODO(zjl): to support py_func layer
516
class OpBase {
J
Jiabin Yang 已提交
517 518 519
  DISABLE_COPY_AND_ASSIGN(OpBase);

 public:
520
  OpBase() = default;
J
Jiabin Yang 已提交
521

522
  ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
J
Jiabin Yang 已提交
523 524 525 526 527

  size_t id() const { return id_; }

  const std::string& Type() const { return op_->Type(); }

H
hong 已提交
528
  const framework::AttributeMap& Attrs() const { return attrs_; }
529

J
Jiabin Yang 已提交
530 531
  const framework::OpInfo& Info() const { return op_->Info(); }

532 533
  const framework::OperatorBase& InnerOp() const { return *op_; }

J
Jiabin Yang 已提交
534 535
  void ClearBackwardTrace();

536
  const std::vector<std::shared_ptr<OpBase>>& GradPendingOps() const {
J
Jiabin Yang 已提交
537 538 539
    return grad_pending_ops_;
  }

540 541
  void SetGradPendingOps(std::vector<std::shared_ptr<OpBase>> pending_ops) {
    grad_pending_ops_ = std::move(pending_ops);
H
hong 已提交
542 543
  }

544 545 546 547 548 549 550
  NameVarMap<VariableWrapper>* GetMutableOutsMap() { return &outs_; }

  NameVarMap<VariableWrapper>* GetMutableInsMap() { return &ins_; }

  const NameVarMap<VariableWrapper>& GetInsMap() { return ins_; }

  const NameVarMap<VariableWrapper>& GetOutsMap() { return outs_; }
J
Jiabin Yang 已提交
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565

  const platform::Place& place() const { return place_; }

  // TODO(jiabin) prepare for backward hook
  void RegisterBackwardHooks(const std::function<void()>& func) {
    backward_hooks_.emplace_back(func);
  }

  void InvokeBackwardHooks() {
    for (const auto& func : backward_hooks_) {
      func();
      VLOG(5) << "Invoke Backward Hook for: " << Type() << std::endl;
    }
  }

566
  void SetType(const std::string& type);
J
Jiabin Yang 已提交
567

568 569 570 571 572 573
  void CheckAttrs() {
    auto& info = op_->Info();
    if (info.Checker() != nullptr) {
      info.Checker()->Check(&attrs_, true);
    }
  }
H
hong 已提交
574

575 576
  void SetInput(const std::string& name, VariableWrapperList vars) {
    ins_[name] = std::move(vars);
H
hong 已提交
577
  }
578 579 580

  void SetOutput(const std::string& name, VariableWrapperList vars) {
    outs_[name] = std::move(vars);
H
hong 已提交
581
  }
582

H
hong 已提交
583
  void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
584

H
hong 已提交
585 586 587
  void SetAttr(const std::string& name, const framework::Attribute& v) {
    attrs_[name] = v;
  }
588

H
hong 已提交
589 590 591 592 593 594 595 596
  void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
    PADDLE_THROW("SetBlockAttr is not support in dygraph OpBase");
  }

  const framework::AttributeMap& Attrs() { return attrs_; }

  void SetId(size_t id) { id_ = id; }

597 598 599
  void SetPlace(const platform::Place& place) { place_ = place; }

  bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
H
hong 已提交
600 601 602 603 604 605 606 607 608 609 610 611 612

  const framework::Attribute& GetAttr(const std::string& name) const {
    auto it = attrs_.find(name);
    PADDLE_ENFORCE(it != attrs_.end(), "can not find attribute [%s]", name);

    return it->second;
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
    return boost::get<T>(GetAttr(name));
  }

613 614 615 616 617 618 619 620 621 622 623 624 625
  void AddAllowedEmptyVar(const VariableWrapper* var) {
    allow_empty_vars_.emplace(var);
  }

  bool IsAllowedEmptyVar(const VariableWrapper* var) {
    return allow_empty_vars_.count(var) > 0;
  }

  static void Run(const framework::OperatorBase& op,
                  const NameVarMap<VarBase>& ins,
                  const NameVarMap<VarBase>& outs,
                  const framework::AttributeMap& attrs,
                  const platform::Place& place);
J
Jiabin Yang 已提交
626

627 628 629 630 631 632 633 634 635 636
  static void Run(const framework::OperatorBase& op,
                  const NameVarMap<VariableWrapper>& ins,
                  const NameVarMap<VariableWrapper>& outs,
                  const framework::AttributeMap& attrs,
                  const platform::Place& place);

 private:
  NameVarMap<VariableWrapper> ins_;
  NameVarMap<VariableWrapper> outs_;
  framework::AttributeMap attrs_;
J
Jiabin Yang 已提交
637 638
  std::unique_ptr<framework::OperatorBase> op_;

639
  std::vector<std::shared_ptr<OpBase>> grad_pending_ops_;
J
Jiabin Yang 已提交
640 641
  platform::Place place_;

642
  std::unordered_set<const VariableWrapper*> allow_empty_vars_;
H
hong 已提交
643

644
  size_t id_{-1UL};
J
Jiabin Yang 已提交
645

646
  std::vector<std::function<void()>> backward_hooks_;
647 648
};

649
template <typename VarType>
H
hong 已提交
650 651 652 653
class DygraphInferShapeContext : public framework::InferShapeContext {
  using DDim = framework::DDim;

 public:
654 655
  DygraphInferShapeContext(const NameVarMap<VarType>* in,
                           const NameVarMap<VarType>* out,
H
hong 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
                           const framework::AttributeMap* attr)
      : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {}

  bool HasInput(const std::string& name) const override {
    // has only one input
    auto it = var_base_map_in_->find(name);

    if (it == var_base_map_in_->end()) {
      return false;
    }
    const auto& in = it->second;
    if (in.size() == 0) return false;
    PADDLE_ENFORCE_EQ(
        in.size(), 1UL,
        platform::errors::PreconditionNotMet(
            "Input %s should not have more than one inputs", name));
    return in[0] != nullptr;
  }

  bool HasOutput(const std::string& name) const override {
    // has only one output
    auto it = var_base_map_out_->find(name);
    if (it == var_base_map_out_->end()) {
      return false;
    }
    const auto& out = it->second;
    if (out.size() == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(
        out.size(), 1UL,
        platform::errors::PreconditionNotMet(
            "Output %s should not have more than one outputs", name));
    return out[0] != nullptr;
  }

  bool HasInputs(const std::string& name) const override {
    auto it = var_base_map_in_->find(name);
    if (it == var_base_map_in_->end() || it->second.empty()) {
      return false;
    }
    for (auto& input : it->second) {
      if (input == nullptr) {
        return false;
      }
    }
    return true;
  }

  bool HasOutputs(const std::string& name) const override {
    auto it = var_base_map_out_->find(name);
    if (it == var_base_map_out_->end() || it->second.empty()) {
      return false;
    }
    for (auto& output : it->second) {
      if (output == nullptr) {
        return false;
      }
    }
    return true;
  }

  framework::AttrReader Attrs() const override {
    return framework::AttrReader(*attrs_);
  }

  std::vector<std::string> Inputs(const std::string& name) const override {
    // return op_.Inputs(name);
    std::vector<std::string> vec_res;
    auto it = var_base_map_in_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_in_->end(),
        platform::errors::NotFound("can not find [%s] in input", name));

    vec_res.reserve(it->second.size());
    for (auto& var : it->second) {
      vec_res.push_back(var->Name());
    }

    return vec_res;
  }

  std::vector<std::string> Outputs(const std::string& name) const override {
    std::vector<std::string> vec_res;
    auto it = var_base_map_out_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_out_->end(),
        platform::errors::NotFound("can not find [%s] in output", name));

    vec_res.reserve(it->second.size());
    for (auto& var : it->second) {
      vec_res.push_back(var->Name());
    }

    return vec_res;
  }

  void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) override {
    auto in_it = var_base_map_in_->find(in);
    auto out_it = var_base_map_out_->find(out);
    PADDLE_ENFORCE_NE(
        in_it, var_base_map_in_->end(),
        platform::errors::NotFound("can not found [%s] in input", in));
    PADDLE_ENFORCE_GT(in_it->second.size(), i,
                      platform::errors::PreconditionNotMet(
                          "Inputs %s should have %llu argument", in, i));
    PADDLE_ENFORCE_NE(
        out_it, var_base_map_out_->end(),
        platform::errors::NotFound("can not found [%s] in input", in));
    PADDLE_ENFORCE_GT(out_it->second.size(), j,
                      platform::errors::PreconditionNotMet(
                          "Outputs %s should have %llu argument", out, j));

    framework::Variable* in_var = in_it->second[i]->MutableVar();
    framework::Variable* out_var = out_it->second[j]->MutableVar();

    PADDLE_ENFORCE_EQ(in_var->Type(), out_var->Type(),
                      platform::errors::PreconditionNotMet(
                          "The type of %s and %s is not the same.", in, out));

777 778 779 780 781 782 783 784 785 786 787
    if (in_var->IsType<framework::LoDTensor>()) {
      auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
      auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
      out_lod_tensor->Resize(in_lod_tensor.dims());
    } else {
      auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
      auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
      out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
      out_sele_rows->set_rows(in_sele_rows.rows());
      out_sele_rows->set_height(in_sele_rows.height());
    }
H
hong 已提交
788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
  }

  void ShareAllLoD(const std::string& in,
                   const std::string& out) const override {
    // do nothing
  }
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const override {
    // do nothing
  }

  bool IsRuntime() const override { return true; }

  // TODO(paddle-dev): Can this be template?
  std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
      const std::string& name) override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "GetInputVarPtrs not support in dygraph runtime context"));
  }

  std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
      const std::string& name) override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "GetOutputVarPtrs not support in dygraph runtime context"));
  }

  DDim GetInputDim(const std::string& name) const override {
    auto it = var_base_map_in_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_in_->end(),
        platform::errors::NotFound("can not find [%s] in input", name));
    PADDLE_ENFORCE_EQ(
        it->second.size(), 1UL,
        platform::errors::PreconditionNotMet(
            "Input(%s) should hold one element, but now it holds %d", name,
            it->second.size()));
    return this->GetDim(it->second[0]->MutableVar());
  }

  std::vector<DDim> GetInputsDim(const std::string& name) const override {
    // const std::vector<Variable*>& vars = InputVars(name);
    std::vector<DDim> vec_res;
    auto it = var_base_map_in_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_in_->end(),
        platform::errors::NotFound("can not find [%s] in output", name));
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.emplace_back(GetDim(it->second[i]->MutableVar()));
    }

    return vec_res;
  }

  std::vector<framework::proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override {
    std::vector<framework::proto::VarType::Type> vec_res;
    auto it = var_base_map_in_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_in_->end(),
        platform::errors::NotFound("can not find [%s] in input", name));
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.emplace_back(
          framework::ToVarType(it->second[i]->MutableVar()->Type()));
    }
    return vec_res;
  }

  std::vector<framework::proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override {
    std::vector<framework::proto::VarType::Type> vec_res;
    auto it = var_base_map_out_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_out_->end(),
        platform::errors::NotFound("can not find [%s] in output", name));
    vec_res.reserve(it->second.size());
    for (size_t i = 0; i < it->second.size(); ++i) {
      vec_res.emplace_back(
          framework::ToVarType(it->second[i]->MutableVar()->Type()));
    }
    return vec_res;
  }

  void SetOutputDim(const std::string& name, const DDim& dim) override {
    auto it = var_base_map_out_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_out_->end(),
        platform::errors::NotFound("can not find [%s] in output", name));

    SetDim(it->second[0]->MutableVar(), dim);
  }

  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override {
    auto it = var_base_map_out_->find(name);
    PADDLE_ENFORCE_NE(
        it, var_base_map_out_->end(),
        platform::errors::NotFound("can not find [%s] in output", name));

    PADDLE_ENFORCE_EQ(it->second.size(), dims.size(),
                      platform::errors::PreconditionNotMet(
                          "dim size [%d] is not match output var number [%d]",
                          dims.size(), it->second.size()));

    for (size_t i = 0; i < dims.size(); ++i) {
      SetDim(it->second[i]->MutableVar(), dims[i]);
    }
  }

  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "GetLoDLevel function not support in dygraph mode"));
  }

  void SetLoDLevel(const std::string& out, int32_t lod_level,
                   size_t j = 0) const override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "SetLoDLevel function not support in dygraph mode"));
  }

 protected:
  DDim GetDim(framework::Variable* var) const {
    PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
                                     "Input variable should not be null"));
    if (var->IsType<framework::LoDTensor>()) {
      return var->Get<framework::LoDTensor>().dims();
    } else if (var->IsType<framework::SelectedRows>()) {
      return var->Get<framework::SelectedRows>().GetCompleteDims();
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
          "type_id is xx."));
    }
  }

  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "GetRepeatedDims not support in dygraph runtime"));
  }

  void SetDim(framework::Variable* var, const DDim& dim) {
    if (var->IsType<framework::LoDTensor>()) {
      var->GetMutable<framework::LoDTensor>()->Resize(dim);
    } else if (var->IsType<framework::SelectedRows>()) {
      var->GetMutable<framework::SelectedRows>()->set_height(dim[0]);
    } else {
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Variable type_id %s, expect LoDTensor/SelectedRows."));
    }
  }

  void SetDims(const std::vector<framework::Variable*>& vars,
               const std::vector<DDim>& dims) {
    size_t length = vars.size();
    PADDLE_ENFORCE_EQ(
        length, dims.size(),
        platform::errors::PreconditionNotMet(
            "Vars number [%d] should be equal with dims number [%d]", length,
            dims.size()));
    for (size_t i = 0; i < length; ++i) {
      if (vars[i] == nullptr) {
        continue;
      }
      SetDim(vars[i], dims[i]);
    }
  }

  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
    PADDLE_THROW(platform::errors::PermissionDenied(
        "SetRepeatedDims not support in dygraph runtime"));
  }

 private:
963 964
  const NameVarMap<VarType>* var_base_map_in_;
  const NameVarMap<VarType>* var_base_map_out_;
H
hong 已提交
965 966 967
  const framework::AttributeMap* attrs_;
};

968 969
}  // namespace imperative
}  // namespace paddle