grad_op_desc_maker.h 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
16 17
#include <string>
#include <unordered_set>
18 19 20 21 22 23 24 25
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

class GradOpDescMakerBase {
 public:
26 27 28 29
  explicit GradOpDescMakerBase(
      const OpDescBind& fwd_op,
      const std::unordered_set<std::string>& no_grad_set)
      : fwd_op_(fwd_op), no_grad_set_(no_grad_set) {}
30 31

  virtual ~GradOpDescMakerBase() = default;
Y
Yu Yang 已提交
32
  virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
33 34

 protected:
35 36
  std::vector<std::string> InputGrad(const std::string& name,
                                     bool drop_empty_grad = true) const {
37
    std::vector<std::string> ret_val;
38
    auto var_names = this->Input(name);
39
    ret_val.reserve(var_names.size());
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    std::transform(
        var_names.begin(), var_names.end(), std::back_inserter(ret_val),
        [this](const std::string& fwd_var_name) -> std::string {
          auto g_name = GradVarName(fwd_var_name);
          return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName;
        });
    if (!drop_empty_grad) {
      return ret_val;
    }
    std::vector<std::string> dropped_ret_val;
    dropped_ret_val.reserve(ret_val.size());
    std::copy_if(ret_val.begin(), ret_val.end(),
                 std::back_inserter(dropped_ret_val),
                 [](const std::string& str) { return str != kEmptyVarName; });
    return dropped_ret_val;
55 56 57
  }

  std::vector<std::string> OutputGrad(const std::string& name) const {
58 59 60 61 62 63
    std::vector<std::string> ret_val;
    auto onames = this->Output(name);
    ret_val.reserve(onames.size());
    std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
                   GradVarName);
    return ret_val;
64 65
  }

Y
Yu Yang 已提交
66 67
  std::vector<std::string> InputNames() const {
    return this->fwd_op_.InputNames();
68 69
  }

Y
Yu Yang 已提交
70 71
  std::vector<std::string> OutputNames() const {
    return this->fwd_op_.OutputNames();
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
  }

  std::vector<std::string> Input(const std::string& name) const {
    return fwd_op_.Input(name);
  }

  std::vector<std::string> Output(const std::string& name) const {
    return fwd_op_.Output(name);
  }

  const std::unordered_map<std::string, Attribute>& Attrs() const {
    return fwd_op_.GetAttrMap();
  }

  const Attribute& GetAttr(const std::string& name) const {
    auto& map = fwd_op_.GetAttrMap();
    auto it = map.find(name);
    PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
    return it->second;
  }

  std::string ForwardOpType() const { return this->fwd_op_.Type(); }

 private:
  const OpDescBind& fwd_op_;
97
  const std::unordered_set<std::string>& no_grad_set_;
98 99 100 101
};

class SingleGradOpDescMaker : public GradOpDescMakerBase {
 public:
Y
Yu Yang 已提交
102 103
  using GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
104 105 106 107 108
  std::vector<std::unique_ptr<OpDescBind>> operator()() const {
    std::vector<std::unique_ptr<OpDescBind>> retv;
    retv.emplace_back(this->Apply());
    return retv;
  }
109 110

 protected:
Y
Yu Yang 已提交
111
  virtual std::unique_ptr<OpDescBind> Apply() const = 0;
112 113
};

114
template <bool DropEmptyIG = true>
115
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
Y
Yu Yang 已提交
116 117 118
 public:
  using SingleGradOpDescMaker::SingleGradOpDescMaker;

119
 protected:
Y
Yu Yang 已提交
120 121 122
  virtual std::unique_ptr<OpDescBind> Apply() const {
    auto* grad = new OpDescBind();
    grad->SetType(this->GradOpType());
123

Y
Yu Yang 已提交
124
    for (auto& input_param : this->InputNames()) {
Y
Yu Yang 已提交
125
      grad->SetInput(input_param, this->Input(input_param));
126 127
      grad->SetOutput(GradVarName(input_param),
                      this->InputGrad(input_param, DropEmptyIG));
128 129
    }

Y
Yu Yang 已提交
130
    for (auto& output_param : this->OutputNames()) {
Y
Yu Yang 已提交
131 132
      grad->SetInput(output_param, this->Output(output_param));
      grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
133 134
    }

Y
Yu Yang 已提交
135
    grad->SetAttrMap(this->Attrs());
136

Y
Yu Yang 已提交
137
    return std::unique_ptr<OpDescBind>(grad);
138 139 140 141 142 143 144 145 146
  }

  virtual std::string GradOpType() const {
    return this->ForwardOpType() + "_grad";
  }
};

}  // namespace framework
}  // namespace paddle