提交 d376cf71 编写于 作者: D dzhwinter

polish code for reading. test=develop

上级 283573c6
...@@ -240,7 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -240,7 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
continue; continue;
} }
} }
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
return graph; return graph;
} }
......
...@@ -268,10 +268,15 @@ bool OrderedSet::Has(ir::Node* var) const { ...@@ -268,10 +268,15 @@ bool OrderedSet::Has(ir::Node* var) const {
return false; return false;
} }
void OrderedSet::Erase(const std::string& var) {
PADDLE_ENFORCE(mark_table_.count(var));
nodes_.erase(mark_table_[var]);
mark_table_.erase(var);
}
void OrderedSet::Erase(ir::Node* var) { void OrderedSet::Erase(ir::Node* var) {
PADDLE_ENFORCE(mark_table_.count(var->Name())); PADDLE_ENFORCE(var != nullptr);
nodes_.erase(mark_table_[var->Name()]); Erase(var->Name());
mark_table_.erase(var->Name());
} }
std::string OrderedSet::ToString() const { std::string OrderedSet::ToString() const {
...@@ -509,7 +514,9 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name, ...@@ -509,7 +514,9 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
for (auto* node : ops_) { for (auto* node : ops_) {
if (node == op) break; if (node == op) break;
for (auto& output : node->outputs) { for (auto& output : node->outputs) {
if (output->Name() == name) { PADDLE_ENFORCE((output != nullptr && output->IsVar()),
"Output is empty!");
if (output->Var() && output->Name() == name) {
found_node = output; found_node = output;
} }
} }
......
...@@ -55,6 +55,7 @@ class OrderedSet { ...@@ -55,6 +55,7 @@ class OrderedSet {
void Insert(ir::Node* var); void Insert(ir::Node* var);
void Erase(ir::Node* var); void Erase(ir::Node* var);
void Erase(const std::string& var);
bool Has(ir::Node* var) const; bool Has(ir::Node* var) const;
void Clear() { void Clear() {
mark_table_.clear(); mark_table_.clear();
......
...@@ -107,6 +107,52 @@ TEST(OrderedSet, Normal) { ...@@ -107,6 +107,52 @@ TEST(OrderedSet, Normal) {
ASSERT_EQ(pool.GetNodeIndexInPool(cache), 5); // match 4:[5,2] ASSERT_EQ(pool.GetNodeIndexInPool(cache), 5); // match 4:[5,2]
} }
} }
TEST(OrderedSet, FindBestFitNode) {
OrderedSet pool;
std::vector<std::unique_ptr<ir::Node>> nodes;
ProgramDesc prog;
BlockDesc* block_desc = prog.MutableBlock(0);
auto* op_desc = block_desc->AppendOp();
op_desc->SetType("dummy");
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
{
auto desc = block_desc->Var("a");
desc->SetShape({128, 128});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
{
auto desc = block_desc->Var("b");
desc->SetShape({128, 129});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
{
auto desc = block_desc->Var("c");
desc->SetShape({128, 128});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
for (auto& node : nodes) {
pool.Insert(node.get());
}
// FindNextBestFitNode
auto* n = nodes[0].get();
auto* cache = pool.FindBestFitNode(n);
PADDLE_ENFORCE(cache->Name() == "a");
cache = pool.FindNextBestFitNode(n, cache);
PADDLE_ENFORCE(cache->Name() == "c");
cache = pool.FindNextBestFitNode(n, cache);
PADDLE_ENFORCE(cache->Name() == "b");
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -69,7 +69,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -69,7 +69,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
for (auto& var : op->outputs) { for (auto& var : op->outputs) {
if (skip_set_.count(var->Name())) { if (var->IsVar() && !var->IsCtrlVar() && skip_set_.count(var->Name())) {
VLOG(3) << "Skip set contains variable of " << var->Name() VLOG(3) << "Skip set contains variable of " << var->Name()
<< "disable reuse on it. skipped"; << "disable reuse on it. skipped";
continue; continue;
...@@ -77,8 +77,8 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -77,8 +77,8 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) { if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) {
ir::Node* cache = pool_.FindBestFitNode(var); ir::Node* cache = pool_.FindBestFitNode(var);
while (cache != nullptr && var->Name() == cache->Name()) { while (cache != nullptr && var->Name() == cache->Name()) {
VLOG(3) << "The same cache variable is cascade reused." << var->Name() VLOG(3) << "The same cache variable is cascade reused. "
<< " is re-filled to the pool after" << var->Name() << " is re-filled to the pool after"
<< "the reused op is finished. Current op can not " << "the reused op is finished. Current op can not "
<< "replace it again. Skip this candidate."; << "replace it again. Skip this candidate.";
cache = pool_.FindNextBestFitNode(var, cache); cache = pool_.FindNextBestFitNode(var, cache);
...@@ -107,11 +107,13 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -107,11 +107,13 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
// //
// CFG Graph store the liveness information, when reuse happens // CFG Graph store the liveness information, when reuse happens
// we also need to update the variable liveness. // we also need to update the variable liveness.
cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); const std::string var_name = var->Name();
RenameVarInGraphDesc(var->Name(), cache->Name(), idx); const std::string cache_name = cache->Name();
RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get());
pool_.Erase(cache); cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
pool_.Erase(cache_name);
} }
} }
} }
...@@ -119,7 +121,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -119,7 +121,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
for (auto var : cfg_->LiveIn(op)) { for (auto var : cfg_->LiveIn(op)) {
if (cfg_->LiveOut(op).count(var) == 0) { if (cfg_->LiveOut(op).count(var) == 0) {
ir::Node* var_node = cfg_->GetNodeByName(var, op); ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (var_node == nullptr) continue; if (var_node == nullptr || var_node->IsCtrlVar()) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) { if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node); pool_.Insert(var_node);
} }
...@@ -275,8 +277,7 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -275,8 +277,7 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
// redirect the input to the latest version of cache_var // redirect the input to the latest version of cache_var
for (auto* node : op->inputs) { for (auto* node : op->inputs) {
if (node->Name() == var) { if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); ir::Node* cache_node = var_nodes_[cache_var].back();
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache_node // swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(), cache_node->outputs.insert(cache_node->outputs.end(),
...@@ -285,11 +286,15 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -285,11 +286,15 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
auto* prev_op = node->inputs[0]; auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node); cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) { for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
// erase unused node
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
} }
} }
...@@ -309,15 +314,14 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -309,15 +314,14 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
}
}
}
// release node of unused var in graph // erase unused node
for (auto* node : var_nodes_[var]) { auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node); graph->RemoveNode(node);
} }
var_nodes_.at(var).clear(); }
}
} }
} // namespace details } // namespace details
......
...@@ -79,7 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize
# python memory optimization is conflict with inplace pass. # python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way. # Use ir graph memory optimization after inplace pass is the correct way.
build_strategy.enable_inplace = False if memory_opt else enable_inplace build_strategy.enable_inplace = False if memory_opt else enable_inplace
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ['FLAGS_fast_eager_deletion_mode'] = True
os.environ[
'RECORDIO_FILENAME'] = '/tmp/ir_memory_optimize_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer
# NOTE(dzhwinter): test diferent strategy colisions.
# open the eager delete tensor strategy by default.
class TestTransformerWithIR(TestTransformer):
def test_main(self):
if core.is_compiled_with_cuda():
# check python transpiler
self.check_network_convergence(
transformer,
use_cuda=True,
memory_opt=True,
use_ir_memory_optimize=False)
# check IR memory optimize
self.check_network_convergence(
transformer,
use_cuda=True,
memory_opt=False,
use_ir_memory_optimize=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册