layer.h 6.8 KB
Newer Older
J
Jiabin Yang 已提交
1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
J
Jiabin Yang 已提交
16
#include <algorithm>
Z
Zeng Jinle 已提交
17
#include <cstdint>
J
Jiabin Yang 已提交
18
#include <list>
19 20
#include <map>
#include <memory>
Z
Zeng Jinle 已提交
21
#include <set>
22 23 24
#include <string>
#include <unordered_map>
#include <unordered_set>
25
#include <utility>
J
Jiabin Yang 已提交
26
#include <vector>
W
wanghuancoder 已提交
27

28
#include "paddle/fluid/framework/operator.h"
H
hong 已提交
29 30
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type.h"
J
Jiabin Yang 已提交
31
#include "paddle/fluid/framework/variable.h"
Z
Zeng Jinle 已提交
32
#include "paddle/fluid/imperative/flags.h"
33
#include "paddle/fluid/imperative/saved_variable_wrapper_list.h"
J
Jiabin Yang 已提交
34
#include "paddle/fluid/imperative/type_defs.h"
35
#include "paddle/fluid/imperative/variable_wrapper.h"
J
Jiabin Yang 已提交
36 37
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
M
minqiyang 已提交
38

W
wanghuancoder 已提交
39 40 41 42 43 44
namespace paddle {
namespace framework {
class Variable;
}  // namespace framework
}  // namespace paddle

45 46 47 48
namespace paddle {
namespace imperative {

class OpBase;
W
wanghuancoder 已提交
49 50
class GradOpNode;
class VariableWrapper;
51

Z
Zeng Jinle 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64
class ThreadSafeNameSet {
 public:
  void Insert(const std::string& name);

  void Remove(const std::string& name);

  std::vector<std::string> Names() const;

 private:
  std::multiset<std::string> set_;
  mutable std::mutex mtx_;
};

65
class VarBase {
J
Jiabin Yang 已提交
66 67
  DISABLE_COPY_AND_ASSIGN(VarBase);

68
 public:
Z
Zeng Jinle 已提交
69
  static std::vector<std::string> AliveVarNames();
70 71

 public:
J
Jiabin Yang 已提交
72
  explicit VarBase(bool has_grad, const std::string& name)
73
      : var_(std::make_shared<VariableWrapper>(name)),
J
Jiabin Yang 已提交
74
        grad_var_(has_grad ? new VarBase(false, GradVarName()) : nullptr) {
75 76 77 78
    if (has_grad) {
      var_->SetGradVar(grad_var_->var_);
    }

Z
Zeng Jinle 已提交
79
    if (IsDebugEnabled()) {
80 81
      VLOG(10) << "Construct VarBase: " << Name();
      name_set_.Insert(Name());
Z
Zeng Jinle 已提交
82
    }
83
  }
84

J
Jiabin Yang 已提交
85 86
  explicit VarBase(const std::string& name) : VarBase(true, name) {}

87
  // NOTE(zengjinle): be careful when you use this constructor!!!
88 89
  // Unpack VarBase from VariableWrapper.
  explicit VarBase(const std::shared_ptr<VariableWrapper>& var);
90

J
Jiabin Yang 已提交
91
  ~VarBase() {
92
    VLOG(10) << "Destruct VarBase: " << Name();
Z
Zeng Jinle 已提交
93
    if (IsDebugEnabled()) {
94
      name_set_.Remove(Name());
Z
Zeng Jinle 已提交
95
    }
M
minqiyang 已提交
96
  }
97

98
  const std::shared_ptr<VariableWrapper>& SharedVar() const { return var_; }
99

100 101 102
  const framework::Variable& Var() const { return var_->Var(); }

  framework::Variable* MutableVar() { return var_->MutableVar(); }
M
minqiyang 已提交
103

J
Jiabin Yang 已提交
104 105 106 107
  bool HasGradVar() const { return grad_var_ != nullptr; }

  const std::shared_ptr<VarBase>& GradVarBase() const { return grad_var_; }

108 109 110 111
  void ClearGradVarBase() { grad_var_ = nullptr; }

  const std::shared_ptr<VarBase>& MutableGradVarBase() {
    if (grad_var_ == nullptr) {
112
      if (auto grad_var_wrapper = var_->GetGradVar()) {
113
        grad_var_ = std::make_shared<VarBase>(grad_var_wrapper);
114 115 116 117 118 119 120
      } else {
        grad_var_ = std::make_shared<VarBase>(false, GradVarName());
        var_->SetGradVar(grad_var_->var_);
        grad_var_->var_->SetGradNode(grad_var_->grad_node_);
      }
      // NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property
      // same as fwd varbase
121
      grad_var_->SetOverridedStopGradient(var_->InnerOverridedStopGradient());
122 123 124 125
    }
    return grad_var_;
  }

J
Jiabin Yang 已提交
126
  const framework::Variable& GradVar() const {
127 128 129 130
    PADDLE_ENFORCE_NOT_NULL(
        grad_var_,
        platform::errors::NotFound("Gradient of %s does not exist", Name()));
    return grad_var_->Var();
M
minqiyang 已提交
131
  }
M
minqiyang 已提交
132

J
Jiabin Yang 已提交
133
  framework::Variable* MutableGradVar() {
134 135 136 137
    PADDLE_ENFORCE_NOT_NULL(
        grad_var_,
        platform::errors::NotFound("Gradient of %s does not exist", Name()));
    return grad_var_->MutableVar();
J
Jiabin Yang 已提交
138
  }
X
Xin Pan 已提交
139

140
  void SetOverridedStopGradient(bool stop_gradient) {
141
    var_->SetOverridedStopGradient(stop_gradient);
J
Jiabin Yang 已提交
142
    if (grad_var_) {
143 144 145 146
      grad_var_->SetOverridedStopGradient(stop_gradient);
    }
  }

147
  bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
148 149

  void InnerSetOverridedStopGradient(bool stop_gradient) {
150 151
    if (var_->InnerOverridedStopGradient() == -1) {
      var_->InnerSetOverridedStopGradient(stop_gradient);
152 153 154 155 156
      if (grad_var_) {
        grad_var_->InnerSetOverridedStopGradient(stop_gradient);
      }
    }
  }
157

158
  void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
159

160
  bool Persistable() const { return var_->Persistable(); }
X
Xin Pan 已提交
161

162
  // Only grad var is allowed to call these 2 methods
163 164 165
  void SetGradNode(const std::shared_ptr<GradOpNode>& node) {
    grad_node_ = node;
    var_->SetGradNode(node);
166 167
  }

168
  size_t GradOpNum() const;
169

170 171 172
  const std::shared_ptr<GradOpNode>& GradNode() const { return grad_node_; }

  void ClearGradNode() { SetGradNode(nullptr); }
X
Xin Pan 已提交
173

174
  const std::string& Name() const { return var_->Name(); }
M
minqiyang 已提交
175

J
Jiabin Yang 已提交
176
  void SetName(const std::string& name) {
177
    var_->SetName(name);
J
Jiabin Yang 已提交
178 179 180
    if (grad_var_) {
      grad_var_->SetName(GradVarName());
    }
M
minqiyang 已提交
181 182
  }

183
  std::string GradVarName() { return framework::GradVarName(Name()); }
184

185
  void SetType(framework::proto::VarType::Type type) { var_->SetType(type); }
186

187
  framework::proto::VarType::Type Type() const { return var_->Type(); }
188

J
Jiabin Yang 已提交
189
  void SetDataType(framework::proto::VarType::Type data_type) {
190
    var_->SetDataType(data_type);
J
Jiabin Yang 已提交
191
    if (grad_var_) {
192
      grad_var_->SetDataType(data_type);
193 194 195
    }
  }

196
  framework::proto::VarType::Type DataType() const { return var_->DataType(); }
X
polish  
Xin Pan 已提交
197

198 199
  const platform::Place Place() const { return var_->Place(); }

J
Jiabin Yang 已提交
200
  void ClearGradient();
X
Xin Pan 已提交
201

J
Jiabin Yang 已提交
202 203
  std::shared_ptr<VarBase> NewVarBase(const platform::Place& dst_place,
                                      const bool blocking) const;
M
minqiyang 已提交
204

J
Jiabin Yang 已提交
205
 private:
206 207 208 209 210 211 212
  /**
   * NOTE(zengjinle): never remove the const qualifier of `var_` if you are
   * not very familiar with the autograd idea (including the higher order
   * derivative).
   */
  const std::shared_ptr<VariableWrapper> var_;

J
Jiabin Yang 已提交
213
  std::shared_ptr<VarBase> grad_var_;
214 215 216 217 218 219

  /**
   * NOTE(zengjinle): should consider whether to implement an inlined vector
   * or other things like that.
   */
  std::shared_ptr<GradOpNode> grad_node_;
H
hong 已提交
220

J
Jiabin Yang 已提交
221
  mutable size_t copied_counter_ = 0;
222

J
Jiabin Yang 已提交
223
  static ThreadSafeNameSet name_set_;
224 225 226 227 228 229
};

class Layer {
 public:
  virtual ~Layer() {}

230 231
  virtual std::vector<std::shared_ptr<VarBase>> Forward(
      const std::vector<std::shared_ptr<VarBase>>& inputs) {
J
Jiabin Yang 已提交
232
    return {};
233
  }
X
Xin Pan 已提交
234
};
235

236 237 238 239
std::shared_ptr<GradOpNode> CreateGradOpNode(
    const framework::OperatorBase& op, const NameVarBaseMap& ins,
    const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
    const platform::Place& place);
H
hong 已提交
240

241 242
}  // namespace imperative
}  // namespace paddle