未验证 提交 cb9c59bd 编写于 作者: L liuwei1031 提交者: GitHub

cherry-pick PR 16547,16736,16739 test=release/1.4 (#16748)

* fix the bug of reusing different types of variables in memory_optimiz… (#16547)

* fix the bug of reusing different types of variables in memory_optimize_pass, test=develop

* disable SELECTED_ROWS AND LOD_TENSOR_ARRAY reusage, test=develop

* only use the latest version variable for inplace strategy (#16736)

* bug-fix, test=develop

* tweak code, test=develop

* cherry-pick PR 16547,16736,16739 test=release/1.4
上级 44f50cf4
...@@ -305,6 +305,12 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -305,6 +305,12 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name; VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
if (var_nodes_[in_var_name].back() != in_node) {
VLOG(4) << "SKIP since " << in_var_name
<< " is also used as output by other ops";
continue;
}
bool can_replace = true; bool can_replace = true;
if (in_var_name == out_var_name) { if (in_var_name == out_var_name) {
can_replace = false; can_replace = false;
...@@ -527,6 +533,9 @@ void GraphView::Build(ir::Graph* g) { ...@@ -527,6 +533,9 @@ void GraphView::Build(ir::Graph* g) {
}; };
for (auto& node : g->Nodes()) { for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
// avoid optimize the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) update_skip_set(node);
if (node->Name() == "send") update_skip_set(node); if (node->Name() == "send") update_skip_set(node);
if (node->Name() == "recv") update_skip_set(node); if (node->Name() == "recv") update_skip_set(node);
if (node->Name() == "prefetch") update_skip_set(node); if (node->Name() == "prefetch") update_skip_set(node);
......
...@@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) { ...@@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
return type_size * std::abs(size); return type_size * std::abs(size);
} }
size_t NodeSize(ir::Node* n) { size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); }
VarDesc* desc = nullptr;
// some op do not have block pointer
if (n->inputs[0]->Op() != nullptr) {
desc = FindVarDescInBlock(n);
} else {
desc = n->Var();
}
return NodeSize(*desc);
}
std::string DebugStringImpl(VarDesc* var) { std::string DebugStringImpl(VarDesc* var) {
std::stringstream ss; std::stringstream ss;
...@@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) { ...@@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
} }
std::string DebugString(ir::Node* var) { std::string DebugString(ir::Node* var) {
return DebugStringImpl(FindVarDescInBlock(var)); return DebugStringImpl(GetVarDesc(var));
} }
// NOTE(dzh): based ir node, if a large node has been reused // NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will // by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc. // have the small size. Find the original node shap from blockdesc.
VarDesc* FindVarDescInBlock(ir::Node* n) { VarDesc* GetVarDesc(ir::Node* n) {
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1); PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
BlockDesc* block = n->inputs[0]->Op()->Block(); return n->Var();
PADDLE_ENFORCE(block->HasVar(n->Name()),
string::Sprintf("Block do not has var %s", n->Name()));
return block->FindVar(n->Name());
} }
struct NodeComparator { struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const { bool operator()(ir::Node* lhs, ir::Node* rhs) const {
auto* lhs_desc = FindVarDescInBlock(lhs); if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false;
auto* rhs_desc = FindVarDescInBlock(rhs); auto* lhs_desc = GetVarDesc(lhs);
auto* rhs_desc = GetVarDesc(rhs);
// match data type // match data type
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) { if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
return false; return false;
...@@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) { ...@@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) {
return; return;
} }
auto* var_desc = FindVarDescInBlock(var); auto* var_desc = var->Var();
auto var_shape = var_desc->GetShape(); auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]); int batch_size = static_cast<int>(var_shape[0]);
...@@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) { ...@@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) {
Iter it = nodes_.begin(); Iter it = nodes_.begin();
while (it != nodes_.end()) { while (it != nodes_.end()) {
auto& prev = it->front(); auto& prev = it->front();
auto* cache_desc = FindVarDescInBlock(prev); auto* cache_desc = GetVarDesc(prev);
int cache_batch_size = cache_desc->GetShape()[0]; int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) || if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) { (cache_batch_size != -1 && batch_size != -1)) {
...@@ -336,10 +325,16 @@ int MinChunkSize() { ...@@ -336,10 +325,16 @@ int MinChunkSize() {
bool NodeCanReused(const VarDesc& node) { bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType(); auto type = node.GetType();
// only these types holds bulk of gpu memory // only these types holds bulk of gpu memory
if (!(type == proto::VarType::LOD_TENSOR || // FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
type == proto::VarType::LOD_TENSOR_ARRAY)) { // LOD_TENSOR_ARRAY re-use logic,
return false; // disable them in version 1.4
} // if (!(type == proto::VarType::LOD_TENSOR ||
// type == proto::VarType::SELECTED_ROWS ||
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
// return false;
// }
if (type != proto::VarType::LOD_TENSOR) return false;
// persistable variable is parameter // persistable variable is parameter
if (node.Persistable()) { if (node.Persistable()) {
return false; return false;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -140,11 +141,7 @@ size_t NodeSize(const VarDesc&); ...@@ -140,11 +141,7 @@ size_t NodeSize(const VarDesc&);
std::string DebugString(ir::Node* var); std::string DebugString(ir::Node* var);
// NOTE(dzhwinter) VarDesc* GetVarDesc(ir::Node* n);
// after node reuse, the replaced node shape is
// different with its VarDesc. So need to find the
// correct VarDesc in Block.
VarDesc* FindVarDescInBlock(ir::Node* n);
static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) { static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() && return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册