提交 6600e239 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4644 from reyoung/feature/backward_default_values

Fix bug of foward default attribute not passed to backward
...@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc); grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get());
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
......
...@@ -97,6 +97,11 @@ class OpDescBind { ...@@ -97,6 +97,11 @@ class OpDescBind {
const VariableNameMap &Outputs() const { return outputs_; } const VariableNameMap &Outputs() const { return outputs_; }
AttributeMap *MutableAttrMap() {
this->need_update_ = true;
return &this->attrs_;
}
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) { ...@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
} }
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs( std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
const OpDescBind& op_desc) { OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type()); auto& info = OpInfoMap::Instance().Get(op_desc->Type());
return info.grad_op_maker_(op_desc);
if (info.Checker() != nullptr) {
info.Checker()->Check(*op_desc->MutableAttrMap());
}
return info.grad_op_maker_(*op_desc);
} }
} // namespace framework } // namespace framework
......
...@@ -80,7 +80,7 @@ class OpRegistry { ...@@ -80,7 +80,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs( static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
const OpDescBind& op_desc); OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册