提交 2a5ecb68 编写于 作者: D dzhwinter

follow comment. test=develop

上级 9f693fca
...@@ -266,11 +266,13 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -266,11 +266,13 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
VLOG(4) << "Try to inplace op " << op->Name(); VLOG(4) << "Try to inplace op " << op->Name();
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr, PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
"op_desc is nullptr"); "op_desc is nullptr");
// 4 pre-requirments need to meet if the op want to inplaced. // some pre-requirments need to meet if the op want to inplaced.
// 1. infer_inplace_ is registered.
auto* op_desc = op->Op(); auto* op_desc = op->Op();
auto& infer_inplace = auto& infer_inplace =
OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_; OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_;
// 1. infer_inplace_ is registered.
if (!static_cast<bool>(infer_inplace)) return; if (!static_cast<bool>(infer_inplace)) return;
PADDLE_ENFORCE(static_cast<bool>(infer_inplace), PADDLE_ENFORCE(static_cast<bool>(infer_inplace),
"%s's infer_inplace has not been registered", op_desc->Type()); "%s's infer_inplace has not been registered", op_desc->Type());
...@@ -399,7 +401,7 @@ void GraphView::Build(ir::Graph* g) { ...@@ -399,7 +401,7 @@ void GraphView::Build(ir::Graph* g) {
} }
} }
const std::vector<ir::Node*> GraphView::AllOps() { return ops_; } const& std::vector<ir::Node*> GraphView::AllOps() { return ops_; }
bool GraphView::ReusedInPythonMemOpt(const std::string& var) const { bool GraphView::ReusedInPythonMemOpt(const std::string& var) const {
return dup_nodes_.count(var); return dup_nodes_.count(var);
......
...@@ -33,7 +33,7 @@ class GraphView { ...@@ -33,7 +33,7 @@ class GraphView {
void Build(ir::Graph* g); void Build(ir::Graph* g);
const std::vector<ir::Node*> AllOps(); const& std::vector<ir::Node*> AllOps();
ir::Node* GetNodeByName(const std::string& name, ir::Node* GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const; const std::vector<ir::Node*>& nodes) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册