提交 77150f1f 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/add_persistable_in_var_desc

...@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc); grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get());
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
......
...@@ -97,6 +97,11 @@ class OpDescBind { ...@@ -97,6 +97,11 @@ class OpDescBind {
const VariableNameMap &Outputs() const { return outputs_; } const VariableNameMap &Outputs() const { return outputs_; }
AttributeMap *MutableAttrMap() {
this->need_update_ = true;
return &this->attrs_;
}
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) { ...@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
} }
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs( std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
const OpDescBind& op_desc) { OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type()); auto& info = OpInfoMap::Instance().Get(op_desc->Type());
return info.grad_op_maker_(op_desc);
if (info.Checker() != nullptr) {
info.Checker()->Check(*op_desc->MutableAttrMap());
}
return info.grad_op_maker_(*op_desc);
} }
} // namespace framework } // namespace framework
......
...@@ -80,7 +80,7 @@ class OpRegistry { ...@@ -80,7 +80,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs( static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
const OpDescBind& op_desc); OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
...@@ -22,7 +22,7 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -22,7 +22,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamaxOp should not be null."); "Input(Param) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册