未验证 提交 28609b34 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #15696 from dzhwinter/cherry-pick/memory

cherry picked modifies.
...@@ -50,7 +50,12 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -50,7 +50,12 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper) if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass) cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info) cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
......
...@@ -240,7 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -240,7 +240,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
continue; continue;
} }
} }
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
return graph; return graph;
} }
......
...@@ -49,7 +49,7 @@ DEFINE_bool( ...@@ -49,7 +49,7 @@ DEFINE_bool(
"If this option turns on, only these op in whitelist can be inplaced." "If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op." "If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add" "Such as scale, elementwise_add"
"By default, it's turned on"); "By default, it's turned off");
DECLARE_string(memory_optimize_debug); DECLARE_string(memory_optimize_debug);
......
...@@ -13,13 +13,19 @@ ...@@ -13,13 +13,19 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm>
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <iostream> #include <iterator>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif // PADDLE_WITH_CUDA
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -166,6 +172,11 @@ struct NodeComparator { ...@@ -166,6 +172,11 @@ struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const { bool operator()(ir::Node* lhs, ir::Node* rhs) const {
auto* lhs_desc = FindVarDescInBlock(lhs); auto* lhs_desc = FindVarDescInBlock(lhs);
auto* rhs_desc = FindVarDescInBlock(rhs); auto* rhs_desc = FindVarDescInBlock(rhs);
// match data type
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
return false;
}
// match shape
auto lhs_shape = lhs_desc->GetShape(); auto lhs_shape = lhs_desc->GetShape();
auto rhs_shape = rhs_desc->GetShape(); auto rhs_shape = rhs_desc->GetShape();
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) || if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
...@@ -230,6 +241,27 @@ ir::Node* OrderedSet::FindBestFitNode(ir::Node* var) const { ...@@ -230,6 +241,27 @@ ir::Node* OrderedSet::FindBestFitNode(ir::Node* var) const {
return found_node; return found_node;
} }
ir::Node* OrderedSet::FindNextBestFitNode(ir::Node* var, ir::Node* prev) const {
ir::Node* found_node = nullptr;
NodeComparator functor;
auto it =
std::find_if(nodes_.begin(), nodes_.end(), [&](const NodeVector& v) {
if (v.front() == prev)
return true;
else
return false;
});
PADDLE_ENFORCE(it != nodes_.end(), "Not found previous in node list!");
for (it = std::next(it); it != nodes_.end(); ++it) {
auto& candidate = it->front();
if (functor(var, candidate)) {
found_node = candidate;
break;
}
}
return found_node;
}
bool OrderedSet::Has(ir::Node* var) const { bool OrderedSet::Has(ir::Node* var) const {
if (mark_table_.count(var->Name())) { if (mark_table_.count(var->Name())) {
auto& node_in_samename = mark_table_.at(var->Name()); auto& node_in_samename = mark_table_.at(var->Name());
...@@ -241,10 +273,15 @@ bool OrderedSet::Has(ir::Node* var) const { ...@@ -241,10 +273,15 @@ bool OrderedSet::Has(ir::Node* var) const {
return false; return false;
} }
void OrderedSet::Erase(const std::string& var) {
PADDLE_ENFORCE(mark_table_.count(var));
nodes_.erase(mark_table_[var]);
mark_table_.erase(var);
}
void OrderedSet::Erase(ir::Node* var) { void OrderedSet::Erase(ir::Node* var) {
PADDLE_ENFORCE(mark_table_.count(var->Name())); PADDLE_ENFORCE(var != nullptr);
nodes_.erase(mark_table_[var->Name()]); Erase(var->Name());
mark_table_.erase(var->Name());
} }
std::string OrderedSet::ToString() const { std::string OrderedSet::ToString() const {
...@@ -274,14 +311,35 @@ bool NodeCanReused(ir::Node* node) { ...@@ -274,14 +311,35 @@ bool NodeCanReused(ir::Node* node) {
return flag; return flag;
} }
int MinChunkSize() {
int size{0};
#ifdef PADDLE_WITH_CUDA
size = platform::GpuMinChunkSize();
#else
size = platform::CpuMinChunkSize();
#endif // PADDLE_WITH_CUDA
return size;
}
bool NodeCanReused(const VarDesc& node) { bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType(); auto type = node.GetType();
// only these types holds bulk of gpu memory
if (!(type == proto::VarType::LOD_TENSOR || if (!(type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::SELECTED_ROWS || type == proto::VarType::SELECTED_ROWS ||
type == proto::VarType::LOD_TENSOR_ARRAY)) { type == proto::VarType::LOD_TENSOR_ARRAY)) {
return false; return false;
} }
if (node.Persistable() || node.GetShape().empty()) { // persistable variable is parameter
if (node.Persistable()) {
return false;
}
// shape < min_chunk_size is meaningless.
// further more, fetched loss always has size = 1
// which should not be reused.
auto shape = node.GetShape();
int size = std::abs(
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()));
if (shape.empty() || size < MinChunkSize()) {
return false; return false;
} }
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
...@@ -461,7 +519,9 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name, ...@@ -461,7 +519,9 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
for (auto* node : ops_) { for (auto* node : ops_) {
if (node == op) break; if (node == op) break;
for (auto& output : node->outputs) { for (auto& output : node->outputs) {
if (output->Name() == name) { PADDLE_ENFORCE((output != nullptr && output->IsVar()),
"Output is empty!");
if (output->Var() && output->Name() == name) {
found_node = output; found_node = output;
} }
} }
......
...@@ -55,6 +55,7 @@ class OrderedSet { ...@@ -55,6 +55,7 @@ class OrderedSet {
void Insert(ir::Node* var); void Insert(ir::Node* var);
void Erase(ir::Node* var); void Erase(ir::Node* var);
void Erase(const std::string& var);
bool Has(ir::Node* var) const; bool Has(ir::Node* var) const;
void Clear() { void Clear() {
mark_table_.clear(); mark_table_.clear();
...@@ -62,6 +63,7 @@ class OrderedSet { ...@@ -62,6 +63,7 @@ class OrderedSet {
} }
// find the bestfit shape node block with var. // find the bestfit shape node block with var.
ir::Node* FindBestFitNode(ir::Node* var) const; ir::Node* FindBestFitNode(ir::Node* var) const;
ir::Node* FindNextBestFitNode(ir::Node* var, ir::Node* prev) const;
// map store non-const iterator, can not promise const // map store non-const iterator, can not promise const
int GetNodeIndexInPool(ir::Node* var); int GetNodeIndexInPool(ir::Node* var);
// pool all node to string // pool all node to string
......
...@@ -107,6 +107,52 @@ TEST(OrderedSet, Normal) { ...@@ -107,6 +107,52 @@ TEST(OrderedSet, Normal) {
ASSERT_EQ(pool.GetNodeIndexInPool(cache), 5); // match 4:[5,2] ASSERT_EQ(pool.GetNodeIndexInPool(cache), 5); // match 4:[5,2]
} }
} }
TEST(OrderedSet, FindBestFitNode) {
OrderedSet pool;
std::vector<std::unique_ptr<ir::Node>> nodes;
ProgramDesc prog;
BlockDesc* block_desc = prog.MutableBlock(0);
auto* op_desc = block_desc->AppendOp();
op_desc->SetType("dummy");
std::unique_ptr<ir::Node> op = ir::CreateNodeForTest(op_desc);
{
auto desc = block_desc->Var("a");
desc->SetShape({128, 128});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
{
auto desc = block_desc->Var("b");
desc->SetShape({128, 129});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
{
auto desc = block_desc->Var("c");
desc->SetShape({128, 128});
std::unique_ptr<ir::Node> node = ir::CreateNodeForTest(desc);
node->inputs.emplace_back(op.get());
nodes.emplace_back(std::move(node));
}
for (auto& node : nodes) {
pool.Insert(node.get());
}
// FindNextBestFitNode
auto* n = nodes[0].get();
auto* cache = pool.FindBestFitNode(n);
PADDLE_ENFORCE(cache->Name() == "a");
cache = pool.FindNextBestFitNode(n, cache);
PADDLE_ENFORCE(cache->Name() == "c");
cache = pool.FindNextBestFitNode(n, cache);
PADDLE_ENFORCE(cache->Name() == "b");
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -69,11 +69,20 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -69,11 +69,20 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
for (auto& var : op->outputs) { for (auto& var : op->outputs) {
if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 || if (var->IsVar() && !var->IsCtrlVar() && skip_set_.count(var->Name())) {
skip_set_.count(var->Name())) VLOG(3) << "Skip set contains variable of " << var->Name()
<< "disable reuse on it. skipped";
continue; continue;
}
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) {
ir::Node* cache = pool_.FindBestFitNode(var); ir::Node* cache = pool_.FindBestFitNode(var);
while (cache != nullptr && var->Name() == cache->Name()) {
VLOG(3) << "The same cache variable is cascade reused. "
<< cache->Name() << " is re-filled to the pool after "
<< "the reused op is finished. Current op can not "
<< "replace it again. Skip this candidate.";
cache = pool_.FindNextBestFitNode(var, cache);
}
if (var->Name() == FLAGS_memory_optimize_debug) { if (var->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "start match var " << DebugString(var) << " of op " VLOG(3) << "start match var " << DebugString(var) << " of op "
<< op->Name(); << op->Name();
...@@ -82,42 +91,37 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -82,42 +91,37 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
<< ((cache == nullptr) ? "False" : "True"); << ((cache == nullptr) ? "False" : "True");
} }
if (cache == nullptr) continue; if (cache != nullptr) {
if (var->Name() == cache->Name()) {
VLOG(3) << "The same cache variable is cascade reused." << var->Name()
<< " is re-filled to the pool after"
<< "the reused op is finished. Current op can not "
<< "replace it again. Skip this candidate.";
continue;
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache); int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
VLOG(3) << string::Sprintf( VLOG(3) << string::Sprintf(
"!!! %s, %s => %s, cache idx %d, pool size %d", "!!! %s, %s => %s, cache idx %d, pool size %d",
std::to_string(reuse_id++), DebugString(var), DebugString(cache), std::to_string(reuse_id++), DebugString(var), DebugString(cache),
node_idx_in_pool, static_cast<int>(pool_.size())); node_idx_in_pool, static_cast<int>(pool_.size()));
// NOTE(dzhwinter): update the ProgramDesc/IR Graph
// and the CFG Graph on the fly.
//
// IR Graph define the dependence relationship between nodes.
//
// ProgramDesc defines the input/output vars. Its used in
// CreateOp, CreateVar when running happens.
//
// CFG Graph store the liveness information, when reuse happens
// we also need to update the variable liveness.
const std::string var_name = var->Name();
const std::string cache_name = cache->Name();
// update CFG Graph on the fly. cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
// reused var maybe re-fill into the pool RenameVarInGraphDesc(var_name, cache_name, idx);
cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
// NOTE(dzhwinter): we need to both update the ProgramDesc pool_.Erase(cache_name);
// and IR Graph. because op_desc/var_desc is used in CreateOp, }
// CreateVar when running happens. But IR Graph }
// define the dependence relationship between nodes.
RenameVarInGraphDesc(var->Name(), cache->Name(), idx);
RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get());
pool_.Erase(cache);
} }
// fill the pool // fill the pool
std::unordered_set<std::string> unlived_vars;
for (auto var : cfg_->LiveIn(op)) { for (auto var : cfg_->LiveIn(op)) {
if (cfg_->LiveOut(op).count(var) == 0) { if (cfg_->LiveOut(op).count(var) == 0) {
unlived_vars.emplace(var);
}
}
for (auto var : unlived_vars) {
ir::Node* var_node = cfg_->GetNodeByName(var, op); ir::Node* var_node = cfg_->GetNodeByName(var, op);
if (var_node == nullptr || var_node->IsCtrlVar()) continue;
if (NodeCanReused(var_node) && !pool_.Has(var_node)) { if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
pool_.Insert(var_node); pool_.Insert(var_node);
} }
...@@ -273,8 +277,7 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -273,8 +277,7 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
// redirect the input to the latest version of cache_var // redirect the input to the latest version of cache_var
for (auto* node : op->inputs) { for (auto* node : op->inputs) {
if (node->Name() == var) { if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); ir::Node* cache_node = var_nodes_[cache_var].back();
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache_node // swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(), cache_node->outputs.insert(cache_node->outputs.end(),
...@@ -283,11 +286,15 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -283,11 +286,15 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
auto* prev_op = node->inputs[0]; auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node); cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) { for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
// erase unused node
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
} }
} }
...@@ -307,15 +314,14 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -307,15 +314,14 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node); cache_node);
} }
}
}
}
// release node of unused var in graph // erase unused node
for (auto* node : var_nodes_[var]) { auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node); graph->RemoveNode(node);
} }
var_nodes_.at(var).clear(); }
}
} }
} // namespace details } // namespace details
......
...@@ -179,11 +179,11 @@ TEST(InferInplace, SingleOpInplaceInToOut) { ...@@ -179,11 +179,11 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 64}); prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 64, 128, 128});
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16}); prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16, 128, 128});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block()); auto in_to_outs = infer_inplace(*op, op->Block());
...@@ -201,11 +201,11 @@ TEST(InferInplace, SingleGradOpInplaceInToOut) { ...@@ -201,11 +201,11 @@ TEST(InferInplace, SingleGradOpInplaceInToOut) {
op->SetOutput(GradVarName("X"), {"test2_a", "test2_b", "test2_c"}); op->SetOutput(GradVarName("X"), {"test2_a", "test2_b", "test2_c"});
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 16}); prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16}); prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16, 1024, 1024});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block()); auto in_to_outs = infer_inplace(*op, op->Block());
...@@ -233,12 +233,12 @@ TEST(InferInplace, MultiOutInplaceInToOut) { ...@@ -233,12 +233,12 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog.MutableBlock(0)->Var("o0"); prog.MutableBlock(0)->Var("o0");
prog.MutableBlock(0)->Var("y0"); prog.MutableBlock(0)->Var("y0");
prog.MutableBlock(0)->Var("z0"); prog.MutableBlock(0)->Var("z0");
prog.MutableBlock(0)->Var("a0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("a0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("b0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("b0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("c0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("c0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("o0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("o0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("y0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("y0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block()); auto in_to_outs = infer_inplace(*op, op->Block());
...@@ -267,12 +267,12 @@ TEST(InferInplace, MultiGradInplaceInToOut) { ...@@ -267,12 +267,12 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog.MutableBlock(0)->Var("o0"); prog.MutableBlock(0)->Var("o0");
prog.MutableBlock(0)->Var("y0"); prog.MutableBlock(0)->Var("y0");
prog.MutableBlock(0)->Var("z0"); prog.MutableBlock(0)->Var("z0");
prog.MutableBlock(0)->Var("a0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("a0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("b0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("b0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("c0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("c0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("o0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("o0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("y0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("y0")->SetShape({32, 16, 1024, 1024});
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16}); prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_;
auto in_to_outs = infer_inplace(*op, op->Block()); auto in_to_outs = infer_inplace(*op, op->Block());
......
...@@ -177,7 +177,10 @@ class CompiledProgram(object): ...@@ -177,7 +177,10 @@ class CompiledProgram(object):
# FIXME(dzhwinter): enable_inplace should be after memory_optimize # FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass. # if turn on python memory optimize, turn off the inplace_pass.
self._build_strategy.enable_inplace = False if self._program._is_mem_optimized else True if self._build_strategy.memory_optimize is None:
self._build_strategy.memory_optimize = False if main._is_mem_optimized else True
if self._build_strategy.enable_inplace is None:
self._build_strategy.enable_inplace = False if main._is_mem_optimized else True
if self._build_strategy.num_trainers > 1 and trainers_endpoints: if self._build_strategy.num_trainers > 1 and trainers_endpoints:
assert self._build_strategy.num_trainers == len( assert self._build_strategy.num_trainers == len(
......
...@@ -148,6 +148,8 @@ class ParallelExecutor(object): ...@@ -148,6 +148,8 @@ class ParallelExecutor(object):
else framework.default_main_program() else framework.default_main_program()
# FIXME(dzhwinter): enable_inplace should be after memory_optimize # FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass. # if turn on python memory optimize, turn off the inplace_pass.
if build_strategy.memory_optimize is None:
build_strategy.memory_optimize = False if main._is_mem_optimized else True
if build_strategy.enable_inplace is None: if build_strategy.enable_inplace is None:
build_strategy.enable_inplace = False if main._is_mem_optimized else True build_strategy.enable_inplace = False if main._is_mem_optimized else True
scope = scope if scope is not None else executor.global_scope() scope = scope if scope is not None else executor.global_scope()
......
...@@ -77,6 +77,7 @@ list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op) ...@@ -77,6 +77,7 @@ list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op) list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
list(REMOVE_ITEM TEST_OPS test_imperative_resnet) list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
list(REMOVE_ITEM TEST_OPS test_imperative_optimizer) list(REMOVE_ITEM TEST_OPS test_imperative_optimizer)
list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -107,6 +108,9 @@ py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SE ...@@ -107,6 +108,9 @@ py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SE
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450) set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450)
py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL)
if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL)
endif()
if(NOT APPLE) if(NOT APPLE)
py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL)
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
......
...@@ -79,7 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize
# python memory optimization is conflict with inplace pass. # python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way. # Use ir graph memory optimization after inplace pass is the correct way.
build_strategy.enable_inplace = False if memory_opt else enable_inplace build_strategy.enable_inplace = False if memory_opt else enable_inplace
......
...@@ -121,6 +121,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -121,6 +121,8 @@ class TestMNIST(TestParallelExecutorBase):
regularization=fluid.regularizer.L2Decay(1e-6)) regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer return optimizer
# NOTE(dzh):
# need to make it compatible with elewise fuse act
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence( not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model, model,
feed_dict={"image": img, feed_dict={"image": img,
...@@ -128,6 +130,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -128,6 +130,7 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda, use_cuda=use_cuda,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
memory_opt=False, memory_opt=False,
use_ir_memory_optimize=False,
optimizer=_optimizer) optimizer=_optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence( fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model, model,
...@@ -136,6 +139,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -136,6 +139,7 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda, use_cuda=use_cuda,
fuse_elewise_add_act_ops=True, fuse_elewise_add_act_ops=True,
memory_opt=False, memory_opt=False,
use_ir_memory_optimize=False,
optimizer=_optimizer) optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
os.environ[
'RECORDIO_FILENAME'] = '/tmp/ir_memory_optimize_transformer.wmt16.recordio'
from test_parallel_executor_transformer import TestTransformer
from test_parallel_executor_transformer import transformer
# NOTE(dzhwinter): test diferent strategy colisions.
# open the eager delete tensor strategy by default.
class TestTransformerWithIR(TestTransformer):
def test_main(self):
if core.is_compiled_with_cuda():
# check python transpiler
self.check_network_convergence(
transformer,
use_cuda=True,
memory_opt=True,
use_ir_memory_optimize=False)
# check IR memory optimize
self.check_network_convergence(
transformer,
use_cuda=True,
memory_opt=False,
use_ir_memory_optimize=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册