grad_op_builder.h 1.3 KB
Newer Older
1 2 3 4 5 6 7 8 9
#pragma once

#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {
class OpRegistry;

F
fengjiayi 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
enum InOutType { IN, OUT };

struct OpInOutArg {
  OpInOutArg(const std::string& proto_name, const InOutType& type,
             bool needed_in_grad, size_t begin_idx, size_t end_idx)
      : proto_name_(proto_name),
        type_(type),
        needed_in_grad_(needed_in_grad),
        begin_idx_(begin_idx),
        end_idx_(end_idx) {}

  std::string proto_name_;
  InOutType type_;
  bool needed_in_grad_;
  size_t begin_idx_;
  size_t end_idx_;
};

28
class GradOpBuilder {
F
fengjiayi 已提交
29 30
  using VarIndexMap = std::unordered_map<std::string, int>;

31
 public:
32 33
  GradOpBuilder(const OperatorBase* op) : op_(op) {}
  OperatorBase* Build();
34 35 36

 private:
  OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
F
fengjiayi 已提交
37
                       const std::vector<int>& format, InOutType type);
38
  void BuildOpInOutArgList();
F
fengjiayi 已提交
39 40
  void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
                        std::vector<int>& format, VarIndexMap* varmap, int& idx,
F
fengjiayi 已提交
41
                        bool is_grad) const;
42 43 44
  void CompleteGradOp(OperatorBase* grad_op) const;
  const OperatorBase* op_;
  std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
F
fengjiayi 已提交
45
};
46 47 48

}  // namespace framework
}  // namespace paddle