提交 18afb77e 编写于 作者: D dzhwinter

polish code for reading. test=develop

上级 684b5723
...@@ -128,7 +128,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -128,7 +128,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
} }
} }
graph->ResolveHazard(var_nodes_); // graph->ResolveHazard(var_nodes_);
return graph; return graph;
} }
...@@ -324,6 +324,32 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -324,6 +324,32 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
} }
} }
void MemoryOptimizePass::ClearControlDepVars(ir::Graph* graph) const {
for (auto& op : graph->Nodes()) {
if (!op->IsOp()) continue;
{
auto& nodes = op->inputs;
nodes.erase(
std::remove_if(nodes.begin(), nodes.end(),
[&](ir::Node* var) { return var->IsCtrlVar(); }),
nodes.end());
}
{
auto& nodes = op->outputs;
nodes.erase(
std::remove_if(nodes.begin(), nodes.end(),
[&](ir::Node* var) { return var->IsCtrlVar(); }),
nodes.end());
}
}
for (auto& node : graph->Nodes()) {
if (node->IsCtrlVar()) {
graph->RemoveNode(node);
}
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -48,6 +48,7 @@ class MemoryOptimizePass : public ir::Pass { ...@@ -48,6 +48,7 @@ class MemoryOptimizePass : public ir::Pass {
void RenameVarInGraphNode(const std::string& var, void RenameVarInGraphNode(const std::string& var,
const std::string& cache_var, size_t idx, const std::string& cache_var, size_t idx,
ir::Graph* graph) const; ir::Graph* graph) const;
void ClearControlDepVars(ir::Graph* graph) const;
void SubGraphOptimize(OpDesc* op_desc) const; void SubGraphOptimize(OpDesc* op_desc) const;
// 1. scan op with subblock and collect the output/input vars. // 1. scan op with subblock and collect the output/input vars.
......
...@@ -121,6 +121,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -121,6 +121,8 @@ class TestMNIST(TestParallelExecutorBase):
regularization=fluid.regularizer.L2Decay(1e-6)) regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer return optimizer
# NOTE(dzh):
# need to make it compatible with elewise fuse act
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence( not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model, model,
feed_dict={"image": img, feed_dict={"image": img,
...@@ -128,6 +130,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -128,6 +130,7 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda, use_cuda=use_cuda,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
memory_opt=False, memory_opt=False,
use_ir_memory_optimize=False,
optimizer=_optimizer) optimizer=_optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence( fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model, model,
...@@ -136,6 +139,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -136,6 +139,7 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda, use_cuda=use_cuda,
fuse_elewise_add_act_ops=True, fuse_elewise_add_act_ops=True,
memory_opt=False, memory_opt=False,
use_ir_memory_optimize=False,
optimizer=_optimizer) optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册