未验证 提交 14e1e165 编写于 作者: C chengduo 提交者: GitHub

update alloc_continuous_space_for_grad_pass (#18287)

test=develop
上级 7e61baaa
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
...@@ -84,16 +85,19 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -84,16 +85,19 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
} }
if (params_grads.size() == 0) { if (params_grads.size() == 0) {
LOG(WARNING) << "Doesn't find gradients"; LOG(INFO) << "Doesn't find gradients";
return; return;
} }
std::unordered_map<std::string, ir::Node *> vars; std::unordered_map<std::string, ir::Node *> var_name2node;
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
var_name2node_set;
for (ir::Node *node : result.Nodes()) { for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter // Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer; // is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node); var_name2node.emplace(node->Var()->Name(), node);
var_name2node_set[node->Var()->Name()].emplace(node);
} }
} }
...@@ -101,7 +105,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -101,7 +105,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams); result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams. // Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params); SetGroupGradsAndParams(var_name2node, params_grads, &group_grads_params);
params_grads.clear(); params_grads.clear();
for (auto &group_p_g : group_grads_params) { for (auto &group_p_g : group_grads_params) {
...@@ -116,9 +120,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -116,9 +120,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
auto dtype = kDefaultDtype; auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) { for (auto &p_g : params_grads) {
// Get gradient var // Get gradient var
auto iter = vars.find(p_g.second); auto iter = var_name2node.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second); PADDLE_ENFORCE(iter != var_name2node.end(), "%s is not found.",
iter->second->Var()->SetPersistable(true); p_g.second);
// Set persistable
auto same_nodes = var_name2node_set.find(p_g.second);
PADDLE_ENFORCE(same_nodes != var_name2node_set.end(), "%s is not found.",
p_g.second);
for (auto it : same_nodes->second) {
it->Var()->SetPersistable(true);
}
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType())); PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
...@@ -151,7 +162,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -151,7 +162,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
"%s is duplicate in FusedVars.", fused_var_name); "%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name); fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, InitFusedVarsAndAllocSpaceForVars(places, local_scopes, var_name2node,
fused_var_name, params_grads); fused_var_name, params_grads);
} }
......
...@@ -103,6 +103,33 @@ bool HasCircle(const Graph &graph) { ...@@ -103,6 +103,33 @@ bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph), nullptr); return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
} }
bool VarDescIsConsistency(const Graph &graph) {
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
var_name2node_set;
for (ir::Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
var_name2node_set[node->Var()->Name()].emplace(node);
}
}
for (auto &iter : var_name2node_set) {
auto &first_node = *iter.second.begin();
bool is_persistable = std::any_of(iter.second.begin(), iter.second.end(),
[&first_node](const ir::Node *node) {
return node->Var()->Persistable();
});
if (is_persistable) {
bool is_consistency =
std::all_of(iter.second.begin(), iter.second.end(),
[&first_node](const ir::Node *node) {
return *node->Var() == *first_node->Var();
});
if (!is_consistency) return false;
}
}
return true;
}
bool FindCircleSubGraph(const Graph &graph, bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles) { std::vector<std::vector<ir::Node *>> *circles) {
return HasCircleInternal(BuildOperationAdjList(graph), circles); return HasCircleInternal(BuildOperationAdjList(graph), circles);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -36,6 +37,9 @@ struct NodeComp { ...@@ -36,6 +37,9 @@ struct NodeComp {
// Test if the graph contains circle. // Test if the graph contains circle.
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
// Check if the var desc of node is consistency.
bool VarDescIsConsistency(const Graph &graph);
// Find All Circles for debugging, // Find All Circles for debugging,
// store all subgraph in circles. // store all subgraph in circles.
bool FindCircleSubGraph(const Graph &graph, bool FindCircleSubGraph(const Graph &graph,
......
...@@ -38,6 +38,8 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -38,6 +38,8 @@ Graph* Pass::Apply(Graph* graph) const {
// TODO(panyx0718): Add more verifications. // TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph), PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass. Generated graph shouldn't has cycle."); "Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency.");
PADDLE_ENFORCE(graph == native_graph, PADDLE_ENFORCE(graph == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't " "Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)"); "return a new graph.(For the need of pybind11)");
......
...@@ -320,12 +320,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -320,12 +320,14 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
#endif #endif
if (!member_->use_all_reduce_) { if (!member_->use_all_reduce_) {
PADDLE_ENFORCE(places.size() > 1, if (places.size() == 1) {
"If you set build_strategy.reduce with 'Reduce'," LOG(INFO) << "If you set build_strategy.reduce with 'Reduce',"
"the number of places must be greater than 1."); "the number of places should be greater than 1.";
member_->use_all_reduce_ = true;
}
} }
LOG(WARNING) << string::Sprintf( LOG(INFO) << string::Sprintf(
"The number of %s, which is used in ParallelExecutor, is %lu. And " "The number of %s, which is used in ParallelExecutor, is %lu. And "
"the Program will be copied %lu copies", "the Program will be copied %lu copies",
(member_->use_cuda_ ? "CUDAPlace" : "CPUPlace"), places.size(), (member_->use_cuda_ ? "CUDAPlace" : "CPUPlace"), places.size(),
...@@ -364,10 +366,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -364,10 +366,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// choice the execution strategy. // choice the execution strategy.
build_strategy.enable_parallel_graph_ = build_strategy.enable_parallel_graph_ =
EnableParallelGraphExecution(*graph, exec_strategy, build_strategy); EnableParallelGraphExecution(*graph, exec_strategy, build_strategy);
if (build_strategy.enable_parallel_graph_) if (build_strategy.enable_parallel_graph_) {
VLOG(0) << "The Executor would execute the graph by ParallelGraph " LOG(INFO) << "The Executor would execute the graph by ParallelGraph "
"Execution which can get better performance," "Execution which can get better performance,"
<< "you can force it off by env FLAGS_enable_parallel_graph=0"; << "you can force it off by env FLAGS_enable_parallel_graph=0";
}
if (member_->use_cuda_ && member_->nranks_ > 1) { if (member_->use_cuda_ && member_->nranks_ > 1) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <google/protobuf/util/message_differencer.h>
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -264,5 +266,10 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() { ...@@ -264,5 +266,10 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
} }
} }
bool operator==(const VarDesc &left, const VarDesc &right) {
return left.Proto()->SerializeAsString() ==
right.Proto()->SerializeAsString();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -67,6 +67,8 @@ class VarDesc { ...@@ -67,6 +67,8 @@ class VarDesc {
proto::VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() { return &desc_; }
const proto::VarDesc *Proto() const { return &desc_; }
std::string Name() const { return desc_.name(); } std::string Name() const { return desc_.name(); }
void SetName(std::string name) { desc_.set_name(name); } void SetName(std::string name) { desc_.set_name(name); }
...@@ -116,5 +118,7 @@ class VarDesc { ...@@ -116,5 +118,7 @@ class VarDesc {
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
bool operator==(const VarDesc &left, const VarDesc &right);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册