diff --git a/paddle/framework/grad_op_creator.cc b/paddle/framework/grad_op_creator.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbc10d5ad530990504d37bd9d5c15fb47e26af73 --- /dev/null +++ b/paddle/framework/grad_op_creator.cc @@ -0,0 +1,97 @@ +#include "paddle/framework/grad_op_creator.h" + +namespace paddle { +namespace framework { + +OperatorBase* GradOpCreator::Create() { + BuildOpInOutArgList(); + OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)(); + CompleteGradOp(grad_op); + return grad_op; +} + +OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, + const VarIndexMap& var_map, + const vector& format, InOutType type) { + int idx = var_map.at(var.name()); + int begin_idx = format.empty() ? idx : format.at(idx); + int end_idx = format.empty() ? idx + 1 : format.at(idx + 1); + return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx, + end_idx); +} + +void GradOpCreator::BuildOpInOutArgList() { + const OpProto& op_proto = OpRegistry::protos().at(op_->type); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_)); + const vector& in_format = + op_->attrs_.count("input_format") + ? op->GetAttr>("input_format") + : std::vector(); + const vector& out_format = + op_->attrs_.count("output_format") + ? op->GetAttr>("output_format") + : std::vector(); + for (const auto& var : op_proto.inputs()) { + arg_list_.emplace_back( + std::shared_ptr(BuildArg(var, var_map, in_format, IN))); + } + for (const auto& var : op_proto.outputs()) { + arg_list_.emplace_back( + std::shared_ptr(BuildArg(var, var_map, out_format, OUT))); + } +} + +void GradOpCreator::PushArgIntoGradOp(const OpInOutArg* arg, + vector& in_out, + vector& format, VarIndexMap* varmap, + int& idx, bool is_grad) { + std::string var_name = arg->proto_name_; + if (is_grad) { + var_name += OperatorBase::GRAD_VAR_SUFFIX(); + } + *(varmap)[var_name] = idx++; + size_t pre_sz = in_out.size(); + auto base_it = arg->type == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, + std::back_inserter(in_out)); + if (is_grad) { + for (size_t i = pre_sz; i < in_out.size(); ++i) { + in_out[i] += OperatorBase::GRAD_VAR_SUFFIX(); + } + } + format.push_back(in_out.size()); +} + +void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { + grad_op->type_ = op_->type_ + "@GRAD"; // not necessary + grad_op->attrs_ = op_->attrs_; + grad_op->attrs_.erase("input_format"); + grad_op->attrs_.erase("output_format"); + VarIndexMap* grad_varmap = new VarIndexMap(); + int in_idx = 0; + int out_idx = 0; + vector in_format({0}); + vector out_format({0}); + for (const auto& arg : arg_list_) { + // op_'s inputs_ and outputs_ + if (arg->needed_in_grad_) { + PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, + in_idx, false); + } + if (arg->type_ == IN) { + // gradients of op_'s inputs_ + PushArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap, + out_idx, true); + } else { + // gradients of op_'s outputs_ + PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, + in_idx, true); + } + } + grad_op->attrs_["input_format"] = in_format; + grad_op->attrs_["output_format"] = out_format; + grad_op->in_out_idxs_.reset(grad_varmap); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/grad_op_creator.h b/paddle/framework/grad_op_creator.h new file mode 100644 index 0000000000000000000000000000000000000000..441aae4979476e59267cecb842e23f6fece3e88f --- /dev/null +++ b/paddle/framework/grad_op_creator.h @@ -0,0 +1,46 @@ +#pragma once + +#include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace framework { +class OpRegistry; + +class GradOpCreator { + public: + GradOpCreator(const OperatorBase* op) : op_(op) {} + OperatorBase* Create(); + + private: + enum InOutType { IN, OUT }; + + struct OpInOutArg { + OpInOutArg(const std::string& proto_name, const InOutType& type, + bool needed_in_grad, size_t begin_idx, size_t end_idx) + : proto_name_(proto_name), + type_(type), + needed_in_grad_(needed_in_grad), + begin_idx_(begin_idx), + end_idx_(end_idx) {} + + std::string proto_name_; + InOutType type_; + bool needed_in_grad_; + size_t begin_idx_; + size_t end_idx_; + }; + + OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, + const vector& format, InOutType type); + void BuildOpInOutArgList(); + void PushArgIntoGradOp(const OpInOutArg* arg, vector& in_out, + vector& format, VarIndexMap* varmap, int& idx, + bool is_grad); + void CompleteGradOp(OperatorBase* grad_op) const; + const OperatorBase* op_; + std::vector> arg_list_; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4a197102d6e9937c341b8bfdf1afcc863d7ff6d8..fcb529bbac4908a61fa41bffc444aae09c60dad3 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -6,9 +6,8 @@ #include #include #include "paddle/framework/attr_checker.h" +#include "paddle/framework/grad_op_creater.h" #include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/op_proto.pb.h" -#include "paddle/framework/operator.h" #include "paddle/framework/scope.h" namespace paddle { @@ -286,13 +285,8 @@ class OpRegistry { } static OperatorPtr CreateGradOp(OperatorPtr op) { - OperatorPtr grad_op(grad_creators().at(op->type_)()); - grad_op->type_ = op->type_; - - AssembleGradInOut(op, grad_op); - GenerateGradArgOffset(op, grad_op); - GenerateGradAttr(op, grad_op); - + GradOpCreator creator(op.get()); + OperatorPtr grad_op(creator.Create()); grad_op->Init(); return grad_op; } @@ -302,13 +296,18 @@ class OpRegistry { return protos_; }; - private: + static std::unordered_map& grad_creators() { + static std::unordered_map grad_creators_; + return grad_creators_; + } + static std::unordered_map>& VarIndexMaps() { static std::unordered_map> maps_; return maps_; } + private: static std::unordered_map& creators() { static std::unordered_map creators_; return creators_; @@ -319,11 +318,6 @@ class OpRegistry { return op_checkers_; }; - static std::unordered_map& grad_creators() { - static std::unordered_map grad_creators_; - return grad_creators_; - } - static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); for (auto& outname : op->outputs_) { @@ -334,100 +328,6 @@ class OpRegistry { } } } - - static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) { - size_t in_sz = op->inputs_.size() + op->outputs_.size() * 2; - grad_op->inputs_.reserve(in_sz); - size_t out_sz = op->inputs_.size(); - grad_op->outputs_.reserve(out_sz); - // copy op->inputs_ to grad_op->inputs_ - std::copy(op->inputs_.begin(), op->inputs_.end(), - std::back_inserter(grad_op->inputs_)); - // copy op->outputs_ to grad_op->inputs_ - std::copy(op->outputs_.begin(), op->outputs_.end(), - std::back_inserter(grad_op->inputs_)); - // add gradients of op->outputs_ to grad_op->inputs_ - for (const std::string& name : op->outputs_) { - grad_op->inputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX()); - } - // add gradients of op->inputs_ to grad_op->outputs_ - for (const std::string& name : op->inputs_) { - grad_op->outputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX()); - } - } - - static void GenerateGradArgOffset(OperatorPtr op, OperatorPtr grad_op) { - VarIndexMap* grad_varmap = new VarIndexMap(); - const OpProto& op_proto = protos()[op->type_]; - int idx = 0; - // offset of op's inputs - for (const auto& var : op_proto.inputs()) { - (*grad_varmap)[var.name()] = idx++; - } - // offset of op's outputs - for (const auto& var : op_proto.outputs()) { - (*grad_varmap)[var.name()] = idx++; - } - // offset of gradients of op's output - for (const auto& var : op_proto.outputs()) { - (*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++; - } - idx = 0; - // offset of gradients of op's input - for (const auto& var : op_proto.inputs()) { - (*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++; - } - grad_op->in_out_idxs_.reset(grad_varmap); - } - - static void GenerateGradAttr(OperatorPtr op, OperatorPtr grad_op) { - const OpProto& op_proto = protos()[op->type_]; - grad_op->attrs_ = op->attrs_; - grad_op->attrs_.erase("input_format"); - grad_op->attrs_.erase("output_format"); - bool has_in_format = op->attrs_.count("input_format"); - bool has_out_format = op->attrs_.count("output_format"); - // grad_op's inputs_ contains op's inputs_, outputs_ and gradients of - // outpus_. So grad_op's input_format is necessary when op has - // either input_format or output_format. - if (has_in_format || has_out_format) { - std::vector old_in_format; - std::vector old_out_format; - has_in_format - ? old_in_format = op->GetAttr>("input_format") - : old_in_format = std::vector(op_proto.inputs_size()), - std::iota(old_in_format.begin(), old_in_format.end(), 0); - has_out_format - ? old_out_format = op->GetAttr>("output_format") - : old_out_format = std::vector(op_proto.outputs_size()), - std::iota(old_out_format.begin(), old_out_format.end(), 0); - - std::vector in_format; - in_format.reserve(old_in_format.size() + old_out_format.size() * 2); - int base = 0; - for (const int& idx : old_in_format) { - in_format.emplace_back(idx + base); - } - base += op->inputs_.size(); - for (const int& idx : old_out_format) { - in_format.emplace_back(idx + base); - } - base += op->outputs_.size(); - for (const int& idx : old_in_format) { - in_format.emplace_back(idx + base); - } - grad_op->attrs_["input_format"] = in_format; - // grad_op's outputs_ contains gradients of op's inputs_. So grad_op's - // output_format is necessary only when op has input_format. - if (has_in_format) { - std::vector out_format; - out_format.reserve(op_proto.inputs_size()); - std::copy(old_in_format.begin(), old_in_format.end(), - std::back_inserter(out_format)); - grad_op->attrs_["output_format"] = out_format; - } - } - } }; template