未验证 提交 cac315f9 编写于 作者: C chengduo 提交者: GitHub

update alloc_continuous_space_for_grad_pass (#18288)

test=release/1.5
上级 618c2c75
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
...@@ -84,16 +85,19 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -84,16 +85,19 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
if (params_grads.size() == 0) { if (params_grads.size() == 0) {
LOG(WARNING) << "Doesn't find gradients"; LOG(INFO) << "Doesn't find gradients";
return; return;
} }
std::unordered_map<std::string, ir::Node *> vars; std::unordered_map<std::string, ir::Node *> var_name2node;
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
var_name2node_set;
for (ir::Node *node : result.Nodes()) { for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter // Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer; // is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node); var_name2node.emplace(node->Var()->Name(), node);
var_name2node_set[node->Var()->Name()].emplace(node);
} }
} }
...@@ -101,7 +105,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -101,7 +105,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams); result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams. // Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params); SetGroupGradsAndParams(var_name2node, params_grads, &group_grads_params);
params_grads.clear(); params_grads.clear();
for (auto &group_p_g : group_grads_params) { for (auto &group_p_g : group_grads_params) {
...@@ -116,9 +120,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -116,9 +120,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
auto dtype = kDefaultDtype; auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) { for (auto &p_g : params_grads) {
// Get gradient var // Get gradient var
auto iter = vars.find(p_g.second); auto iter = var_name2node.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second); PADDLE_ENFORCE(iter != var_name2node.end(), "%s is not found.",
iter->second->Var()->SetPersistable(true); p_g.second);
// Set persistable
auto same_nodes = var_name2node_set.find(p_g.second);
PADDLE_ENFORCE(same_nodes != var_name2node_set.end(), "%s is not found.",
p_g.second);
for (auto it : same_nodes->second) {
it->Var()->SetPersistable(true);
}
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType())); PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
...@@ -151,7 +162,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -151,7 +162,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
"%s is duplicate in FusedVars.", fused_var_name); "%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name); fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, InitFusedVarsAndAllocSpaceForVars(places, local_scopes, var_name2node,
fused_var_name, params_grads); fused_var_name, params_grads);
} }
......
...@@ -103,6 +103,33 @@ bool HasCircle(const Graph &graph) { ...@@ -103,6 +103,33 @@ bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph), nullptr); return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
} }
bool VarDescIsConsistency(const Graph &graph) {
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
var_name2node_set;
for (ir::Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
var_name2node_set[node->Var()->Name()].emplace(node);
}
}
for (auto &iter : var_name2node_set) {
auto &first_node = *iter.second.begin();
bool is_persistable = std::any_of(iter.second.begin(), iter.second.end(),
[&first_node](const ir::Node *node) {
return node->Var()->Persistable();
});
if (is_persistable) {
bool is_consistency =
std::all_of(iter.second.begin(), iter.second.end(),
[&first_node](const ir::Node *node) {
return *node->Var() == *first_node->Var();
});
if (!is_consistency) return false;
}
}
return true;
}
bool FindCircleSubGraph(const Graph &graph, bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles) { std::vector<std::vector<ir::Node *>> *circles) {
return HasCircleInternal(BuildOperationAdjList(graph), circles); return HasCircleInternal(BuildOperationAdjList(graph), circles);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -36,6 +37,9 @@ struct NodeComp { ...@@ -36,6 +37,9 @@ struct NodeComp {
// Test if the graph contains circle. // Test if the graph contains circle.
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
// Check if the var desc of node is consistency.
bool VarDescIsConsistency(const Graph &graph);
// Find All Circles for debugging, // Find All Circles for debugging,
// store all subgraph in circles. // store all subgraph in circles.
bool FindCircleSubGraph(const Graph &graph, bool FindCircleSubGraph(const Graph &graph,
......
...@@ -38,6 +38,8 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -38,6 +38,8 @@ Graph* Pass::Apply(Graph* graph) const {
// TODO(panyx0718): Add more verifications. // TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph), PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass. Generated graph shouldn't has cycle."); "Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency.");
PADDLE_ENFORCE(graph == native_graph, PADDLE_ENFORCE(graph == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't " "Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)"); "return a new graph.(For the need of pybind11)");
......
...@@ -320,12 +320,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -320,12 +320,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
#endif #endif
if (!member_->use_all_reduce_) { if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1, if (places.size() == 1) {
"If you set build_strategy.reduce with 'Reduce'," LOG(INFO) << "If you set build_strategy.reduce with 'Reduce',"
"the number of places must be greater than 1."); "the number of places should be greater than 1.";
member_->use_all_reduce_ = true;
}
} }
LOG(WARNING) << string::Sprintf( LOG(INFO) << string::Sprintf(
"The number of %s, which is used in ParallelExecutor, is %lu. And " "The number of %s, which is used in ParallelExecutor, is %lu. And "
"the Program will be copied %lu copies", "the Program will be copied %lu copies",
(member_->use_cuda_ ? "CUDAPlace" : "CPUPlace"), places.size(), (member_->use_cuda_ ? "CUDAPlace" : "CPUPlace"), places.size(),
...@@ -364,10 +366,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -364,10 +366,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// choice the execution strategy. // choice the execution strategy.
build_strategy.enable_parallel_graph_ = build_strategy.enable_parallel_graph_ =
EnableParallelGraphExecution(*graph, exec_strategy, build_strategy); EnableParallelGraphExecution(*graph, exec_strategy, build_strategy);
if (build_strategy.enable_parallel_graph_) if (build_strategy.enable_parallel_graph_) {
VLOG(0) << "The Executor would execute the graph by ParallelGraph " LOG(INFO) << "The Executor would execute the graph by ParallelGraph "
"Execution which can get better performance," "Execution which can get better performance,"
<< "you can force it off by env FLAGS_enable_parallel_graph=0"; << "you can force it off by env FLAGS_enable_parallel_graph=0";
}
if (member_->use_cuda_ && member_->nranks_ > 1) { if (member_->use_cuda_ && member_->nranks_ > 1) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -264,5 +264,10 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() { ...@@ -264,5 +264,10 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
} }
} }
bool operator==(const VarDesc &left, const VarDesc &right) {
return left.Proto()->SerializeAsString() ==
right.Proto()->SerializeAsString();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -67,6 +67,8 @@ class VarDesc { ...@@ -67,6 +67,8 @@ class VarDesc {
proto::VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() { return &desc_; }
const proto::VarDesc *Proto() const { return &desc_; }
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
void SetName(std::string name) { desc_.set_name(name); } void SetName(std::string name) { desc_.set_name(name); }
...@@ -116,5 +118,7 @@ class VarDesc { ...@@ -116,5 +118,7 @@ class VarDesc {
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
bool operator==(const VarDesc &left, const VarDesc &right);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册