未验证 提交 0f2dc4ca 编写于 作者: C csy0225 提交者: GitHub

Fix delete_isolated_node_pass problem (#52856)

上级 ea1c9b89
...@@ -50,8 +50,8 @@ class DeleteIsolatedNodePass : public Pass { ...@@ -50,8 +50,8 @@ class DeleteIsolatedNodePass : public Pass {
std::unordered_set<std::string>* delete_node_names) const; std::unordered_set<std::string>* delete_node_names) const;
int UpdateControlFlowOp( int UpdateControlFlowOp(
int current_graph_index,
Graph* graph, Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const; const std::unordered_set<std::string>& delete_node_names) const;
const std::map<std::string, std::string> control_flow_op_input_map_{ const std::map<std::string, std::string> control_flow_op_input_map_{
...@@ -86,20 +86,9 @@ void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const { ...@@ -86,20 +86,9 @@ void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const {
LOG(INFO) << "--- delete " << delete_counts << " isolated nodes"; LOG(INFO) << "--- delete " << delete_counts << " isolated nodes";
} }
std::map<int, Graph*> block_id_graph_map;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
auto* sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsVar()) {
block_id_graph_map[node->GetVarNodeBlockId()] = sub_graph;
break;
}
}
}
int update_counts = 0; int update_counts = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) { for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
update_counts += UpdateControlFlowOp( update_counts += UpdateControlFlowOp(i, graph, delete_node_names);
graph->GetSubGraph(i), block_id_graph_map, delete_node_names);
} }
if (update_counts > 0) { if (update_counts > 0) {
LOG(INFO) << "--- update " << update_counts << " control flow ops"; LOG(INFO) << "--- update " << update_counts << " control flow ops";
...@@ -129,6 +118,7 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes( ...@@ -129,6 +118,7 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes(
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
block = node->Op()->Block(); block = node->Op()->Block();
break;
} }
} }
Scope& scope = graph->Get<framework::Scope>("__param_scope__"); Scope& scope = graph->Get<framework::Scope>("__param_scope__");
...@@ -160,11 +150,12 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes( ...@@ -160,11 +150,12 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes(
} }
int DeleteIsolatedNodePass::UpdateControlFlowOp( int DeleteIsolatedNodePass::UpdateControlFlowOp(
int current_graph_index,
Graph* graph, Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const { const std::unordered_set<std::string>& delete_node_names) const {
auto* cur_graph = graph->GetSubGraph(current_graph_index);
int update_counts = 0; int update_counts = 0;
for (auto* node : graph->Nodes()) { for (auto* node : cur_graph->Nodes()) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
auto op_type = node->Op()->Type(); auto op_type = node->Op()->Type();
if (control_flow_op_input_map_.count(op_type) == 0) continue; if (control_flow_op_input_map_.count(op_type) == 0) continue;
...@@ -181,7 +172,7 @@ int DeleteIsolatedNodePass::UpdateControlFlowOp( ...@@ -181,7 +172,7 @@ int DeleteIsolatedNodePass::UpdateControlFlowOp(
auto* sub_block = PADDLE_GET_CONST(framework::BlockDesc*, auto* sub_block = PADDLE_GET_CONST(framework::BlockDesc*,
node->Op()->GetAttr("sub_block")); node->Op()->GetAttr("sub_block"));
auto* sub_graph = block_id_graph_map.at(sub_block->ID()); auto* sub_graph = graph->GetSubGraph(sub_block->ID());
std::unordered_set<std::string> sub_persistable_node_names; std::unordered_set<std::string> sub_persistable_node_names;
CollectReservedPersistableNodeNames(sub_graph, &sub_persistable_node_names); CollectReservedPersistableNodeNames(sub_graph, &sub_persistable_node_names);
for (auto sub_name : sub_persistable_node_names) { for (auto sub_name : sub_persistable_node_names) {
......
...@@ -70,6 +70,9 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper): ...@@ -70,6 +70,9 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
self.epsilon = 1e-05 self.epsilon = 1e-05
self.no_grad_set = None self.no_grad_set = None
self.set_attrs() self.set_attrs()
self.atol = 1e-5
if self.dtype == np.float16:
self.atol = 1e-2
np.random.seed(12345) np.random.seed(12345)
epsilon = self.epsilon epsilon = self.epsilon
...@@ -109,7 +112,7 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper): ...@@ -109,7 +112,7 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
pass pass
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0)) self.check_output_with_place(paddle.XPUPlace(0), atol=self.atol)
def test_check_grad(self): def test_check_grad(self):
self.check_grad_with_place( self.check_grad_with_place(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册