grad_op_desc_maker.h 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
16 17
#include <string>
#include <unordered_set>
18 19 20 21 22 23 24 25
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

class GradOpDescMakerBase {
 public:
26 27
  explicit GradOpDescMakerBase(
      const OpDescBind& fwd_op,
28 29 30
      const std::unordered_set<std::string>& no_grad_set,
      std::unordered_map<std::string, std::string>* grad_to_var)
      : fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
31 32

  virtual ~GradOpDescMakerBase() = default;
Y
Yu Yang 已提交
33
  virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
34 35

 protected:
36 37
  std::vector<std::string> InputGrad(const std::string& name,
                                     bool drop_empty_grad = true) const {
38
    std::vector<std::string> ret_val;
39
    auto var_names = this->Input(name);
40
    ret_val.reserve(var_names.size());
41 42 43 44 45 46 47 48 49 50 51
    std::transform(var_names.begin(), var_names.end(),
                   std::back_inserter(ret_val),
                   [this](const std::string& fwd_var_name) -> std::string {
                     auto g_name = GradVarName(fwd_var_name);
                     if (no_grad_set_.count(g_name)) {
                       return kEmptyVarName;
                     } else {
                       (*this->grad_to_var_)[g_name] = fwd_var_name;
                       return g_name;
                     }
                   });
52 53 54 55 56 57 58 59 60
    if (!drop_empty_grad) {
      return ret_val;
    }
    std::vector<std::string> dropped_ret_val;
    dropped_ret_val.reserve(ret_val.size());
    std::copy_if(ret_val.begin(), ret_val.end(),
                 std::back_inserter(dropped_ret_val),
                 [](const std::string& str) { return str != kEmptyVarName; });
    return dropped_ret_val;
61 62 63
  }

  std::vector<std::string> OutputGrad(const std::string& name) const {
64 65 66 67 68 69
    std::vector<std::string> ret_val;
    auto onames = this->Output(name);
    ret_val.reserve(onames.size());
    std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
                   GradVarName);
    return ret_val;
70 71
  }

Y
Yu Yang 已提交
72 73
  std::vector<std::string> InputNames() const {
    return this->fwd_op_.InputNames();
74 75
  }

Y
Yu Yang 已提交
76 77
  std::vector<std::string> OutputNames() const {
    return this->fwd_op_.OutputNames();
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  }

  std::vector<std::string> Input(const std::string& name) const {
    return fwd_op_.Input(name);
  }

  std::vector<std::string> Output(const std::string& name) const {
    return fwd_op_.Output(name);
  }

  const std::unordered_map<std::string, Attribute>& Attrs() const {
    return fwd_op_.GetAttrMap();
  }

  const Attribute& GetAttr(const std::string& name) const {
    auto& map = fwd_op_.GetAttrMap();
    auto it = map.find(name);
    PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
    return it->second;
  }

  std::string ForwardOpType() const { return this->fwd_op_.Type(); }

 private:
  const OpDescBind& fwd_op_;
103
  const std::unordered_set<std::string>& no_grad_set_;
104
  std::unordered_map<std::string, std::string>* grad_to_var_;
105 106 107 108
};

class SingleGradOpDescMaker : public GradOpDescMakerBase {
 public:
Y
Yu Yang 已提交
109 110
  using GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
111 112 113 114 115
  std::vector<std::unique_ptr<OpDescBind>> operator()() const {
    std::vector<std::unique_ptr<OpDescBind>> retv;
    retv.emplace_back(this->Apply());
    return retv;
  }
116 117

 protected:
Y
Yu Yang 已提交
118
  virtual std::unique_ptr<OpDescBind> Apply() const = 0;
119 120
};

121
template <bool DropEmptyIG = true>
122
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
Y
Yu Yang 已提交
123 124 125
 public:
  using SingleGradOpDescMaker::SingleGradOpDescMaker;

126
 protected:
Y
Yu Yang 已提交
127 128 129
  virtual std::unique_ptr<OpDescBind> Apply() const {
    auto* grad = new OpDescBind();
    grad->SetType(this->GradOpType());
130

Y
Yu Yang 已提交
131
    for (auto& input_param : this->InputNames()) {
Y
Yu Yang 已提交
132
      grad->SetInput(input_param, this->Input(input_param));
133 134
      grad->SetOutput(GradVarName(input_param),
                      this->InputGrad(input_param, DropEmptyIG));
135 136
    }

Y
Yu Yang 已提交
137
    for (auto& output_param : this->OutputNames()) {
Y
Yu Yang 已提交
138 139
      grad->SetInput(output_param, this->Output(output_param));
      grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
140 141
    }

Y
Yu Yang 已提交
142
    grad->SetAttrMap(this->Attrs());
143

Y
Yu Yang 已提交
144
    return std::unique_ptr<OpDescBind>(grad);
145 146 147 148 149 150 151 152 153
  }

  virtual std::string GradOpType() const {
    return this->ForwardOpType() + "_grad";
  }
};

}  // namespace framework
}  // namespace paddle