提交 ebbbaee0 编写于 作者: Y Yu Yang

Follow comments

上级 2594a502
......@@ -39,9 +39,9 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
std::transform(grad_descs.begin(), grad_descs.end(),
std::back_inserter(grad_ops),
[](const std::unique_ptr<OpDescBind>& grad_desc) {
return OpRegistry::CreateOp(grad_desc.get());
return OpRegistry::CreateOp(*grad_desc);
});
PADDLE_ENFORCE_GT(grad_ops.size(), 0);
PADDLE_ENFORCE(!grad_ops.empty());
if (grad_ops.size() == 1) {
return std::move(grad_ops[0]);
} else {
......
......@@ -54,9 +54,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(OpDescBind* op_desc) {
return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(),
op_desc->GetAttrMap());
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(),
op_desc.GetAttrMap());
}
} // namespace framework
......
......@@ -79,7 +79,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
template <typename OpType, typename ProtoMakerType, typename GradOpType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册