提交 6600e239 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4644 from reyoung/feature/backward_default_values

Fix bug of foward default attribute not passed to backward
......@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector
}
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get());
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
......
......@@ -97,6 +97,11 @@ class OpDescBind {
const VariableNameMap &Outputs() const { return outputs_; }
AttributeMap *MutableAttrMap() {
this->need_update_ = true;
return &this->attrs_;
}
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
}
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
const OpDescBind& op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type());
return info.grad_op_maker_(op_desc);
OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc->Type());
if (info.Checker() != nullptr) {
info.Checker()->Check(*op_desc->MutableAttrMap());
}
return info.grad_op_maker_(*op_desc);
}
} // namespace framework
......
......@@ -80,7 +80,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
const OpDescBind& op_desc);
OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册