提交 9f001c65 编写于 作者: D dzhwinter

skip dist. test=develop

上级 2561a6fc
......@@ -301,7 +301,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
// 3. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if (view_.ReusedInPythonMemOpt(out_node->Name())) {
if (view_.InSkipSet(out_node->Name())) {
VLOG(4) << string::Sprintf(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
......@@ -385,7 +385,7 @@ void GraphView::Build(ir::Graph* g) {
// resolve data harzards depends on the var nodes in right order.
ops_ = SortOpLikeDescOrder(*g);
// track the nodes which reused previous node in Python memory optimize.
// 1. track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std::unordered_set<std::string> all_vars;
for (auto& node : g->Nodes()) {
......@@ -399,11 +399,28 @@ void GraphView::Build(ir::Graph* g) {
}
}
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
for (auto& node : g->Nodes()) {
if (!node->IsOp()) continue;
if (node->Name() == "send") {
for (auto& in : node->inputs) {
dup_nodes_.emplace(in->Name());
}
}
if (node->Name() == "recv") {
for (auto& out : node->outputs) {
dup_nodes_.emplace(out->Name());
}
}
}
}
const std::vector<ir::Node*>& GraphView::AllOps() { return ops_; }
bool GraphView::ReusedInPythonMemOpt(const std::string& var) const {
bool GraphView::InSkipSet(const std::string& var) const {
return dup_nodes_.count(var);
}
......
......@@ -41,11 +41,14 @@ class GraphView {
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
// Will Deperated in the future.
// NOTE(dzhwinter) : Python memory optimize will reuse
// NOTE(dzhwinter) :
// 1. Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
bool ReusedInPythonMemOpt(const std::string& var) const;
// 2. DistributeTranspiler will use unique name to
// map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const;
private:
std::vector<ir::Node*> ops_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册