From 8156fedf5676c7886709bf7aaf1a4597e7cdd369 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 29 Jan 2019 16:49:07 +0800 Subject: [PATCH] merge develop branch. test=develop --- .../framework/details/inplace_op_pass.cc | 133 +++++------------- .../fluid/framework/details/inplace_op_pass.h | 18 ++- .../unittests/parallel_executor_test_base.py | 2 +- .../tests/unittests/test_ir_inplace_pass.py | 7 - 4 files changed, 46 insertions(+), 114 deletions(-) diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index d8a6be8573..208c353093 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var, } } -const SSANodeVector InplacePass::TryInplaceModifyVar( - const std::string& var, const std::string& cache_var, const size_t& idx, - ir::Graph* graph) const { +const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, + ir::Graph* graph) const { PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && var_nodes_[var].at(0)->Var() != nullptr); std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); var_desc->SetName(cache_var); - SSANodeVector swap_nodes; + SSANodePair swap_nodes; + for (size_t i = idx; i < view_.AllOps().size(); ++i) { auto* op = view_.AllOps()[i]; @@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar( for (auto* node : op->inputs) { if (node->Name() == var) { ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + // swap node to cache_node cache_node->outputs.insert(cache_node->outputs.end(), node->outputs.begin(), node->outputs.end()); @@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar( cache_node); } - swap_nodes[node].emplace_back(cache_node); + swap_nodes.emplace_back(std::make_pair(node, cache_node)); } } + + // if we need to rename the output, + // always create a newer version of cache_var for (auto* node : op->outputs) { if (node->Name() == var) { ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); - var_nodes_[cache_var].emplace_back(cache_node); // swap node to cache node cache_node->outputs.insert(cache_node->outputs.end(), node->outputs.begin(), node->outputs.end()); @@ -244,108 +249,43 @@ const SSANodeVector InplacePass::TryInplaceModifyVar( std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, cache_node); } - swap_nodes[node].emplace_back(cache_node); + + swap_nodes.emplace_back(std::make_pair(node, cache_node)); } } } + return swap_nodes; } -void InplacePass::CommitModify(const SSANodeVector& swap_nodes, +void InplacePass::CommitModify(const SSANodePair& swap_nodes, ir::Graph* graph) const { for (auto& pair : swap_nodes) { - auto* node = pair.first; - const std::string var = node->Name(); - for (auto* cache_node : pair.second) { - const std::string cache_var = cache_node->Name(); - var_nodes_[cache_var].emplace_back(cache_node); - } + auto *node = pair.first, *cache_node = pair.second; + const std::string var = node->Name(), cache_var = cache_node->Name(); + var_nodes_[cache_var].emplace_back(cache_node); + graph->RemoveNode(node); auto& nodes = var_nodes_.at(var); + // release unused var in graph. Because python side memory optimize + // may reused the var in same name, so we only clear the var node + // after current inplaced index. nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); - graph->RemoveNode(node); } } -void InplacePass::WithDrawModify(const SSANodeVector& nodes, +void InplacePass::WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const { for (auto& pair : nodes) { - auto* node = pair.first; - const std::string var = node->Name(); - for (auto* cache_node : pair.second) { - const std::string cache_var = cache_node->Name(); - auto* prev_op = node->inputs[0]; - std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, + auto *node = pair.first, *cache_node = pair.second; + const std::string var = node->Name(), cache_var = cache_node->Name(); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, + node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, node); - for (auto* next_op : node->outputs) { - std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, - node); - } - graph->RemoveNode(cache_node); - } - } -} - -void InplacePass::InplaceModifyVar(const std::string& var, - const std::string& cache_var, - const size_t& idx, ir::Graph* graph) const { - PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && - var_nodes_[var].at(0)->Var() != nullptr); - std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); - var_desc->SetName(cache_var); - - for (size_t i = idx; i < view_.AllOps().size(); ++i) { - auto* op = view_.AllOps()[i]; - - // redirect the input to the latest version of cache_var - for (auto* node : op->inputs) { - if (node->Name() == var) { - ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); - var_nodes_[cache_var].emplace_back(cache_node); - - // swap node to cache_node - cache_node->outputs.insert(cache_node->outputs.end(), - node->outputs.begin(), node->outputs.end()); - PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); - auto* prev_op = node->inputs[0]; - std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, - cache_node); - cache_node->inputs.emplace_back(prev_op); - for (auto* next_op : node->outputs) { - std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, - cache_node); - } - - // release unused var in graph. Because python side memory optimize - // may reused the var in same name, so we only clear the var node - // after current inplaced index. - graph->RemoveNode(node); - auto& nodes = var_nodes_.at(var); - nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); - } - } - - // if we need to rename the output, - // always create a newer version of cache_var - for (auto* node : op->outputs) { - if (node->Name() == var) { - ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); - var_nodes_[cache_var].emplace_back(cache_node); - // swap node to cache node - cache_node->outputs.insert(cache_node->outputs.end(), - node->outputs.begin(), node->outputs.end()); - cache_node->inputs.emplace_back(op); - std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); - for (auto* next_op : node->outputs) { - std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, - cache_node); - } - - // release unsed var in graph - graph->RemoveNode(node); - auto& nodes = var_nodes_.at(var); - nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); - } } + graph->RemoveNode(cache_node); } } @@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, continue; } + // NOTE(dzhwinter): + // two stage commit of inplaced process. if after inplace happens generate a + // circle, + // then withdraw the changes. Otherwise, safely add the node. auto swap_nodes = TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); - // NOTE(dzhwinter): - // two stage commit of inplaced op. If add such node generate a circle, - // then withdraw the changes. Otherwise, safely add the node. if (!ir::HasCircle(*graph)) { VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), out_var_name, in_var_name); - CommitModify(swap_nodes, graph); InplaceModifyDesc(out_var_name, in_var_name, idx); + CommitModify(swap_nodes, graph); } else { VLOG(3) << string::Sprintf( "Skiped pair %s => %s, inplace will generate a circle. withdraw %s", out_var_name, in_var_name, op->Name()); - WithDrawModify(swap_nodes, graph); + WithdrawModify(swap_nodes, graph); } } } diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h index cf1099323a..203ffe6e24 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/graph.h" @@ -54,7 +55,7 @@ class GraphView { std::map> adj_list_; }; -typedef std::unordered_map> SSANodeVector; +typedef std::vector> SSANodePair; class InplacePass : public ir::Pass { public: InplacePass(); @@ -66,17 +67,14 @@ class InplacePass : public ir::Pass { void InitSSAGraphNodes() const; private: - void InplaceModifyVar(const std::string& in_var, const std::string& out_var, - const size_t& idx, ir::Graph* graph) const; + const SSANodePair TryInplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, + ir::Graph* graph) const; - const SSANodeVector TryInplaceModifyVar(const std::string& var, - const std::string& cache_var, - const size_t& idx, - ir::Graph* graph) const; + void CommitModify(const SSANodePair&, ir::Graph* graph) const; - void CommitModify(const SSANodeVector&, ir::Graph* graph) const; - - void WithDrawModify(const SSANodeVector& nodes, ir::Graph* graph) const; + void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const; void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, const size_t& idx) const; diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index eaf2ebb62f..c429c8af7d 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase): def check_network_convergence(self, method, use_cuda=True, - memory_opt=False, + memory_opt=True, iter=50, batch_size=None, allow_op_delay=False, diff --git a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py index b87407e31e..2770afd605 100644 --- a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py +++ b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py @@ -70,10 +70,3 @@ class TestIrInplace(TestParallelExecutorBase): self.assertAlmostEqual(loss00, loss10, delta=delta) self.assertAlmostEqual(loss00, loss01, delta=delta) self.assertAlmostEqual(loss00, loss11, delta=delta) - - def test_fc_with_batchnorm_memory_opt(self, delta=1e-3): - loss00 = self._fc_with_batchnorm(False, True, False) - loss10 = self._fc_with_batchnorm(False, True, True) - loss10 = self._fc_with_batchnorm(True, True, True) - self.assertAlmostEqual(loss00, loss10, delta=delta) - self.assertAlmostEqual(loss00, loss01, delta=delta) -- GitLab