grad_op_desc_maker.h 8.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16
#include <algorithm>
M
minqiyang 已提交
17
#include <memory>
18
#include <string>
M
minqiyang 已提交
19
#include <unordered_map>
20
#include <unordered_set>
Y
Yu Yang 已提交
21
#include <vector>
Y
Yi Wang 已提交
22 23
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
H
hong 已提交
24 25 26
#include "paddle/fluid/imperative/dygraph_grad_maker.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
27 28 29 30

namespace paddle {
namespace framework {

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
namespace details {

template <typename T>
struct GradOpPtrTrait {};

template <>
struct GradOpPtrTrait<OpDesc> {
  using Type = OpDesc*;
};

template <>
struct GradOpPtrTrait<imperative::OpBase> {
  using Type = imperative::TracedGradOp*;
};

}  // namespace details

template <typename T>
using GradOpPtr = typename details::GradOpPtrTrait<T>::Type;

51 52 53 54 55 56 57 58
/*
  This functor class is responsible for creating the gradient ops for the given
  operator fwd_op. After it is called (through operator()), the pairs of
  (gradient variable, corresponding input variable of fwd_op) will be added to
  grad_to_var. If an input variable of fwd_op is contained in no_grad_set, its
  gradient varialbe will be ignored or kEmptyVarName depending on the template
  argument DropEmptyIG in the derived classes.
 */
59 60
class GradOpDescMakerBase {
 public:
61
  explicit GradOpDescMakerBase(
Y
Yu Yang 已提交
62
      const OpDesc& fwd_op, const std::unordered_set<std::string>& no_grad_set,
Y
Yu Yang 已提交
63
      std::unordered_map<std::string, std::string>* grad_to_var,
Y
Yu Yang 已提交
64
      const std::vector<BlockDesc*>& grad_block = std::vector<BlockDesc*>())
Y
Yu Yang 已提交
65 66 67 68
      : fwd_op_(fwd_op),
        no_grad_set_(no_grad_set),
        grad_to_var_(grad_to_var),
        grad_block_(grad_block) {}
69

70 71 72 73
  static std::unique_ptr<OpDesc> CreateOp() {
    return std::unique_ptr<OpDesc>(new OpDesc());
  }

74
  virtual ~GradOpDescMakerBase() = default;
Y
Yu Yang 已提交
75
  virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0;
76 77

 protected:
78 79
  std::vector<std::string> InputGrad(const std::string& name,
                                     bool drop_empty_grad = true) const {
80
    std::vector<std::string> ret_val;
81
    auto var_names = this->Input(name);
82
    ret_val.reserve(var_names.size());
83 84 85 86
    std::transform(var_names.begin(), var_names.end(),
                   std::back_inserter(ret_val),
                   [this](const std::string& fwd_var_name) -> std::string {
                     auto g_name = GradVarName(fwd_var_name);
M
minqiyang 已提交
87
                     if (no_grad_set_.empty() || !no_grad_set_.count(g_name)) {
M
minqiyang 已提交
88 89
                       (*this->grad_to_var_)[g_name] = fwd_var_name;
                       return g_name;
90
                     } else {
M
minqiyang 已提交
91
                       return kEmptyVarName;
92 93
                     }
                   });
94 95 96
    if (!drop_empty_grad) {
      return ret_val;
    }
97 98 99 100 101
    PADDLE_ENFORCE_LE(var_names.size(), 1UL,
                      "BUG from operator developer:"
                      " for input argument with a list of variables, "
                      " drop_empty_grad is not allowed because it makes"
                      " the correspondence bewteen a variable and its gradient"
102
                      " ambiguous."
103 104 105
                      " Op type %s",
                      fwd_op_.Type());

106 107 108 109 110 111
    std::vector<std::string> dropped_ret_val;
    dropped_ret_val.reserve(ret_val.size());
    std::copy_if(ret_val.begin(), ret_val.end(),
                 std::back_inserter(dropped_ret_val),
                 [](const std::string& str) { return str != kEmptyVarName; });
    return dropped_ret_val;
112 113 114
  }

  std::vector<std::string> OutputGrad(const std::string& name) const {
115 116 117 118
    std::vector<std::string> ret_val;
    auto onames = this->Output(name);
    ret_val.reserve(onames.size());
    std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
119 120 121 122 123
                   [this](const std::string& fwd_var_name) -> std::string {
                     auto g_name = GradVarName(fwd_var_name);
                     (*this->grad_to_var_)[g_name] = fwd_var_name;
                     return g_name;
                   });
124
    return ret_val;
125 126
  }

127 128 129 130 131 132 133
  static std::vector<std::string> EmptyInput() { return {}; }

  static std::vector<std::string> EmptyOutput() { return {}; }

  static std::vector<std::string> EmptyInputGrad() { return {}; }

  static std::vector<std::string> EmptyOutputGrad() { return {}; }
H
hong 已提交
134

Y
Yu Yang 已提交
135 136
  std::vector<std::string> InputNames() const {
    return this->fwd_op_.InputNames();
137 138
  }

Y
Yu Yang 已提交
139 140
  std::vector<std::string> OutputNames() const {
    return this->fwd_op_.OutputNames();
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  }

  std::vector<std::string> Input(const std::string& name) const {
    return fwd_op_.Input(name);
  }

  std::vector<std::string> Output(const std::string& name) const {
    return fwd_op_.Output(name);
  }

  const std::unordered_map<std::string, Attribute>& Attrs() const {
    return fwd_op_.GetAttrMap();
  }

  const Attribute& GetAttr(const std::string& name) const {
    auto& map = fwd_op_.GetAttrMap();
    auto it = map.find(name);
    PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
    return it->second;
  }

E
emailweixu 已提交
162 163 164 165 166
  template <typename T>
  inline const T& Attr(const std::string& name) const {
    return boost::get<T>(GetAttr(name));
  }

167 168
  std::string ForwardOpType() const { return this->fwd_op_.Type(); }

S
sneaxiy 已提交
169
 protected:
H
hong 已提交
170 171 172
  bool HasInput(const std::string& name) const {
    return (fwd_op_.Inputs().count(name) > 0);
  }
S
sneaxiy 已提交
173

174 175 176 177
  bool HasOutput(const std::string& name) const {
    return (fwd_op_.Outputs().count(name) > 0);
  }

178
 private:
Y
Yu Yang 已提交
179
  const OpDesc& fwd_op_;
180
  const std::unordered_set<std::string>& no_grad_set_;
181
  std::unordered_map<std::string, std::string>* grad_to_var_;
Y
Yu Yang 已提交
182 183

 protected:
Y
Yu Yang 已提交
184
  std::vector<BlockDesc*> grad_block_;
185 186
};

H
hong 已提交
187
template <typename T>
188
class SingleGradOpMaker {};
H
hong 已提交
189 190 191

template <>
class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
192
 public:
Y
Yu Yang 已提交
193 194
  using GradOpDescMakerBase::GradOpDescMakerBase;

195
  std::vector<std::unique_ptr<OpDesc>> operator()() const final {
Y
Yu Yang 已提交
196
    std::vector<std::unique_ptr<OpDesc>> retv;
197 198
    retv.emplace_back(new OpDesc());
    this->Apply(retv.front().get());
Y
Yu Yang 已提交
199 200
    return retv;
  }
201 202

 protected:
203
  virtual void Apply(GradOpPtr<OpDesc> op) const = 0;
204 205
};

H
hong 已提交
206 207 208 209 210 211
template <>
class SingleGradOpMaker<imperative::OpBase>
    : public imperative::GradOpBaseMakerBase {
 public:
  using GradOpBaseMakerBase::GradOpBaseMakerBase;

212
  std::shared_ptr<imperative::GradOpNode> operator()() const final {
213
    auto node = this->NewGradNode();
214
    {
215 216
      imperative::TracedGradOp traced_grad_op(node);
      this->Apply(&traced_grad_op);
217
    }
218
    return node->empty() ? nullptr : node;
H
hong 已提交
219
  }
Y
Yu Yang 已提交
220

221
 protected:
222
  virtual void Apply(GradOpPtr<imperative::OpBase> op) const = 0;
H
hong 已提交
223 224 225 226 227 228 229 230
};

template <typename T, bool DropEmptyIG = true>
class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
 public:
  using SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
231
  void Apply(GradOpPtr<T> grad) const final {
S
sneaxiy 已提交
232
    grad->SetType(this->ForwardOpType() + "_grad");
233

Y
Yu Yang 已提交
234
    for (auto& input_param : this->InputNames()) {
Y
Yu Yang 已提交
235
      grad->SetInput(input_param, this->Input(input_param));
236 237
      grad->SetOutput(GradVarName(input_param),
                      this->InputGrad(input_param, DropEmptyIG));
238 239
    }

Y
Yu Yang 已提交
240
    for (auto& output_param : this->OutputNames()) {
Y
Yu Yang 已提交
241 242
      grad->SetInput(output_param, this->Output(output_param));
      grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
243 244
    }

Y
Yu Yang 已提交
245
    grad->SetAttrMap(this->Attrs());
246 247 248
  }
};

H
hong 已提交
249
template <typename T>
250
class EmptyGradOpMaker {};
H
hong 已提交
251 252 253

template <>
class EmptyGradOpMaker<OpDesc> final : public GradOpDescMakerBase {
Y
Yu Yang 已提交
254 255
 public:
  using GradOpDescMakerBase::GradOpDescMakerBase;
S
sneaxiy 已提交
256
  std::vector<std::unique_ptr<OpDesc>> operator()() const final { return {}; }
Y
Yu Yang 已提交
257 258
};

H
hong 已提交
259 260 261 262 263
template <>
class EmptyGradOpMaker<imperative::OpBase> final
    : public imperative::GradOpBaseMakerBase {
 public:
  using GradOpBaseMakerBase::GradOpBaseMakerBase;
264 265 266

  std::shared_ptr<imperative::GradOpNode> operator()() const final {
    return nullptr;
H
hong 已提交
267 268 269
  }
};

270
}  // namespace framework
271 272 273 274 275 276 277 278

namespace operators {

template <typename T>
using GradOpPtr = framework::GradOpPtr<T>;

}  // namespace operators

279
}  // namespace paddle