diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 7319fcc88cfd11f125266cc571501a1065a349d8..048864c7004872df7fe6336675b4e3012f41709a 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -13,22 +13,22 @@ express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/framework.pb.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { enum class OpArgType { IN, OUT }; -static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, - const OpArgType& src_type, const OpArgType& dst_type, - bool is_grad) { +using VarNameMap = OperatorBase::VarNameMap; + +static VarNameMap TransOpArg(const OperatorBase* src_op, + const OpArgType& src_type, + const OpArgType& dst_type, bool is_grad) { const auto& src_inout = - src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; - auto& dst_inout = - dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; + src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); + VarNameMap dst_inout; - const OpProto& proto = OpProtos().at(src_op->type_); + const OpProto& proto = OpProtos().at(src_op->Type()); const auto& src_arg_list = src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); for (const auto& arg : src_arg_list) { @@ -41,17 +41,23 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, dst_inout[dst_name].emplace_back(s); } } + return dst_inout; } OperatorBase* BuildGradOp(const OperatorBase* op) { - std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); - OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); - grad_op->type_ = grad_op_type; - grad_op->attrs_ = op->attrs_; - TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I - TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O - TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG - TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG + std::string grad_op_type = OpRegistry::grad_ops().at(op->Type()); + auto I = TransOpArg(op, OpArgType::IN, OpArgType::IN, false); // I + auto O = TransOpArg(op, OpArgType::OUT, OpArgType::IN, false); // O + auto OG = TransOpArg(op, OpArgType::OUT, OpArgType::IN, true); // OG + auto IG = TransOpArg(op, OpArgType::IN, OpArgType::OUT, true); // IG + // TODO(merge I/O/OG) + VarNameMap GradIn; + GradIn.insert(I.begin(), I.end()); + GradIn.insert(O.begin(), O.end()); + GradIn.insert(OG.begin(), OG.end()); + + OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)( + grad_op_type, GradIn, IG, op->Attrs()); return grad_op; } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 55cf7fbe31f86306a5456465b4232bedf525499a..ffd48160b8af385961440873c8e9b525e1f5618b 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -128,7 +128,11 @@ class OpRegistry { public: template static void RegisterOp(const std::string& op_type) { - op_creators()[op_type] = [] { return new OpType; }; + op_creators()[op_type] = []( + const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs) { + return new OpType(type, inputs, outputs, attrs); + }; OpAttrChecker& op_checker = op_checkers()[op_type]; OpProto& op_proto = OpProtos()[op_type]; auto maker = ProtoMakerType(&op_proto, &op_checker); @@ -143,7 +147,11 @@ class OpRegistry { template static void RegisterGradOp(const std::string& op_type, const std::string& grad_op_type) { - op_creators()[grad_op_type] = [] { return new GradOpType; }; + op_creators()[grad_op_type] = []( + const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs) { + return new GradOpType(type, inputs, outputs, attrs); + }; grad_ops()[op_type] = grad_op_type; }