提交 77150f1f 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/add_persistable_in_var_desc

......@@ -302,7 +302,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector
}
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get());
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
......
......@@ -97,6 +97,11 @@ class OpDescBind {
const VariableNameMap &Outputs() const { return outputs_; }
AttributeMap *MutableAttrMap() {
this->need_update_ = true;
return &this->attrs_;
}
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
}
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
const OpDescBind& op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type());
return info.grad_op_maker_(op_desc);
OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc->Type());
if (info.Checker() != nullptr) {
info.Checker()->Check(*op_desc->MutableAttrMap());
}
return info.grad_op_maker_(*op_desc);
}
} // namespace framework
......
......@@ -80,7 +80,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
const OpDescBind& op_desc);
OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
......
......@@ -22,7 +22,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册