grad_op_creator.h 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#pragma once

#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace framework {
class OpRegistry;

class GradOpCreator {
F
fengjiayi 已提交
11 12
  using VarIndexMap = std::unordered_map<std::string, int>;

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
 public:
  GradOpCreator(const OperatorBase* op) : op_(op) {}
  OperatorBase* Create();

 private:
  enum InOutType { IN, OUT };

  struct OpInOutArg {
    OpInOutArg(const std::string& proto_name, const InOutType& type,
               bool needed_in_grad, size_t begin_idx, size_t end_idx)
        : proto_name_(proto_name),
          type_(type),
          needed_in_grad_(needed_in_grad),
          begin_idx_(begin_idx),
          end_idx_(end_idx) {}

    std::string proto_name_;
    InOutType type_;
    bool needed_in_grad_;
    size_t begin_idx_;
    size_t end_idx_;
  };

  OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
F
fengjiayi 已提交
37
                       const std::vector<int>& format, InOutType type);
38
  void BuildOpInOutArgList();
F
fengjiayi 已提交
39 40 41
  void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
                        std::vector<int>& format, VarIndexMap* varmap, int& idx,
                        bool is_grad);
42 43 44
  void CompleteGradOp(OperatorBase* grad_op) const;
  const OperatorBase* op_;
  std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
F
fengjiayi 已提交
45
};
46 47 48

}  // namespace framework
}  // namespace paddle