未验证 提交 29e540af 编写于 作者: H huangxu96 提交者: GitHub

add python interface of sub_graph (#36120) (#38235)

Add python interface of subgraph: 1. all_sub_graphs() 2. get_sub_graph(idx)
上级 d70a06cc
...@@ -125,7 +125,15 @@ void BindGraph(py::module *m) { ...@@ -125,7 +125,15 @@ void BindGraph(py::module *m) {
return_value_policy::reference) return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard) .def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc", &Graph::OriginProgram, .def("origin_program_desc", &Graph::OriginProgram,
return_value_policy::reference); return_value_policy::reference)
.def("sub_graph_size", &Graph::SubGraphsSize)
.def("get_sub_graph", [](Graph &self, int i) {
/* Here we use a lambda function as an empty deleter to avoid the double
free of smart pointer.
Otherwise, this shared pointer will be free both in python and
cpp scope, which will lead a core dumped. */
return std::shared_ptr<Graph>(self.GetSubGraph(i), [](Graph *) {});
});
} }
void BindNode(py::module *m) { void BindNode(py::module *m) {
......
...@@ -3960,6 +3960,23 @@ class IrGraph(object): ...@@ -3960,6 +3960,23 @@ class IrGraph(object):
""" """
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def all_sub_graphs(self, for_test=False):
"""
Return all sub_graphs included in the main graph as a set.
"""
return [
IrGraph(
self.graph.get_sub_graph(i), for_test=for_test)
for i in range(self.graph.sub_graph_size())
]
def get_sub_graph(self, i, for_test=False):
"""
Return i-th sub_graph in the main graph.
"""
return IrGraph(self.graph.get_sub_graph(i), for_test=for_test)
def create_persistable_node(self, name, var_type, shape, var_dtype): def create_persistable_node(self, name, var_type, shape, var_dtype):
""" """
Create a persistable variable node in the graph. In IrGraph, Create a persistable variable node in the graph. In IrGraph,
...@@ -4106,8 +4123,10 @@ class IrGraph(object): ...@@ -4106,8 +4123,10 @@ class IrGraph(object):
node_in(IrNode): the input node. node_in(IrNode): the input node.
node_out(IrNode): the output node. node_out(IrNode): the output node.
""" """
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \ assert node_in.node in self.graph.nodes(), (
'The two arguments(node_in&node_out) must be in the graph nodes.' 'node_in(%s) must be in the graph nodes.' % node_in.node.name())
assert node_out.node in self.graph.nodes(), (
'node_out(%s) must be in the graph nodes.' % node_out.node.name())
node_in.append_output(node_out) node_in.append_output(node_out)
node_out.append_input(node_in) node_out.append_input(node_in)
...@@ -4269,7 +4288,8 @@ class IrGraph(object): ...@@ -4269,7 +4288,8 @@ class IrGraph(object):
for n in nodes: for n in nodes:
if n.name() == node_name: if n.name() == node_name:
target_node = n target_node = n
assert target_node is not None, "Cannot find the target node in the giving set." assert target_node is not None, (
"Cannot find the target node (%s)in the giving set." % node_name)
return target_node return target_node
def _update_desc_attr(self, desc, name, val): def _update_desc_attr(self, desc, name, val):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid as fluid
import six
from paddle.fluid.framework import IrGraph
from paddle.fluid.framework import IrNode
from paddle.fluid.tests.unittests.op_test import OpTestTool
from paddle.fluid import core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard, default_startup_program
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
paddle.enable_static()
class TestQuantizationSubGraph(unittest.TestCase):
def build_graph_with_sub_graph(self):
def linear_fc(num):
data = fluid.layers.data(
name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
main_program = Program()
startup_program = Program()
def true_func():
return linear_fc(3)
def false_func():
return linear_fc(5)
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
pred = layers.less_than(y, x)
out = layers.cond(pred, true_func, false_func)
core_graph = core.Graph(main_program.desc)
# We should create graph for test, otherwise it will throw a
# error that it cannot find the node of "STEP_COUNTER"
graph = IrGraph(core_graph, for_test=True)
sub_graph = graph.get_sub_graph(0)
all_sub_graphs = graph.all_sub_graphs(
for_test=True) # same reason for subgraph
# Should return graph and sub_graphs at the same time. If only return sub_graph, the graph will
# be destructed and the sub_graphs will be empty.
return graph, all_sub_graphs
def test_quant_sub_graphs(self, use_cuda=False):
graph, sub_graphs = self.build_graph_with_sub_graph()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=place,
activation_quantize_type='abs_max',
weight_quantize_type='range_abs_max')
Find_inserted_quant_op = False
for sub_graph in sub_graphs:
transform_pass.apply(sub_graph)
for op in sub_graph.all_op_nodes():
if 'quantize' in op.name():
Find_inserted_quant_op = True
self.assertTrue(Find_inserted_quant_op)
def test_quant_sub_graphs_cpu(self):
self.test_quant_sub_graphs(use_cuda=False)
@OpTestTool.skip_if(not paddle.is_compiled_with_cuda(),
"Not GPU version paddle")
def test_quant_sub_graphs_gpu(self):
self.test_quant_sub_graphs(use_cuda=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册