提交 9f001c65 编写于 作者: D dzhwinter

skip dist. test=develop

上级 2561a6fc
...@@ -301,7 +301,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -301,7 +301,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
// 3. if output has been memory optimize by python(fluid.memory_optmize()). // 3. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future. // this candidate can not be inplaced. Will be deprecated in the future.
if (view_.ReusedInPythonMemOpt(out_node->Name())) { if (view_.InSkipSet(out_node->Name())) {
VLOG(4) << string::Sprintf( VLOG(4) << string::Sprintf(
"Skiped %s => %s reused previous memory block in python memory " "Skiped %s => %s reused previous memory block in python memory "
"optmize," "optmize,"
...@@ -385,7 +385,7 @@ void GraphView::Build(ir::Graph* g) { ...@@ -385,7 +385,7 @@ void GraphView::Build(ir::Graph* g) {
// resolve data harzards depends on the var nodes in right order. // resolve data harzards depends on the var nodes in right order.
ops_ = SortOpLikeDescOrder(*g); ops_ = SortOpLikeDescOrder(*g);
// track the nodes which reused previous node in Python memory optimize. // 1. track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph. // these node can not be inplaced, otherwise may generate a circle in graph.
std::unordered_set<std::string> all_vars; std::unordered_set<std::string> all_vars;
for (auto& node : g->Nodes()) { for (auto& node : g->Nodes()) {
...@@ -399,11 +399,28 @@ void GraphView::Build(ir::Graph* g) { ...@@ -399,11 +399,28 @@ void GraphView::Build(ir::Graph* g) {
} }
} }
} }
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue;
if (node->Name() == "send") {
for (auto& in : node->inputs) {
dup_nodes_.emplace(in->Name());
}
}
if (node->Name() == "recv") {
for (auto& out : node->outputs) {
dup_nodes_.emplace(out->Name());
}
}
}
} }
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; } const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
bool GraphView::ReusedInPythonMemOpt(const std::string& var) const { bool GraphView::InSkipSet(const std::string& var) const {
return dup_nodes_.count(var); return dup_nodes_.count(var);
} }
......
...@@ -41,11 +41,14 @@ class GraphView { ...@@ -41,11 +41,14 @@ class GraphView {
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var); std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
// Will Deperated in the future. // Will Deperated in the future.
// NOTE(dzhwinter) : Python memory optimize will reuse // NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may // memory based var name, so different op output may
// have the same variable name. enable inplace on such node // have the same variable name. enable inplace on such node
// will generate a circle in ssa graph. // will generate a circle in ssa graph.
bool ReusedInPythonMemOpt(const std::string& var) const; // 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const;
private: private:
std::vector<ir::Node*> ops_; std::vector<ir::Node*> ops_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册