提交 8156fedf 编写于 作者: D dzhwinter

merge develop branch. test=develop

上级 ee3aae56
......@@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
}
}
const SSANodeVector InplacePass::TryInplaceModifyVar(
const std::string& var, const std::string& cache_var, const size_t& idx,
ir::Graph* graph) const {
const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
SSANodeVector swap_nodes;
SSANodePair swap_nodes;
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
......@@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
......@@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
cache_node);
}
swap_nodes[node].emplace_back(cache_node);
swap_nodes.emplace_back(std::make_pair(node, cache_node));
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
......@@ -244,108 +249,43 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes[node].emplace_back(cache_node);
swap_nodes.emplace_back(std::make_pair(node, cache_node));
}
}
}
return swap_nodes;
}
void InplacePass::CommitModify(const SSANodeVector& swap_nodes,
void InplacePass::CommitModify(const SSANodePair& swap_nodes,
ir::Graph* graph) const {
for (auto& pair : swap_nodes) {
auto* node = pair.first;
const std::string var = node->Name();
for (auto* cache_node : pair.second) {
const std::string cache_var = cache_node->Name();
var_nodes_[cache_var].emplace_back(cache_node);
}
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
var_nodes_[cache_var].emplace_back(cache_node);
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
}
}
void InplacePass::WithDrawModify(const SSANodeVector& nodes,
void InplacePass::WithdrawModify(const SSANodePair& nodes,
ir::Graph* graph) const {
for (auto& pair : nodes) {
auto* node = pair.first;
const std::string var = node->Name();
for (auto* cache_node : pair.second) {
const std::string cache_var = cache_node->Name();
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
auto *node = pair.first, *cache_node = pair.second;
const std::string var = node->Name(), cache_var = cache_node->Name();
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
}
graph->RemoveNode(cache_node);
}
}
}
void InplacePass::InplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx, ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// release unused var in graph. Because python side memory optimize
// may reused the var in same name, so we only clear the var node
// after current inplaced index.
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
}
// if we need to rename the output,
// always create a newer version of cache_var
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
// release unsed var in graph
graph->RemoveNode(node);
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
}
}
graph->RemoveNode(cache_node);
}
}
......@@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
continue;
}
// NOTE(dzhwinter):
// two stage commit of inplaced process. if after inplace happens generate a
// circle,
// then withdraw the changes. Otherwise, safely add the node.
auto swap_nodes =
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
// NOTE(dzhwinter):
// two stage commit of inplaced op. If add such node generate a circle,
// then withdraw the changes. Otherwise, safely add the node.
if (!ir::HasCircle(*graph)) {
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
CommitModify(swap_nodes, graph);
InplaceModifyDesc(out_var_name, in_var_name, idx);
CommitModify(swap_nodes, graph);
} else {
VLOG(3) << string::Sprintf(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
out_var_name, in_var_name, op->Name());
WithDrawModify(swap_nodes, graph);
WithdrawModify(swap_nodes, graph);
}
}
}
......
......@@ -17,6 +17,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -54,7 +55,7 @@ class GraphView {
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
};
typedef std::unordered_map<ir::Node*, std::vector<ir::Node*>> SSANodeVector;
typedef std::vector<std::pair<ir::Node*, ir::Node*>> SSANodePair;
class InplacePass : public ir::Pass {
public:
InplacePass();
......@@ -66,17 +67,14 @@ class InplacePass : public ir::Pass {
void InitSSAGraphNodes() const;
private:
void InplaceModifyVar(const std::string& in_var, const std::string& out_var,
const size_t& idx, ir::Graph* graph) const;
const SSANodePair TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
const SSANodeVector TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const SSANodePair&, ir::Graph* graph) const;
void CommitModify(const SSANodeVector&, ir::Graph* graph) const;
void WithDrawModify(const SSANodeVector& nodes, ir::Graph* graph) const;
void WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const;
......
......@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self,
method,
use_cuda=True,
memory_opt=False,
memory_opt=True,
iter=50,
batch_size=None,
allow_op_delay=False,
......
......@@ -70,10 +70,3 @@ class TestIrInplace(TestParallelExecutorBase):
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
def test_fc_with_batchnorm_memory_opt(self, delta=1e-3):
loss00 = self._fc_with_batchnorm(False, True, False)
loss10 = self._fc_with_batchnorm(False, True, True)
loss10 = self._fc_with_batchnorm(True, True, True)
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册