未验证 提交 0f2dc4ca 编写于 作者: C csy0225 提交者: GitHub

Fix delete_isolated_node_pass problem (#52856)

上级 ea1c9b89
......@@ -50,8 +50,8 @@ class DeleteIsolatedNodePass : public Pass {
std::unordered_set<std::string>* delete_node_names) const;
int UpdateControlFlowOp(
int current_graph_index,
Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const;
const std::map<std::string, std::string> control_flow_op_input_map_{
......@@ -86,20 +86,9 @@ void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const {
LOG(INFO) << "--- delete " << delete_counts << " isolated nodes";
}
std::map<int, Graph*> block_id_graph_map;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
auto* sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsVar()) {
block_id_graph_map[node->GetVarNodeBlockId()] = sub_graph;
break;
}
}
}
int update_counts = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
update_counts += UpdateControlFlowOp(
graph->GetSubGraph(i), block_id_graph_map, delete_node_names);
update_counts += UpdateControlFlowOp(i, graph, delete_node_names);
}
if (update_counts > 0) {
LOG(INFO) << "--- update " << update_counts << " control flow ops";
......@@ -129,6 +118,7 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes(
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
block = node->Op()->Block();
break;
}
}
Scope& scope = graph->Get<framework::Scope>("__param_scope__");
......@@ -160,11 +150,12 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes(
}
int DeleteIsolatedNodePass::UpdateControlFlowOp(
int current_graph_index,
Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const {
auto* cur_graph = graph->GetSubGraph(current_graph_index);
int update_counts = 0;
for (auto* node : graph->Nodes()) {
for (auto* node : cur_graph->Nodes()) {
if (!node->IsOp()) continue;
auto op_type = node->Op()->Type();
if (control_flow_op_input_map_.count(op_type) == 0) continue;
......@@ -181,7 +172,7 @@ int DeleteIsolatedNodePass::UpdateControlFlowOp(
auto* sub_block = PADDLE_GET_CONST(framework::BlockDesc*,
node->Op()->GetAttr("sub_block"));
auto* sub_graph = block_id_graph_map.at(sub_block->ID());
auto* sub_graph = graph->GetSubGraph(sub_block->ID());
std::unordered_set<std::string> sub_persistable_node_names;
CollectReservedPersistableNodeNames(sub_graph, &sub_persistable_node_names);
for (auto sub_name : sub_persistable_node_names) {
......
......@@ -70,6 +70,9 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
self.epsilon = 1e-05
self.no_grad_set = None
self.set_attrs()
self.atol = 1e-5
if self.dtype == np.float16:
self.atol = 1e-2
np.random.seed(12345)
epsilon = self.epsilon
......@@ -109,7 +112,7 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
pass
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
self.check_output_with_place(paddle.XPUPlace(0), atol=self.atol)
def test_check_grad(self):
self.check_grad_with_place(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册