grad_op_desc_maker.h 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
16 17
#include <string>
#include <unordered_set>
Y
Yu Yang 已提交
18
#include <vector>
19 20 21 22 23 24 25 26
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

class GradOpDescMakerBase {
 public:
27 28
  explicit GradOpDescMakerBase(
      const OpDescBind& fwd_op,
29
      const std::unordered_set<std::string>& no_grad_set,
Y
Yu Yang 已提交
30 31 32 33 34 35 36
      std::unordered_map<std::string, std::string>* grad_to_var,
      const std::vector<BlockDescBind*>& grad_block =
          std::vector<BlockDescBind*>())
      : fwd_op_(fwd_op),
        no_grad_set_(no_grad_set),
        grad_to_var_(grad_to_var),
        grad_block_(grad_block) {}
37 38

  virtual ~GradOpDescMakerBase() = default;
Y
Yu Yang 已提交
39
  virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
40 41

 protected:
42 43
  std::vector<std::string> InputGrad(const std::string& name,
                                     bool drop_empty_grad = true) const {
44
    std::vector<std::string> ret_val;
45
    auto var_names = this->Input(name);
46
    ret_val.reserve(var_names.size());
47 48 49 50 51 52 53 54 55 56 57
    std::transform(var_names.begin(), var_names.end(),
                   std::back_inserter(ret_val),
                   [this](const std::string& fwd_var_name) -> std::string {
                     auto g_name = GradVarName(fwd_var_name);
                     if (no_grad_set_.count(g_name)) {
                       return kEmptyVarName;
                     } else {
                       (*this->grad_to_var_)[g_name] = fwd_var_name;
                       return g_name;
                     }
                   });
58 59 60 61 62 63 64 65 66
    if (!drop_empty_grad) {
      return ret_val;
    }
    std::vector<std::string> dropped_ret_val;
    dropped_ret_val.reserve(ret_val.size());
    std::copy_if(ret_val.begin(), ret_val.end(),
                 std::back_inserter(dropped_ret_val),
                 [](const std::string& str) { return str != kEmptyVarName; });
    return dropped_ret_val;
67 68 69
  }

  std::vector<std::string> OutputGrad(const std::string& name) const {
70 71 72 73 74 75
    std::vector<std::string> ret_val;
    auto onames = this->Output(name);
    ret_val.reserve(onames.size());
    std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
                   GradVarName);
    return ret_val;
76 77
  }

Y
Yu Yang 已提交
78 79
  std::vector<std::string> InputNames() const {
    return this->fwd_op_.InputNames();
80 81
  }

Y
Yu Yang 已提交
82 83
  std::vector<std::string> OutputNames() const {
    return this->fwd_op_.OutputNames();
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  }

  std::vector<std::string> Input(const std::string& name) const {
    return fwd_op_.Input(name);
  }

  std::vector<std::string> Output(const std::string& name) const {
    return fwd_op_.Output(name);
  }

  const std::unordered_map<std::string, Attribute>& Attrs() const {
    return fwd_op_.GetAttrMap();
  }

  const Attribute& GetAttr(const std::string& name) const {
    auto& map = fwd_op_.GetAttrMap();
    auto it = map.find(name);
    PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
    return it->second;
  }

  std::string ForwardOpType() const { return this->fwd_op_.Type(); }

 private:
  const OpDescBind& fwd_op_;
109
  const std::unordered_set<std::string>& no_grad_set_;
110
  std::unordered_map<std::string, std::string>* grad_to_var_;
Y
Yu Yang 已提交
111 112 113

 protected:
  std::vector<BlockDescBind*> grad_block_;
114 115 116 117
};

class SingleGradOpDescMaker : public GradOpDescMakerBase {
 public:
Y
Yu Yang 已提交
118 119
  using GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
120 121 122 123 124
  std::vector<std::unique_ptr<OpDescBind>> operator()() const {
    std::vector<std::unique_ptr<OpDescBind>> retv;
    retv.emplace_back(this->Apply());
    return retv;
  }
125 126

 protected:
Y
Yu Yang 已提交
127
  virtual std::unique_ptr<OpDescBind> Apply() const = 0;
128 129
};

130
template <bool DropEmptyIG = true>
131
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
Y
Yu Yang 已提交
132 133 134
 public:
  using SingleGradOpDescMaker::SingleGradOpDescMaker;

135
 protected:
Y
Yu Yang 已提交
136 137 138
  virtual std::unique_ptr<OpDescBind> Apply() const {
    auto* grad = new OpDescBind();
    grad->SetType(this->GradOpType());
139

Y
Yu Yang 已提交
140
    for (auto& input_param : this->InputNames()) {
Y
Yu Yang 已提交
141
      grad->SetInput(input_param, this->Input(input_param));
142 143
      grad->SetOutput(GradVarName(input_param),
                      this->InputGrad(input_param, DropEmptyIG));
144 145
    }

Y
Yu Yang 已提交
146
    for (auto& output_param : this->OutputNames()) {
Y
Yu Yang 已提交
147 148
      grad->SetInput(output_param, this->Output(output_param));
      grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
149 150
    }

Y
Yu Yang 已提交
151
    grad->SetAttrMap(this->Attrs());
152

Y
Yu Yang 已提交
153
    return std::unique_ptr<OpDescBind>(grad);
154 155 156 157 158 159 160
  }

  virtual std::string GradOpType() const {
    return this->ForwardOpType() + "_grad";
  }
};

Y
Yu Yang 已提交
161 162 163 164 165 166 167 168
class EmptyGradOpMaker : public GradOpDescMakerBase {
 public:
  using GradOpDescMakerBase::GradOpDescMakerBase;
  std::vector<std::unique_ptr<OpDescBind>> operator()() const override {
    return {};
  }
};

169 170
}  // namespace framework
}  // namespace paddle