提交 18afb77e 编写于 作者: D dzhwinter

polish code for reading. test=develop

上级 684b5723
......@@ -128,7 +128,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
}
graph->ResolveHazard(var_nodes_);
// graph->ResolveHazard(var_nodes_);
return graph;
}
......@@ -324,6 +324,32 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
}
}
void MemoryOptimizePass::ClearControlDepVars(ir::Graph* graph) const {
for (auto& op : graph->Nodes()) {
if (!op->IsOp()) continue;
{
auto& nodes = op->inputs;
nodes.erase(
std::remove_if(nodes.begin(), nodes.end(),
[&](ir::Node* var) { return var->IsCtrlVar(); }),
nodes.end());
}
{
auto& nodes = op->outputs;
nodes.erase(
std::remove_if(nodes.begin(), nodes.end(),
[&](ir::Node* var) { return var->IsCtrlVar(); }),
nodes.end());
}
}
for (auto& node : graph->Nodes()) {
if (node->IsCtrlVar()) {
graph->RemoveNode(node);
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......
......@@ -48,6 +48,7 @@ class MemoryOptimizePass : public ir::Pass {
void RenameVarInGraphNode(const std::string& var,
const std::string& cache_var, size_t idx,
ir::Graph* graph) const;
void ClearControlDepVars(ir::Graph* graph) const;
void SubGraphOptimize(OpDesc* op_desc) const;
// 1. scan op with subblock and collect the output/input vars.
......
......@@ -121,6 +121,8 @@ class TestMNIST(TestParallelExecutorBase):
regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer
# NOTE(dzh):
# need to make it compatible with elewise fuse act
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
......@@ -128,6 +130,7 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda,
fuse_elewise_add_act_ops=False,
memory_opt=False,
use_ir_memory_optimize=False,
optimizer=_optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
......@@ -136,6 +139,7 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda,
fuse_elewise_add_act_ops=True,
memory_opt=False,
use_ir_memory_optimize=False,
optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册