dygraph_grad_maker.h 10.7 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <map>
H
hong 已提交
18 19 20
#include <memory>
#include <string>
#include <unordered_map>
21
#include <unordered_set>
22
#include <utility>
H
hong 已提交
23 24 25
#include <vector>

#include "paddle/fluid/imperative/layer.h"
26
#include "paddle/fluid/imperative/op_base.h"
H
hong 已提交
27 28 29 30 31 32 33
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace imperative {

34 35 36 37 38 39 40 41 42 43 44
enum TracedVarRole { kForward = 0, kBackward = 1 };

template <typename T, TracedVarRole kRole>
class TracedVarList : public std::vector<std::shared_ptr<T>> {
 private:
  using BaseClass = std::vector<std::shared_ptr<T>>;

 public:
  using BaseClass::BaseClass;
};

H
hong 已提交
45 46
class GradOpBaseMakerBase {
 public:
47 48 49 50 51
  explicit GradOpBaseMakerBase(
      const std::string& type, const NameVarBaseMap& var_base_map_in,
      const NameVarBaseMap& var_base_map_out,
      const framework::AttributeMap& attrs,
      const std::map<std::string, std::string>& inplace_map)
52
      : type_(type),
H
hong 已提交
53
        var_base_map_in_(var_base_map_in),
54
        var_base_map_out_(var_base_map_out),
55 56
        attrs_(attrs),
        inplace_map_(inplace_map) {}
H
hong 已提交
57 58

  virtual ~GradOpBaseMakerBase() = default;
59

60
  virtual std::shared_ptr<GradOpNode> operator()() const = 0;
H
hong 已提交
61

62
  TracedVarList<VarBase, TracedVarRole::kBackward> InputGrad(
H
hong 已提交
63
      const std::string& name, bool drop_empty_grad = true) const {
64
    return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/true);
H
hong 已提交
65 66
  }

67
  TracedVarList<VarBase, TracedVarRole::kBackward> OutputGrad(
H
hong 已提交
68
      const std::string& name) const {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/false);
  }

  TracedVarList<VarBase, TracedVarRole::kForward> Input(
      const std::string& name) const {
    return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/true);
  }

  TracedVarList<VarBase, TracedVarRole::kForward> Output(
      const std::string& name) const {
    return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/false);
  }

  static TracedVarList<VarBase, TracedVarRole::kForward> EmptyInput() {
    return {};
H
hong 已提交
84 85
  }

86 87
  static TracedVarList<VarBase, TracedVarRole::kForward> EmptyOutput() {
    return {};
H
hong 已提交
88 89
  }

90 91
  static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyOutputGrad() {
    return {};
H
hong 已提交
92 93
  }

94 95 96
  static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyInputGrad() {
    return {};
  }
H
hong 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

  std::vector<std::string> InputNames() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_in_.size());
    for (auto& it : var_base_map_in_) {
      vec_temp.emplace_back(it.first);
    }
    return vec_temp;
  }

  std::vector<std::string> OutputNames() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_out_.size());
    for (auto& it : var_base_map_out_) {
      vec_temp.emplace_back(it.first);
    }
    return vec_temp;
  }

116
  const framework::AttributeMap& Attrs() const { return attrs_; }
H
hong 已提交
117 118

  const framework::Attribute& GetAttr(const std::string& name) const {
119 120 121 122 123
    auto it = attrs_.find(name);
    PADDLE_ENFORCE_EQ(
        it != attrs_.end(), true,
        platform::errors::NotFound(
            "Cannot find attribute [%s] in operator [%s]", name, type_));
H
hong 已提交
124 125 126 127 128
    return it->second;
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
129
    return BOOST_GET_CONST(T, GetAttr(name));
H
hong 已提交
130 131
  }

132
  const std::string& ForwardOpType() const { return type_; }
H
hong 已提交
133 134 135

 protected:
  bool HasInput(const std::string& name) const {
136
    return var_base_map_in_.count(name) > 0;
H
hong 已提交
137 138
  }

139 140
  bool HasOutput(const std::string& name) const {
    return var_base_map_out_.count(name) > 0;
141 142
  }

143 144 145 146
  static std::shared_ptr<GradOpNode> NewGradNode() {
    return std::make_shared<GradOpNode>();
  }

147 148 149 150
  const std::map<std::string, std::string>& GetInplaceMap() const {
    return inplace_map_;
  }

H
hong 已提交
151
 private:
152 153 154 155
  template <TracedVarRole kRole>
  TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
                                               bool is_input) const {
    const auto& data_map = is_input ? var_base_map_in_ : var_base_map_out_;
H
hong 已提交
156
    auto iterator = data_map.find(name);
157
    TracedVarList<VarBase, kRole> vec_temp;
H
hong 已提交
158 159 160
    if (iterator != data_map.end()) {
      vec_temp.reserve(iterator->second.size());

161
      bool is_valid = false;
H
hong 已提交
162
      for (auto& var_base_temp : iterator->second) {
163 164 165 166 167
        if (!var_base_temp) {
          vec_temp.emplace_back();
          continue;
        }

168
        if (kRole == TracedVarRole::kBackward) {
169 170
          if (!var_base_temp->HasGradVar()) {
            VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
171
                    << " in OP " << type_ << " is null";
172 173
            var_base_temp->MutableGradVarBase();
          }
H
hong 已提交
174
          auto grad_var_base_tmp = var_base_temp->GradVarBase();
175

176 177 178 179 180 181
          if (!is_input) {
            auto* tensor = grad_var_base_tmp->MutableVar()
                               ->GetMutable<framework::LoDTensor>();
            tensor->Resize(
                var_base_temp->Var().Get<framework::LoDTensor>().dims());
          }
H
hong 已提交
182 183 184 185
          vec_temp.emplace_back(grad_var_base_tmp);
        } else {
          vec_temp.emplace_back(var_base_temp);
        }
186 187 188 189 190
        is_valid = true;
      }

      if (!is_valid) {
        vec_temp.clear();
H
hong 已提交
191 192 193 194 195 196 197
      }
    }

    return vec_temp;
  }

 private:
198
  const std::string& type_;
H
hong 已提交
199 200
  const NameVarBaseMap& var_base_map_in_;
  const NameVarBaseMap& var_base_map_out_;
201
  const framework::AttributeMap& attrs_;
202
  const std::map<std::string, std::string>& inplace_map_;
203
};
H
hong 已提交
204

205 206 207 208
class TracedGradOp {
  DISABLE_COPY_AND_ASSIGN(TracedGradOp);

 public:
209 210
  explicit TracedGradOp(const std::shared_ptr<GradOpNode>& node)
      : node_(node), op_(&(node->emplace_back())) {}
211 212

  ~TracedGradOp() {
213 214 215 216 217
    if (UNLIKELY(op_->GetOutsMap().empty())) {
      node_->pop_back();
    } else {
      op_->CheckAttrs();
    }
218 219 220 221 222
  }

  template <TracedVarRole kRole>
  void SetInput(const std::string& name,
                const TracedVarList<VarBase, kRole>& vars) {
223 224 225 226
    if (vars.empty()) {
      return;
    }

227 228
    if (kRole == TracedVarRole::kBackward) {
      for (auto& var : vars) {
229
        if (var && !var->OverridedStopGradient()) {
230
          var->SetGraphIsFreed(false);
231 232 233 234
          auto dirty_grad_node = var->GradNode();
          if (dirty_grad_node) {
            map_dirty_grad_node_[var] = dirty_grad_node;
          }
235 236
          var->SetGradNode(node_);
        }
237 238
      }
    }
239 240

    auto var_wrappers = ToVarWrapperList<kRole>(vars);
241

242 243 244 245
    if (!var_wrappers.empty()) {
      op_->SetInput(name, std::move(var_wrappers),
                    kRole == TracedVarRole::kBackward);
    }
246 247 248 249 250
  }

  template <TracedVarRole kRole>
  void SetOutput(const std::string& name,
                 const TracedVarList<VarBase, kRole>& vars) {
251 252 253 254
    if (vars.empty()) {
      return;
    }

255 256 257 258 259
    if (kRole == TracedVarRole::kBackward) {
      if (vars.size() == 1 && vars.front()->OverridedStopGradient()) {
        return;
      } else {
        for (auto& var : vars) {
260
          if (var && !var->OverridedStopGradient() && var->GradNode()) {
261 262 263 264 265
            if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) {
              node_->InsertGradPendingNode(map_dirty_grad_node_[var]);
            } else {
              node_->InsertGradPendingNode(var->GradNode());
            }
266 267 268 269 270
          }
        }
      }
    }

271 272 273 274 275
    auto var_wrappers = ToVarWrapperList<kRole>(vars);
    if (!var_wrappers.empty()) {
      op_->SetOutput(name, std::move(var_wrappers),
                     kRole == TracedVarRole::kBackward);
    }
276 277
  }

278 279
  std::string Type() const { return op_->Type(); }

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
  void SetType(const std::string& type) { op_->SetType(type); }

  void SetAttrMap(const framework::AttributeMap& attrs) {
    return op_->SetAttrMap(attrs);
  }

  void SetAttr(const std::string& name, const framework::Attribute& v) {
    op_->SetAttr(name, v);
  }

  bool HasAttr(const std::string& name) const { return op_->HasAttr(name); }

  const framework::Attribute& GetAttr(const std::string& name) const {
    return op_->GetAttr(name);
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
    return op_->Attr<T>(name);
  }

 private:
302
  template <TracedVarRole kRole>
303 304 305 306
  static std::vector<std::shared_ptr<VariableWrapper>> ToVarWrapperList(
      const std::vector<std::shared_ptr<VarBase>>& vars) {
    std::vector<std::shared_ptr<VariableWrapper>> result;
    result.reserve(vars.size());
307
    bool has_valid = false;
308
    for (auto& var : vars) {
309 310 311 312
      if (UNLIKELY(!var || (kRole == TracedVarRole::kBackward &&
                            var->OverridedStopGradient()))) {
        result.emplace_back();
      } else {
313 314
        auto var_wrapper = SnapshotVarWrapper(var->SharedVar());
        result.emplace_back(var_wrapper);
315 316 317 318 319 320
        has_valid = true;
      }
    }

    if (!has_valid) {
      result.clear();
321 322 323 324
    }
    return result;
  }

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
  // Get a snapshot of VariableWrapper at a certain inplace version.
  // The inplace version number of VariableWrapper is used for inplace
  // detection in gradient compution.
  static const std::shared_ptr<VariableWrapper> SnapshotVarWrapper(
      const std::shared_ptr<VariableWrapper>& var_wrapper) {
    // NOTE(liym27):
    //  Use original var_wrapper if its inplace_version is not
    //  changed. Otherwise, it will affect the accuracy of the model
    //  results and affect double grad.
    if (!var_wrapper->MutableVar()->IsInitialized() ||
        var_wrapper->InplaceVersionSnapshot() ==
            var_wrapper->MutableVar()->CurrentInplaceVersion()) {
      return var_wrapper;
    } else {
      VariableWrapper new_var_wrapper = *var_wrapper.get();
      new_var_wrapper.ResetInplaceVersion();
      return std::make_shared<VariableWrapper>(new_var_wrapper);
    }
  }

345
 private:
346 347
  const std::shared_ptr<GradOpNode>& node_;
  OpBase* op_;
348 349 350 351 352 353
  // Inplace op has recursion problems when performing grad calculation.
  // Because the input and output of inplace op are the same, the grad
  // node of inplace var will be overwritten.
  // This map is used to store the grad node of inplace var in temporary.
  std::unordered_map<std::shared_ptr<VarBase>, std::shared_ptr<GradOpNode>>
      map_dirty_grad_node_;
H
hong 已提交
354 355 356 357
};

}  // namespace imperative
}  // namespace paddle