提交 07a8ec78 编写于 作者: L Li Xinqi 提交者: GitHub

Bugfix no tick diff (#1614)

* group by has_diff

* rm unnecessary identity


Former-commit-id: 63f2cfd61337c821e0fb6215b592231ccee584d4
上级 31e6d157
......@@ -153,14 +153,17 @@ void OpGraph::InitEdges() {
}
void OpGraph::UpdateOpNodeHasInDiff() {
auto HasIndiff = [&](const OpNode* op_node) -> bool {
TopoForEachNode([&](OpNode* op_node) {
bool has_diff = false;
for (OpEdge* edge : op_node->in_edges()) {
if (edge->src_node()->has_in_diff()) { return true; }
if (edge->src_node()->has_model_diff()) { return true; }
if (edge->src_node()->has_in_diff() || edge->src_node()->has_model_diff()) {
edge->set_has_diff(true);
has_diff = true;
break;
}
}
return false;
};
TopoForEachNode([&](OpNode* op_node) { op_node->set_has_in_diff(HasIndiff(op_node)); });
op_node->set_has_in_diff(has_diff);
});
}
void OpGraph::InferNodeBlobDesc() const {
......
......@@ -54,16 +54,20 @@ class OpEdge final : public Edge<OpNode, OpEdge> {
OF_DISALLOW_COPY_AND_MOVE(OpEdge);
explicit OpEdge(const std::vector<LogicalBlobId>& lbis,
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns)
: lbis_(lbis), lbi2ibns_(lbi2ibns) {}
: lbis_(lbis), lbi2ibns_(lbi2ibns), has_diff_(false) {}
~OpEdge() = default;
const std::vector<LogicalBlobId>& lbis() const { return lbis_; }
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns() const { return lbi2ibns_; }
bool has_diff() const { return has_diff_; }
std::string VisualStr() const override;
void set_has_diff(bool val) { has_diff_ = val; }
private:
std::vector<LogicalBlobId> lbis_;
HashMap<LogicalBlobId, std::vector<std::string>> lbi2ibns_;
bool has_diff_;
};
class OpGraph final : public Graph<OpNode, OpEdge> {
......
......@@ -443,9 +443,21 @@ void JobDesc::ConvertPseudoChainToChain() {
if (chain_nodes.size() - source_nodes.size() <= 2) { return; }
const OpNode* first_node = *source_nodes.begin();
if (first_node->parallel_desc().device_type() == DeviceType::kCPU) { return; }
AddIdentityOpAndReconnect("pseudo_chain_header_", &job_conf_, source_edges,
MutOperatorConf4OpName,
*ParallelConf4OpName(first_node->op().op_name()));
HashMap<bool, std::vector<OpEdge*>> has_diff2source_edges;
for (OpEdge* edge : source_edges) { has_diff2source_edges[edge->has_diff()].push_back(edge); }
for (const auto& pair : has_diff2source_edges) {
HashSet<OpNode*> src_nodes;
HashSet<OpNode*> dst_nodes;
for (OpEdge* edge : pair.second) {
src_nodes.emplace(edge->src_node());
dst_nodes.emplace(edge->dst_node());
}
if (src_nodes.size() > 1 && dst_nodes.size() > 1) {
AddIdentityOpAndReconnect("pseudo_chain_header_", &job_conf_, pair.second,
MutOperatorConf4OpName,
*ParallelConf4OpName(first_node->op().op_name()));
}
}
});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册