From c464ec21d8b0a1e7ad6da7115b78cd047d9a2041 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 12:09:39 -0700 Subject: [PATCH] Fix bug of foward default attribute not passed to backward --- paddle/framework/backward.cc | 2 +- paddle/framework/op_desc.h | 5 +++++ paddle/framework/op_registry.cc | 11 ++++++++--- paddle/framework/op_registry.h | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c970e01dd1..0a4688db9c 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -302,7 +302,7 @@ std::vector> MakeOpGrad( return grad_op_descs; // empty vector } - grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc); + grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); std::list> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index b39808dad1..b729029412 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -97,6 +97,11 @@ class OpDescBind { const VariableNameMap &Outputs() const { return outputs_; } + AttributeMap *MutableAttrMap() { + this->need_update_ = true; + return &this->attrs_; + } + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 66043f6e04..b118edae17 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -60,9 +60,14 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { } std::vector> OpRegistry::CreateGradOpDescs( - const OpDescBind& op_desc) { - auto& info = OpInfoMap::Instance().Get(op_desc.Type()); - return info.grad_op_maker_(op_desc); + OpDescBind* op_desc) { + auto& info = OpInfoMap::Instance().Get(op_desc->Type()); + + if (info.Checker() != nullptr) { + info.Checker()->Check(*op_desc->MutableAttrMap()); + } + + return info.grad_op_maker_(*op_desc); } } // namespace framework diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index cce3605fd4..5ca3af52a6 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -80,7 +80,7 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); static std::vector> CreateGradOpDescs( - const OpDescBind& op_desc); + OpDescBind* op_desc); static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; -- GitLab