grad_op_desc_maker.h 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {

class GradOpDescMakerBase {
 public:
  explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}

  virtual ~GradOpDescMakerBase() = default;
Y
Yu Yang 已提交
27
  virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

 protected:
  static std::vector<std::string> ToGradNames(
      const std::vector<std::string>& var_names) {
    std::vector<std::string> ret_val;
    ret_val.reserve(var_names.size());
    std::transform(var_names.begin(), var_names.end(),
                   std::back_inserter(ret_val), GradVarName);
    return ret_val;
  }

  std::vector<std::string> InputGrad(const std::string& name) const {
    return ToGradNames(fwd_op_.Input(name));
  }

  std::vector<std::string> OutputGrad(const std::string& name) const {
    return ToGradNames(fwd_op_.Output(name));
  }

Y
Yu Yang 已提交
47 48
  std::vector<std::string> InputNames() const {
    return this->fwd_op_.InputNames();
49 50
  }

Y
Yu Yang 已提交
51 52
  std::vector<std::string> OutputNames() const {
    return this->fwd_op_.OutputNames();
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  }

  std::vector<std::string> Input(const std::string& name) const {
    return fwd_op_.Input(name);
  }

  std::vector<std::string> Output(const std::string& name) const {
    return fwd_op_.Output(name);
  }

  const std::unordered_map<std::string, Attribute>& Attrs() const {
    return fwd_op_.GetAttrMap();
  }

  const Attribute& GetAttr(const std::string& name) const {
    auto& map = fwd_op_.GetAttrMap();
    auto it = map.find(name);
    PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
    return it->second;
  }

  std::string ForwardOpType() const { return this->fwd_op_.Type(); }

 private:
  const OpDescBind& fwd_op_;
};

class SingleGradOpDescMaker : public GradOpDescMakerBase {
 public:
Y
Yu Yang 已提交
82 83
  using GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
84 85 86 87 88
  std::vector<std::unique_ptr<OpDescBind>> operator()() const {
    std::vector<std::unique_ptr<OpDescBind>> retv;
    retv.emplace_back(this->Apply());
    return retv;
  }
89 90

 protected:
Y
Yu Yang 已提交
91
  virtual std::unique_ptr<OpDescBind> Apply() const = 0;
92 93 94
};

class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
Y
Yu Yang 已提交
95 96 97
 public:
  using SingleGradOpDescMaker::SingleGradOpDescMaker;

98
 protected:
Y
Yu Yang 已提交
99 100 101
  virtual std::unique_ptr<OpDescBind> Apply() const {
    auto* grad = new OpDescBind();
    grad->SetType(this->GradOpType());
102

Y
Yu Yang 已提交
103
    for (auto& input_param : this->InputNames()) {
Y
Yu Yang 已提交
104 105
      grad->SetInput(input_param, this->Input(input_param));
      grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
106 107
    }

Y
Yu Yang 已提交
108
    for (auto& output_param : this->OutputNames()) {
Y
Yu Yang 已提交
109 110
      grad->SetInput(output_param, this->Output(output_param));
      grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
111 112
    }

Y
Yu Yang 已提交
113
    grad->SetAttrMap(this->Attrs());
114

Y
Yu Yang 已提交
115
    return std::unique_ptr<OpDescBind>(grad);
116 117 118 119 120 121 122 123 124
  }

  virtual std::string GradOpType() const {
    return this->ForwardOpType() + "_grad";
  }
};

}  // namespace framework
}  // namespace paddle