dygraph_grad_maker.h 8.8 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
20
#include <unordered_set>
21
#include <utility>
H
hong 已提交
22 23 24
#include <vector>

#include "paddle/fluid/imperative/layer.h"
25
#include "paddle/fluid/imperative/op_base.h"
H
hong 已提交
26 27 28 29 30 31 32
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace imperative {

33 34 35 36 37 38 39 40 41 42 43
enum TracedVarRole { kForward = 0, kBackward = 1 };

template <typename T, TracedVarRole kRole>
class TracedVarList : public std::vector<std::shared_ptr<T>> {
 private:
  using BaseClass = std::vector<std::shared_ptr<T>>;

 public:
  using BaseClass::BaseClass;
};

H
hong 已提交
44 45
class GradOpBaseMakerBase {
 public:
46
  explicit GradOpBaseMakerBase(const std::string& type,
H
hong 已提交
47
                               const NameVarBaseMap& var_base_map_in,
48 49 50
                               const NameVarBaseMap& var_base_map_out,
                               const framework::AttributeMap& attrs)
      : type_(type),
H
hong 已提交
51
        var_base_map_in_(var_base_map_in),
52 53
        var_base_map_out_(var_base_map_out),
        attrs_(attrs) {}
H
hong 已提交
54 55

  virtual ~GradOpBaseMakerBase() = default;
56

57
  virtual std::shared_ptr<GradOpNode> operator()() const = 0;
H
hong 已提交
58

59
  TracedVarList<VarBase, TracedVarRole::kBackward> InputGrad(
H
hong 已提交
60
      const std::string& name, bool drop_empty_grad = true) const {
61
    return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/true);
H
hong 已提交
62 63
  }

64
  TracedVarList<VarBase, TracedVarRole::kBackward> OutputGrad(
H
hong 已提交
65
      const std::string& name) const {
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/false);
  }

  TracedVarList<VarBase, TracedVarRole::kForward> Input(
      const std::string& name) const {
    return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/true);
  }

  TracedVarList<VarBase, TracedVarRole::kForward> Output(
      const std::string& name) const {
    return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/false);
  }

  static TracedVarList<VarBase, TracedVarRole::kForward> EmptyInput() {
    return {};
H
hong 已提交
81 82
  }

83 84
  static TracedVarList<VarBase, TracedVarRole::kForward> EmptyOutput() {
    return {};
H
hong 已提交
85 86
  }

87 88
  static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyOutputGrad() {
    return {};
H
hong 已提交
89 90
  }

91 92 93
  static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyInputGrad() {
    return {};
  }
H
hong 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

  std::vector<std::string> InputNames() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_in_.size());
    for (auto& it : var_base_map_in_) {
      vec_temp.emplace_back(it.first);
    }
    return vec_temp;
  }

  std::vector<std::string> OutputNames() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(var_base_map_out_.size());
    for (auto& it : var_base_map_out_) {
      vec_temp.emplace_back(it.first);
    }
    return vec_temp;
  }

113
  const framework::AttributeMap& Attrs() const { return attrs_; }
H
hong 已提交
114 115

  const framework::Attribute& GetAttr(const std::string& name) const {
116 117 118 119 120
    auto it = attrs_.find(name);
    PADDLE_ENFORCE_EQ(
        it != attrs_.end(), true,
        platform::errors::NotFound(
            "Cannot find attribute [%s] in operator [%s]", name, type_));
H
hong 已提交
121 122 123 124 125
    return it->second;
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
126
    return BOOST_GET_CONST(T, GetAttr(name));
H
hong 已提交
127 128
  }

129
  const std::string& ForwardOpType() const { return type_; }
H
hong 已提交
130 131 132

 protected:
  bool HasInput(const std::string& name) const {
133
    return var_base_map_in_.count(name) > 0;
H
hong 已提交
134 135
  }

136 137
  bool HasOutput(const std::string& name) const {
    return var_base_map_out_.count(name) > 0;
138 139
  }

140 141 142 143
  static std::shared_ptr<GradOpNode> NewGradNode() {
    return std::make_shared<GradOpNode>();
  }

H
hong 已提交
144
 private:
145 146 147 148
  template <TracedVarRole kRole>
  TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
                                               bool is_input) const {
    const auto& data_map = is_input ? var_base_map_in_ : var_base_map_out_;
H
hong 已提交
149 150
    auto iterator = data_map.find(name);

151
    TracedVarList<VarBase, kRole> vec_temp;
H
hong 已提交
152 153 154
    if (iterator != data_map.end()) {
      vec_temp.reserve(iterator->second.size());

155
      bool is_valid = false;
H
hong 已提交
156
      for (auto& var_base_temp : iterator->second) {
157 158 159 160 161
        if (!var_base_temp) {
          vec_temp.emplace_back();
          continue;
        }

162
        if (kRole == TracedVarRole::kBackward) {
163 164
          if (!var_base_temp->HasGradVar()) {
            VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
165
                    << " in OP " << type_ << " is null";
166 167
            var_base_temp->MutableGradVarBase();
          }
H
hong 已提交
168
          auto grad_var_base_tmp = var_base_temp->GradVarBase();
169

170 171 172 173 174 175
          if (!is_input) {
            auto* tensor = grad_var_base_tmp->MutableVar()
                               ->GetMutable<framework::LoDTensor>();
            tensor->Resize(
                var_base_temp->Var().Get<framework::LoDTensor>().dims());
          }
H
hong 已提交
176 177 178 179
          vec_temp.emplace_back(grad_var_base_tmp);
        } else {
          vec_temp.emplace_back(var_base_temp);
        }
180 181 182 183 184
        is_valid = true;
      }

      if (!is_valid) {
        vec_temp.clear();
H
hong 已提交
185 186 187 188 189 190 191
      }
    }

    return vec_temp;
  }

 private:
192
  const std::string& type_;
H
hong 已提交
193 194
  const NameVarBaseMap& var_base_map_in_;
  const NameVarBaseMap& var_base_map_out_;
195 196
  const framework::AttributeMap& attrs_;
};
H
hong 已提交
197

198 199 200 201
class TracedGradOp {
  DISABLE_COPY_AND_ASSIGN(TracedGradOp);

 public:
202 203
  explicit TracedGradOp(const std::shared_ptr<GradOpNode>& node)
      : node_(node), op_(&(node->emplace_back())) {}
204 205

  ~TracedGradOp() {
206 207 208 209 210
    if (UNLIKELY(op_->GetOutsMap().empty())) {
      node_->pop_back();
    } else {
      op_->CheckAttrs();
    }
211 212 213 214 215
  }

  template <TracedVarRole kRole>
  void SetInput(const std::string& name,
                const TracedVarList<VarBase, kRole>& vars) {
216 217 218 219
    if (vars.empty()) {
      return;
    }

220 221
    if (kRole == TracedVarRole::kBackward) {
      for (auto& var : vars) {
222 223 224
        if (var && !var->OverridedStopGradient()) {
          var->SetGradNode(node_);
        }
225 226
      }
    }
227 228 229 230 231 232

    auto var_wrappers = ToVarWrapperList<kRole>(vars);
    if (!var_wrappers.empty()) {
      op_->SetInput(name, std::move(var_wrappers),
                    kRole == TracedVarRole::kBackward);
    }
233 234 235 236 237
  }

  template <TracedVarRole kRole>
  void SetOutput(const std::string& name,
                 const TracedVarList<VarBase, kRole>& vars) {
238 239 240 241
    if (vars.empty()) {
      return;
    }

242 243 244 245 246
    if (kRole == TracedVarRole::kBackward) {
      if (vars.size() == 1 && vars.front()->OverridedStopGradient()) {
        return;
      } else {
        for (auto& var : vars) {
247 248
          if (var && !var->OverridedStopGradient() && var->GradNode()) {
            node_->InsertGradPendingNode(var->GradNode());
249 250 251 252 253
          }
        }
      }
    }

254 255 256 257 258
    auto var_wrappers = ToVarWrapperList<kRole>(vars);
    if (!var_wrappers.empty()) {
      op_->SetOutput(name, std::move(var_wrappers),
                     kRole == TracedVarRole::kBackward);
    }
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
  }

  void SetType(const std::string& type) { op_->SetType(type); }

  void SetAttrMap(const framework::AttributeMap& attrs) {
    return op_->SetAttrMap(attrs);
  }

  void SetAttr(const std::string& name, const framework::Attribute& v) {
    op_->SetAttr(name, v);
  }

  bool HasAttr(const std::string& name) const { return op_->HasAttr(name); }

  const framework::Attribute& GetAttr(const std::string& name) const {
    return op_->GetAttr(name);
  }

  template <typename T>
  inline const T& Attr(const std::string& name) const {
    return op_->Attr<T>(name);
  }

 private:
283
  template <TracedVarRole kRole>
284 285 286 287
  static std::vector<std::shared_ptr<VariableWrapper>> ToVarWrapperList(
      const std::vector<std::shared_ptr<VarBase>>& vars) {
    std::vector<std::shared_ptr<VariableWrapper>> result;
    result.reserve(vars.size());
288
    bool has_valid = false;
289
    for (auto& var : vars) {
290 291 292 293 294 295 296 297 298 299 300
      if (UNLIKELY(!var || (kRole == TracedVarRole::kBackward &&
                            var->OverridedStopGradient()))) {
        result.emplace_back();
      } else {
        result.emplace_back(var->SharedVar());
        has_valid = true;
      }
    }

    if (!has_valid) {
      result.clear();
301 302 303 304 305
    }
    return result;
  }

 private:
306 307
  const std::shared_ptr<GradOpNode>& node_;
  OpBase* op_;
H
hong 已提交
308 309 310 311
};

}  // namespace imperative
}  // namespace paddle