dygraph_grad_maker.h 11.1 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <map>
H
hong 已提交
18 19 20
#include <memory>
#include <string>
#include <unordered_map>
21
#include <unordered_set>
22
#include <utility>
H
hong 已提交
23 24 25
#include <vector>

#include "paddle/fluid/imperative/layer.h"
26
#include "paddle/fluid/imperative/op_base.h"
H
hong 已提交
27 28 29 30 31 32 33
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace imperative {

34 35 36 37 38 39 40 41 42 43 44
enum TracedVarRole { kForward = 0, kBackward = 1 };

template <typename T, TracedVarRole kRole>
class TracedVarList : public std::vector<std::shared_ptr<T>> {
 private:
  using BaseClass = std::vector<std::shared_ptr<T>>;

 public:
  using BaseClass::BaseClass;
};

H
hong 已提交
45 46
class GradOpBaseMakerBase {
 public:
47 48 49 50 51
  explicit GradOpBaseMakerBase(
      const std::string& type, const NameVarBaseMap& var_base_map_in,
      const NameVarBaseMap& var_base_map_out,
      const framework::AttributeMap& attrs,
      const std::map<std::string, std::string>& inplace_map)
52
      : type_(type),
H
hong 已提交
53
        var_base_map_in_(var_base_map_in),
54
        var_base_map_out_(var_base_map_out),
55 56
        attrs_(attrs),
        inplace_map_(inplace_map) {}
H
hong 已提交
57 58

  virtual ~GradOpBaseMakerBase() = default;
59

60
  virtual std::shared_ptr<GradOpNode> operator()() const = 0;
H
hong 已提交
61

62
  TracedVarList<VarBase, TracedVarRole::kBackward> InputGrad(
H
hong 已提交
63
      const std::string& name, bool drop_empty_grad = true) const {
64
    return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/true);
H
hong 已提交
65 66
  }

67
  TracedVarList<VarBase, TracedVarRole::kBackward> OutputGrad(
H
hong 已提交
68
      const std::string& name) const {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/false);
  }

  TracedVarList<VarBase, TracedVarRole::kForward> Input(
      const std::string& name) const {
    return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/true);
  }

  TracedVarList<VarBase, TracedVarRole::kForward> Output(
      const std::string& name) const {
    return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/false);
  }

  static TracedVarList<VarBase, TracedVarRole::kForward> EmptyInput() {
    return {};
H
hong 已提交
84 85
  }

86 87
  static TracedVarList<VarBase, TracedVarRole::kForward> EmptyOutput() {
    return {};
H
hong 已提交
88 89
  }

90 91
  static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyOutputGrad() {
    return {};
H
hong 已提交
92 93
  }

94 95 96
  static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyInputGrad() {
    return {};
  }
H
hong 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

  std::vector<std::string> InputNames() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_in_.size());
    for (auto& it : var_base_map_in_) {
      vec_temp.emplace_back(it.first);
    }
    return vec_temp;
  }

  std::vector<std::string> OutputNames() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_out_.size());
    for (auto& it : var_base_map_out_) {
      vec_temp.emplace_back(it.first);
    }
    return vec_temp;
  }

116 117 118 119 120 121 122 123 124
  // Only for dygraph
  void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
    default_attrs_ = &default_attrs;
  }

  const framework::AttributeMap& DefaultAttrsMap() const {
    return *default_attrs_;
  }

125
  const framework::AttributeMap& Attrs() const { return attrs_; }
H
hong 已提交
126

127
  virtual const framework::Attribute& GetAttr(const std::string& name) const {
128 129 130 131 132
    auto it = attrs_.find(name);
    PADDLE_ENFORCE_EQ(
        it != attrs_.end(), true,
        platform::errors::NotFound(
            "Cannot find attribute [%s] in operator [%s]", name, type_));
H
hong 已提交
133 134 135 136 137
    return it->second;
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
138
    return BOOST_GET_CONST(T, GetAttr(name));
H
hong 已提交
139 140
  }

141
  const std::string& ForwardOpType() const { return type_; }
H
hong 已提交
142 143 144

 protected:
  bool HasInput(const std::string& name) const {
145
    return var_base_map_in_.count(name) > 0;
H
hong 已提交
146 147
  }

148 149
  bool HasOutput(const std::string& name) const {
    return var_base_map_out_.count(name) > 0;
150 151
  }

152 153 154 155
  static std::shared_ptr<GradOpNode> NewGradNode() {
    return std::make_shared<GradOpNode>();
  }

156 157 158 159
  const std::map<std::string, std::string>& GetInplaceMap() const {
    return inplace_map_;
  }

H
hong 已提交
160
 private:
161 162 163 164
  template <TracedVarRole kRole>
  TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
                                               bool is_input) const {
    const auto& data_map = is_input ? var_base_map_in_ : var_base_map_out_;
H
hong 已提交
165
    auto iterator = data_map.find(name);
166
    TracedVarList<VarBase, kRole> vec_temp;
H
hong 已提交
167 168 169
    if (iterator != data_map.end()) {
      vec_temp.reserve(iterator->second.size());

170
      bool is_valid = false;
H
hong 已提交
171
      for (auto& var_base_temp : iterator->second) {
172 173 174 175 176
        if (!var_base_temp) {
          vec_temp.emplace_back();
          continue;
        }

177
        if (kRole == TracedVarRole::kBackward) {
178 179
          if (!var_base_temp->HasGradVar()) {
            VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
180
                    << " in OP " << type_ << " is null";
181 182
            var_base_temp->MutableGradVarBase();
          }
H
hong 已提交
183
          auto grad_var_base_tmp = var_base_temp->GradVarBase();
184

185 186 187 188 189 190
          if (!is_input) {
            auto* tensor = grad_var_base_tmp->MutableVar()
                               ->GetMutable<framework::LoDTensor>();
            tensor->Resize(
                var_base_temp->Var().Get<framework::LoDTensor>().dims());
          }
H
hong 已提交
191 192 193 194
          vec_temp.emplace_back(grad_var_base_tmp);
        } else {
          vec_temp.emplace_back(var_base_temp);
        }
195 196 197 198 199
        is_valid = true;
      }

      if (!is_valid) {
        vec_temp.clear();
H
hong 已提交
200 201 202 203 204 205 206
      }
    }

    return vec_temp;
  }

 private:
207
  const std::string& type_;
H
hong 已提交
208 209
  const NameVarBaseMap& var_base_map_in_;
  const NameVarBaseMap& var_base_map_out_;
210
  const framework::AttributeMap& attrs_;
211
  const framework::AttributeMap* default_attrs_;
212
  const std::map<std::string, std::string>& inplace_map_;
213
};
H
hong 已提交
214

215 216 217 218
class TracedGradOp {
  DISABLE_COPY_AND_ASSIGN(TracedGradOp);

 public:
219 220
  explicit TracedGradOp(const std::shared_ptr<GradOpNode>& node)
      : node_(node), op_(&(node->emplace_back())) {}
221 222

  ~TracedGradOp() {
223 224 225 226 227
    if (UNLIKELY(op_->GetOutsMap().empty())) {
      node_->pop_back();
    } else {
      op_->CheckAttrs();
    }
228 229 230 231 232
  }

  template <TracedVarRole kRole>
  void SetInput(const std::string& name,
                const TracedVarList<VarBase, kRole>& vars) {
233 234 235 236
    if (vars.empty()) {
      return;
    }

237 238
    if (kRole == TracedVarRole::kBackward) {
      for (auto& var : vars) {
239
        if (var && !var->OverridedStopGradient()) {
240
          var->SetGraphIsFreed(false);
241 242 243 244
          auto dirty_grad_node = var->GradNode();
          if (dirty_grad_node) {
            map_dirty_grad_node_[var] = dirty_grad_node;
          }
245 246
          var->SetGradNode(node_);
        }
247 248
      }
    }
249 250

    auto var_wrappers = ToVarWrapperList<kRole>(vars);
251

252 253 254 255
    if (!var_wrappers.empty()) {
      op_->SetInput(name, std::move(var_wrappers),
                    kRole == TracedVarRole::kBackward);
    }
256 257 258 259 260
  }

  template <TracedVarRole kRole>
  void SetOutput(const std::string& name,
                 const TracedVarList<VarBase, kRole>& vars) {
261 262 263 264
    if (vars.empty()) {
      return;
    }

265 266 267 268 269
    if (kRole == TracedVarRole::kBackward) {
      if (vars.size() == 1 && vars.front()->OverridedStopGradient()) {
        return;
      } else {
        for (auto& var : vars) {
270
          if (var && !var->OverridedStopGradient() && var->GradNode()) {
271 272 273 274 275
            if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) {
              node_->InsertGradPendingNode(map_dirty_grad_node_[var]);
            } else {
              node_->InsertGradPendingNode(var->GradNode());
            }
276 277 278 279 280
          }
        }
      }
    }

281 282 283 284 285
    auto var_wrappers = ToVarWrapperList<kRole>(vars);
    if (!var_wrappers.empty()) {
      op_->SetOutput(name, std::move(var_wrappers),
                     kRole == TracedVarRole::kBackward);
    }
286 287
  }

288 289
  std::string Type() const { return op_->Type(); }

290 291
  void SetType(const std::string& type) { op_->SetType(type); }

292 293
  const framework::OperatorBase& InnerOp() const { return op_->InnerOp(); }

294 295 296 297
  void SetAttrMap(const framework::AttributeMap& attrs) {
    return op_->SetAttrMap(attrs);
  }

298 299 300 301
  void SetDefaultAttrsMap(const framework::AttributeMap& attrs) {
    return op_->SetDefaultAttrsMap(attrs);
  }

302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
  void SetAttr(const std::string& name, const framework::Attribute& v) {
    op_->SetAttr(name, v);
  }

  bool HasAttr(const std::string& name) const { return op_->HasAttr(name); }

  const framework::Attribute& GetAttr(const std::string& name) const {
    return op_->GetAttr(name);
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
    return op_->Attr<T>(name);
  }

 private:
318
  template <TracedVarRole kRole>
319 320 321 322
  static std::vector<std::shared_ptr<VariableWrapper>> ToVarWrapperList(
      const std::vector<std::shared_ptr<VarBase>>& vars) {
    std::vector<std::shared_ptr<VariableWrapper>> result;
    result.reserve(vars.size());
323
    bool has_valid = false;
324
    for (auto& var : vars) {
325 326 327 328
      if (UNLIKELY(!var || (kRole == TracedVarRole::kBackward &&
                            var->OverridedStopGradient()))) {
        result.emplace_back();
      } else {
329 330
        auto var_wrapper = SnapshotVarWrapper(var->SharedVar());
        result.emplace_back(var_wrapper);
331 332 333 334 335 336
        has_valid = true;
      }
    }

    if (!has_valid) {
      result.clear();
337 338 339 340
    }
    return result;
  }

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
  // Get a snapshot of VariableWrapper at a certain inplace version.
  // The inplace version number of VariableWrapper is used for inplace
  // detection in gradient compution.
  static const std::shared_ptr<VariableWrapper> SnapshotVarWrapper(
      const std::shared_ptr<VariableWrapper>& var_wrapper) {
    // NOTE(liym27):
    //  Use original var_wrapper if its inplace_version is not
    //  changed. Otherwise, it will affect the accuracy of the model
    //  results and affect double grad.
    if (!var_wrapper->MutableVar()->IsInitialized() ||
        var_wrapper->InplaceVersionSnapshot() ==
            var_wrapper->MutableVar()->CurrentInplaceVersion()) {
      return var_wrapper;
    } else {
      VariableWrapper new_var_wrapper = *var_wrapper.get();
      new_var_wrapper.ResetInplaceVersion();
      return std::make_shared<VariableWrapper>(new_var_wrapper);
    }
  }

361
 private:
362 363
  const std::shared_ptr<GradOpNode>& node_;
  OpBase* op_;
364 365 366 367 368 369
  // Inplace op has recursion problems when performing grad calculation.
  // Because the input and output of inplace op are the same, the grad
  // node of inplace var will be overwritten.
  // This map is used to store the grad node of inplace var in temporary.
  std::unordered_map<std::shared_ptr<VarBase>, std::shared_ptr<GradOpNode>>
      map_dirty_grad_node_;
H
hong 已提交
370 371 372 373
};

}  // namespace imperative
}  // namespace paddle