提交 8156fedf 编写于 作者: D dzhwinter

merge develop branch. test=develop

上级 ee3aae56
...@@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var, ...@@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
} }
} }
const SSANodeVector InplacePass::TryInplaceModifyVar( const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
const std::string& var, const std::string& cache_var, const size_t& idx, const std::string& cache_var,
ir::Graph* graph) const { const size_t& idx,
ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr); var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var); var_desc->SetName(cache_var);
SSANodeVector swap_nodes; SSANodePair swap_nodes;
for (size_t i = idx; i < view_.AllOps().size(); ++i) { for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i]; auto* op = view_.AllOps()[i];
...@@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar( ...@@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
for (auto* node : op->inputs) { for (auto* node : op->inputs) {
if (node->Name() == var) { if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
// swap node to cache_node // swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(), cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end()); node->outputs.begin(), node->outputs.end());
...@@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar( ...@@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
cache_node); cache_node);
} }
swap_nodes[node].emplace_back(cache_node); swap_nodes.emplace_back(std::make_pair(node, cache_node));
} }
} }
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) { for (auto* node : op->outputs) {
if (node->Name() == var) { if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node // swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(), cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end()); node->outputs.begin(), node->outputs.end());
...@@ -244,108 +249,43 @@ const SSANodeVector InplacePass::TryInplaceModifyVar( ...@@ -244,108 +249,43 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
swap_nodes[node].emplace_back(cache_node);
swap_nodes.emplace_back(std::make_pair(node, cache_node));
} }
} }
} }
return swap_nodes; return swap_nodes;
} }
void InplacePass::CommitModify(const SSANodeVector& swap_nodes, void InplacePass::CommitModify(const SSANodePair& swap_nodes,
ir::Graph* graph) const { ir::Graph* graph) const {
for (auto& pair : swap_nodes) { for (auto& pair : swap_nodes) {
auto* node = pair.first; auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(); const std::string var = node->Name(), cache_var = cache_node->Name();
for (auto* cache_node : pair.second) { var_nodes_[cache_var].emplace_back(cache_node);
const std::string cache_var = cache_node->Name(); graph->RemoveNode(node);
var_nodes_[cache_var].emplace_back(cache_node);
}
auto& nodes = var_nodes_.at(var); auto& nodes = var_nodes_.at(var);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
} }
} }
void InplacePass::WithDrawModify(const SSANodeVector& nodes, void InplacePass::WithdrawModify(const SSANodePair& nodes,
ir::Graph* graph) const { ir::Graph* graph) const {
for (auto& pair : nodes) { for (auto& pair : nodes) {
auto* node = pair.first; auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(); const std::string var = node->Name(), cache_var = cache_node->Name();
for (auto* cache_node : pair.second) { auto* prev_op = node->inputs[0];
const std::string cache_var = cache_node->Name(); std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
auto* prev_op = node->inputs[0]; node);
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node); node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
}
graph->RemoveNode(cache_node);
}
}
}
void InplacePass::InplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx, ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// release unsed var in graph
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
} }
graph->RemoveNode(cache_node);
} }
} }
...@@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
continue; continue;
} }
// NOTE(dzhwinter):
// two stage commit of inplaced process. if after inplace happens generate a
// circle,
// then withdraw the changes. Otherwise, safely add the node.
auto swap_nodes = auto swap_nodes =
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
// NOTE(dzhwinter):
// two stage commit of inplaced op. If add such node generate a circle,
// then withdraw the changes. Otherwise, safely add the node.
if (!ir::HasCircle(*graph)) { if (!ir::HasCircle(*graph)) {
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name); out_var_name, in_var_name);
CommitModify(swap_nodes, graph);
InplaceModifyDesc(out_var_name, in_var_name, idx); InplaceModifyDesc(out_var_name, in_var_name, idx);
CommitModify(swap_nodes, graph);
} else { } else {
VLOG(3) << string::Sprintf( VLOG(3) << string::Sprintf(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s", "Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
out_var_name, in_var_name, op->Name()); out_var_name, in_var_name, op->Name());
WithDrawModify(swap_nodes, graph); WithdrawModify(swap_nodes, graph);
} }
} }
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -54,7 +55,7 @@ class GraphView { ...@@ -54,7 +55,7 @@ class GraphView {
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_; std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
}; };
typedef std::unordered_map<ir::Node*, std::vector<ir::Node*>> SSANodeVector; typedef std::vector<std::pair<ir::Node*, ir::Node*>> SSANodePair;
class InplacePass : public ir::Pass { class InplacePass : public ir::Pass {
public: public:
InplacePass(); InplacePass();
...@@ -66,17 +67,14 @@ class InplacePass : public ir::Pass { ...@@ -66,17 +67,14 @@ class InplacePass : public ir::Pass {
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
private: private:
void InplaceModifyVar(const std::string& in_var, const std::string& out_var, const SSANodePair TryInplaceModifyVar(const std::string& var,
const size_t& idx, ir::Graph* graph) const; const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
const SSANodeVector TryInplaceModifyVar(const std::string& var, void CommitModify(const SSANodePair&, ir::Graph* graph) const;
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const SSANodeVector&, ir::Graph* graph) const; void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const;
void WithDrawModify(const SSANodeVector& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const; const size_t& idx) const;
......
...@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self, def check_network_convergence(self,
method, method,
use_cuda=True, use_cuda=True,
memory_opt=False, memory_opt=True,
iter=50, iter=50,
batch_size=None, batch_size=None,
allow_op_delay=False, allow_op_delay=False,
......
...@@ -70,10 +70,3 @@ class TestIrInplace(TestParallelExecutorBase): ...@@ -70,10 +70,3 @@ class TestIrInplace(TestParallelExecutorBase):
self.assertAlmostEqual(loss00, loss10, delta=delta) self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta) self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta) self.assertAlmostEqual(loss00, loss11, delta=delta)
def test_fc_with_batchnorm_memory_opt(self, delta=1e-3):
loss00 = self._fc_with_batchnorm(False, True, False)
loss10 = self._fc_with_batchnorm(False, True, True)
loss10 = self._fc_with_batchnorm(True, True, True)
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册