提交 33f99d61 编写于 作者: Z Zhen Wang

add IrNode&IrVarNode&IrOpNode. test=develop

上级 d8128930
...@@ -17,7 +17,9 @@ import numpy as np ...@@ -17,7 +17,9 @@ import numpy as np
import six import six
from ..... import compat as cpt from ..... import compat as cpt
from .... import core from .... import core
from .... import Executor
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Program from ....framework import Program
from ....initializer import Constant from ....initializer import Constant
from .... import unique_name from .... import unique_name
...@@ -31,7 +33,7 @@ __all__ = [ ...@@ -31,7 +33,7 @@ __all__ = [
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
def __init__(self, def __init__(self,
scope=None, scope=None,
program_exe=None, place=None,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
activation_quantize_type='abs_max', activation_quantize_type='abs_max',
...@@ -45,7 +47,7 @@ class QuantizationTransformPass(object): ...@@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to type, this pass will create some new parameters. The scope is used to
initialize these new parameters. initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above. parameters described above.
weight_bits (int): quantization bit number for weights, weight_bits (int): quantization bit number for weights,
the bias is not quantized. the bias is not quantized.
...@@ -71,13 +73,13 @@ class QuantizationTransformPass(object): ...@@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
from paddle.fluid import core from paddle.fluid import core
graph = IrGraph(core.Graph(program.desc), for_test=False) graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace()) place = fluid.CPUPlace()
transform_pass = QuantizationTransformPass(fluid.global_scope(), transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe) place)
transform_pass.apply(graph) transform_pass.apply(graph)
""" """
self._scope = scope self._scope = scope
self._program_exe = program_exe self._place = place
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
...@@ -118,7 +120,7 @@ class QuantizationTransformPass(object): ...@@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
self._is_test = graph.is_test() self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _transform_forward(graph, op): def _transform_forward(graph, op):
for var_node in op.inputs: for var_node in op.inputs:
...@@ -149,7 +151,7 @@ class QuantizationTransformPass(object): ...@@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
if not self._is_test: if not self._is_test:
self._create_global_step(graph) self._create_global_step(graph)
ops = graph.all_ops() ops = graph.all_op_nodes()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
...@@ -163,8 +165,8 @@ class QuantizationTransformPass(object): ...@@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
if len(self._need_initialized) > 0: if len(self._need_initialized) > 0:
assert self._scope is not None, \ assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._program_exe is not None, \ assert self._place is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' 'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program() init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized): for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var( var = init_program.global_block().create_var(
...@@ -175,7 +177,8 @@ class QuantizationTransformPass(object): ...@@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
lod_level=var_desc.lod_level(), lod_level=var_desc.lod_level(),
persistable=var_desc.persistable()) persistable=var_desc.persistable())
initializer(var, init_program.global_block()) initializer(var, init_program.global_block())
self._program_exe.run(program=init_program, scope=self._scope) exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
return graph return graph
...@@ -183,11 +186,11 @@ class QuantizationTransformPass(object): ...@@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
if self._weight_quantize_type == 'range_abs_max' or \ if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max': self._activation_quantize_type == 'range_abs_max':
counter_name = cpt.to_text('@STEP_COUNTER@') counter_name = cpt.to_text('@STEP_COUNTER@')
for node in graph.all_vars(): for node in graph.all_var_nodes():
if node.name() == counter_name: if node.name() == counter_name:
self._global_step = node self._global_step = node
if self._global_step is None: if self._global_step is None:
global_step_in = graph.create_param_node( global_step_in = graph.create_persistable_node(
name=counter_name, name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
...@@ -262,7 +265,7 @@ class QuantizationTransformPass(object): ...@@ -262,7 +265,7 @@ class QuantizationTransformPass(object):
shape=var_node.var().shape(), shape=var_node.var().shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.var().dtype())
scale_in_node = graph.create_param_node( scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
...@@ -275,7 +278,7 @@ class QuantizationTransformPass(object): ...@@ -275,7 +278,7 @@ class QuantizationTransformPass(object):
if not self._is_test: if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc. # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node( scales_node = graph.create_persistable_node(
name=unique_name.generate('scales'), name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size], shape=[self._window_size],
...@@ -400,8 +403,8 @@ class QuantizationFreezePass(object): ...@@ -400,8 +403,8 @@ class QuantizationFreezePass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
""" """
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_quant_op_names: if op_name in self._fake_quant_op_names:
...@@ -425,13 +428,13 @@ class QuantizationFreezePass(object): ...@@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
self._weight_bits) self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_dequant_op_names: if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
...@@ -462,7 +465,7 @@ class QuantizationFreezePass(object): ...@@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
max_range = None max_range = None
scale_var_node = None scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in self._op_input_rename_map: if name in self._op_input_rename_map:
...@@ -480,7 +483,7 @@ class QuantizationFreezePass(object): ...@@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
original_var_name) original_var_name)
max_range = param_range * act_range / scale_v max_range = param_range * act_range / scale_v
else: else:
assert isinstance(scale_v, core.Node) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1: if len(op_node.outputs) != 1:
...@@ -517,14 +520,19 @@ class QuantizationFreezePass(object): ...@@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
def _remove_unused_var_nodes(self, graph): def _remove_unused_var_nodes(self, graph):
all_used_vars = set() all_used_vars = set()
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
for input_node in op_node.inputs: for input_node in op_node.inputs:
all_used_vars.add(input_node) all_used_vars.add(input_node)
for output_node in op_node.outputs: for output_node in op_node.outputs:
all_used_vars.add(output_node) all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
def _original_var_name(self, var_name): def _original_var_name(self, var_name):
...@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object): ...@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
""" """
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_ops() ops = graph.all_op_nodes()
input_map = {} input_map = {}
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
...@@ -605,7 +613,7 @@ class ConvertToInt8Pass(object): ...@@ -605,7 +613,7 @@ class ConvertToInt8Pass(object):
def _convert_to_int8(self, graph, var_node): def _convert_to_int8(self, graph, var_node):
int8_var_node_name = var_node.name() + ".int8" int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_param_node( int8_var_node = graph.create_persistable_node(
name=cpt.to_text(int8_var_node_name), name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(), var_type=var_node.var().type(),
shape=var_node.var().shape(), shape=var_node.var().shape(),
...@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object): ...@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
def _remove_unused_var_nodes(self, graph): def _remove_unused_var_nodes(self, graph):
all_used_vars = set() all_used_vars = set()
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
for input_node in op_node.inputs: for input_node in op_node.inputs:
all_used_vars.add(input_node) all_used_vars.add(input_node)
for output_node in op_node.outputs: for output_node in op_node.outputs:
all_used_vars.add(output_node) all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
...@@ -655,7 +668,7 @@ class TransformForMobilePass(object): ...@@ -655,7 +668,7 @@ class TransformForMobilePass(object):
Args: Args:
graph(IrGraph): the graph will be transformed. graph(IrGraph): the graph will be transformed.
""" """
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
name = op_node.name() name = op_node.name()
if name in self._fake_quant_op_names: if name in self._fake_quant_op_names:
......
...@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase): ...@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase):
opt.minimize(loss) opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1: if op.name().find('conv2d') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes) graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle()) self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1) self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort() nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops())) self.assertEqual(len(nodes), len(graph.all_op_nodes()))
nodes_map = graph.build_adjacency_list() nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops())) self.assertEqual(len(nodes_map), len(graph.all_op_nodes()))
nodes_num = len(graph.all_nodes()) nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes) graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
......
...@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = linear_fc(3) loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace()) place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
program_exe=exe, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
...@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_ops(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
val_marked_nodes.add(op) val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
...@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = residual_block(2) loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace()) place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
program_exe=exe, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
...@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_ops(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
val_marked_nodes.add(op) val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
...@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, program_exe=exe, activation_quantize_type=quant_type) scope=scope, place=place, activation_quantize_type=quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
marked_nodes = set() marked_nodes = set()
for op in main_graph.all_ops(): for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
...@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters = 5 iters = 5
batch_size = 8 batch_size = 8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
...@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program, loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
#loss_v = train_exe.run(feed=feeder.feed(data), print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader()) test_data = next(test_reader())
with fluid.program_guard(quantized_test_program): with fluid.program_guard(quantized_test_program):
...@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type, test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
...@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
#print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision # Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze))) np.sum(w_freeze)))
#print('{}: {}'.format('w_quant' + dev_name + quant_type, print('{}: {}'.format('w_quant' + dev_name + quant_type,
# np.sum(w_quant))) np.sum(w_quant)))
# Convert parameter to 8-bit. # Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes) test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
...@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
#print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze))) np.sum(w_freeze)))
mobile_pass = TransformForMobilePass() mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph) mobile_pass.apply(test_graph)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type, test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
......
...@@ -1538,10 +1538,297 @@ class Block(object): ...@@ -1538,10 +1538,297 @@ class Block(object):
return ret_var return ret_var
class IrNode(object):
"""
Python IrNode. Beneath it is a core.Node, which is used for Ir Pass.
"""
def __init__(self, node):
"""
Construct an IrNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert isinstance(node,
core.Node), 'node must be the instance of core.Node.'
self.node = node
def name(self):
"""
Return the node name.
Returns:
str: node name.
"""
return self.node.name()
def node_type(self):
"""
Return the node type.
Returns:
core.Node.Type: node type(core.Node.Type.Operation or core.Node.Type.Variable).
"""
return self.node.node_type()
def var(self):
"""
Return the node variable description.
Returns:
core.VarDesc: node variable description.
"""
return self.node.var()
def op(self):
"""
Return the node operator description.
Returns:
core.OpDesc: node operator description.
"""
return self.node.op()
def id(self):
"""
Return the node id.
Returns:
int: node id.
"""
return self.node.id()
def is_op(self):
"""
If the node is an operator, then return true.
Returns:
bool: indicate whether the node is an operator.
"""
return self.node.is_op()
def is_var(self):
"""
If the node is a variable, then return true.
Returns:
bool: indicate whether the node is a variable.
"""
return self.node.is_var()
def is_ctrl_var(self):
"""
If the node is a control dependence variable, then return true.
Returns:
bool: indicate whether the node is a control dependence variable.
"""
return self.node.is_ctrl_var()
def clear_inputs(self):
"""
Clear the node inputs. After executing the `clear_inputs` function,
the node inputs will be empty.
"""
self.node.clear_inputs()
def inputs_remove_by_id(self, node_id):
"""
Remove a node from inputs by the given node id.
Args:
node_id(int): the given node id.
"""
self.node.inputs_remove(node_id)
def inputs_remove(self, ir_node):
"""
Remove a node from inputs.
Args:
ir_node(IrNode): the node being removed.
"""
self.node.inputs_remove(ir_node.node)
def inputs_append(self, ir_node):
"""
Append a node in inputs.
Args:
ir_node(IrNode): the node being appended.
"""
self.node.inputs_append(ir_node.node)
def clear_outputs(self):
"""
Clear the node outputs. After executing the `clear_outputs` function,
the node outputs will be empty.
"""
self.node.clear_outputs()
def outputs_remove_by_id(self, node_id):
"""
Remove a node from outputs by the given node id.
Args:
node_id(int): the given node id.
"""
self.node.outputs_remove(node_id)
def outputs_remove(self, ir_node):
"""
Remove a node from outputs.
Args:
ir_node(IrNode): the node being removed.
"""
self.node.outputs_remove(ir_node.node)
def outputs_append(self, ir_node):
"""
Append a node in outputs.
Args:
ir_node(IrNode): the node being appended.
"""
self.node.outputs_append(ir_node.node)
@property
def inputs(self):
"""
Return the node inputs.
Returns:
list(IrNode): node inputs wrapped by IrNode.
"""
return [IrNode(n) for n in self.node.inputs]
@property
def outputs(self):
"""
Return the node outputs.
Returns:
list(IrNode): node outputs wrapped by IrNode.
"""
return [IrNode(n) for n in self.node.outputs]
class IrVarNode(IrNode):
"""
Python IrVarNode. Beneath it is a core.Node, it inherits from IrNode.
"""
def __init__(self, node):
"""
Construct an IrVarNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert isinstance(node, core.Node) and node.is_var(), \
'node must be the instance of core.Node and it must be a variable node.'
super(IrVarNode, self).__init__(node)
self.node = node
def set_shape(self, shape):
"""
Set the node variable shape.
Args:
shape(list): shape to be set.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
self.node.var().set_shape(shape)
def persistable(self):
"""
If the variable node is a persistable variable, then return true.
Returns:
bool: indicate whether the variable is persistable.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().persistable()
@property
def inputs(self):
"""
Return the node inputs.
Returns:
list(IrOpNode): node inputs wrapped by IrOpNode.
"""
return [IrOpNode(n) for n in self.node.inputs]
@property
def outputs(self):
"""
Return the node outputs.
Returns:
list(IrOpNode): node outputs wrapped by IrOpNode.
"""
return [IrOpNode(n) for n in self.node.outputs]
class IrOpNode(IrNode):
"""
Python IrOpNode. Beneath it is a core.Node, it inherits from IrNode.
"""
def __init__(self, node):
"""
Construct an IrOpNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert isinstance(node, core.Node) and node.is_op(), \
'node must be the instance of core.Node and it must be a operator node.'
super(IrOpNode, self).__init__(node)
self.node = node
def rename_input(self, old_input_name, new_input_name):
"""
Rename the input of this node.
Args:
old_input_name(str): the old input name.
new_input_name(str): the new input name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name)
@property
def inputs(self):
"""
Return the node inputs.
Returns:
list(IrVarNode): node inputs wrapped by IrVarNode.
"""
return [IrVarNode(n) for n in self.node.inputs]
@property
def outputs(self):
"""
Return the node outputs.
Returns:
list(IrVarNode): node outputs wrapped by IrVarNode.
"""
return [IrVarNode(n) for n in self.node.outputs]
class IrGraph(object): class IrGraph(object):
""" """
Python IrGraph. Beneath it is a core.Graph, which is used for Python IrGraph. Beneath it is a core.Graph, which is used for
create a c++ Ir Pass Graph. An IrGraph is just a graph view of creating a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph a Program. In an IrGraph, both Variables and Operators are graph
nodes. nodes.
""" """
...@@ -1569,15 +1856,15 @@ class IrGraph(object): ...@@ -1569,15 +1856,15 @@ class IrGraph(object):
""" """
Return all nodes included in the graph as a set. Return all nodes included in the graph as a set.
""" """
return {node for node in self.graph.nodes()} return {IrNode(node) for node in self.graph.nodes()}
def all_vars(self): def all_var_nodes(self):
""" """
Return all variable nodes included in the graph as a set. Return all variable nodes included in the graph as a set.
""" """
return {node for node in self.graph.nodes() if node.is_var()} return {IrVarNode(node) for node in self.graph.nodes() if node.is_var()}
def all_persistable_vars(self): def all_persistable_nodes(self):
""" """
Return all persistable variable nodes included in the graph as a set. Return all persistable variable nodes included in the graph as a set.
""" """
...@@ -1586,13 +1873,13 @@ class IrGraph(object): ...@@ -1586,13 +1873,13 @@ class IrGraph(object):
if node.is_var() and node.var() is not None and node.var( if node.is_var() and node.var() is not None and node.var(
).persistable(): ).persistable():
persistable_nodes.add(node) persistable_nodes.add(node)
return persistable_nodes return {IrVarNode(p) for p in persistable_nodes}
def all_ops(self): def all_op_nodes(self):
""" """
Return all operator nodes included in the graph as a set. Return all operator nodes included in the graph as a set.
""" """
return {node for node in self.graph.nodes() if node.is_op()} return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name): def var_node(self, name):
""" """
...@@ -1606,14 +1893,14 @@ class IrGraph(object): ...@@ -1606,14 +1893,14 @@ class IrGraph(object):
doesn't have a variable with the giving name. doesn't have a variable with the giving name.
Returns: Returns:
core.Node: the variable node with the giving name. IrVarNode: the variable node with the giving name.
""" """
if not isinstance(name, six.string_types): if not isinstance(name, six.string_types):
raise TypeError( raise TypeError(
"var require string as parameter, but get %s instead." % "var require string as parameter, but get %s instead." %
(type(name))) (type(name)))
target_var_node = None target_var_node = None
var_nodes = self.all_vars() var_nodes = self.all_var_nodes()
for var_node in var_nodes: for var_node in var_nodes:
if var_node.name() == name: if var_node.name() == name:
target_var_node = var_node target_var_node = var_node
...@@ -1621,7 +1908,7 @@ class IrGraph(object): ...@@ -1621,7 +1908,7 @@ class IrGraph(object):
raise ValueError("var_node %s not in this graph" % name) raise ValueError("var_node %s not in this graph" % name)
return target_var_node return target_var_node
def create_param_node(self, name, var_type, shape, var_dtype): def create_persistable_node(self, name, var_type, shape, var_dtype):
""" """
Create a persistable variable node in the graph. In IrGraph, Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters. it can not distinguish between persistable variables and parameters.
...@@ -1633,14 +1920,14 @@ class IrGraph(object): ...@@ -1633,14 +1920,14 @@ class IrGraph(object):
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node. var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns: Returns:
core.Node: the created persistable variable node. IrVarNode: the created persistable variable node.
""" """
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
var_desc.set_type(var_type) var_desc.set_type(var_type)
var_desc.set_shape(shape) var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype) var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True) var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc) return IrVarNode(self.graph.create_var_node(var_desc))
def create_var_node(self, name, var_type, shape, var_dtype): def create_var_node(self, name, var_type, shape, var_dtype):
""" """
...@@ -1654,14 +1941,14 @@ class IrGraph(object): ...@@ -1654,14 +1941,14 @@ class IrGraph(object):
var_dtype(core.VarDesc.VarType): the data type of the variable node. var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns: Returns:
core.Node: the created variable node. IrVarNode: the created variable node.
""" """
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
var_desc.set_type(var_type) var_desc.set_type(var_type)
var_desc.set_shape(shape) var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype) var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc) return IrVarNode(self.graph.create_var_node(var_desc))
def create_var_node_from_desc(self, var_desc): def create_var_node_from_desc(self, var_desc):
""" """
...@@ -1672,9 +1959,9 @@ class IrGraph(object): ...@@ -1672,9 +1959,9 @@ class IrGraph(object):
var_desc(core.VarDesc): the giving variable description. var_desc(core.VarDesc): the giving variable description.
Returns: Returns:
core.Node: the created variable node. IrVarNode: the created variable node.
""" """
return self.graph.create_var_node(var_desc) return IrVarNode(self.graph.create_var_node(var_desc))
def create_op_node(self, op_type, attrs, inputs, outputs): def create_op_node(self, op_type, attrs, inputs, outputs):
""" """
...@@ -1687,7 +1974,7 @@ class IrGraph(object): ...@@ -1687,7 +1974,7 @@ class IrGraph(object):
outputs(dict): the outpus of the operator node. outputs(dict): the outpus of the operator node.
Returns: Returns:
core.Node: the created operator node. IrOpNode: the created operator node.
""" """
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
...@@ -1703,7 +1990,7 @@ class IrGraph(object): ...@@ -1703,7 +1990,7 @@ class IrGraph(object):
var_nodes = [var_nodes] var_nodes = [var_nodes]
op_desc.set_output(output_name, op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes]) [var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc) return IrOpNode(self.graph.create_op_node(op_desc))
def create_op_node_from_desc(self, op_desc): def create_op_node_from_desc(self, op_desc):
""" """
...@@ -1713,37 +2000,37 @@ class IrGraph(object): ...@@ -1713,37 +2000,37 @@ class IrGraph(object):
op_desc(core.VarDesc): the giving operator description. op_desc(core.VarDesc): the giving operator description.
Returns: Returns:
core.Node: the created operator node. IrOpNode: the created operator node.
""" """
return self.graph.create_op_node(op_desc) return IrOpNode(self.graph.create_op_node(op_desc))
def update_input_link(self, old_input_node, new_input_node, op_node): def update_input_link(self, old_input_node, new_input_node, op_node):
""" """
Update the input's link of a operator node. Update the input's link of a operator node.
Args: Args:
old_input_node(core.Node): the old input node of the giving op_node. old_input_node(IrNode): the old input node of the giving op_node.
new_input_node(core.Node): the new input node of the giving op_node. new_input_node(IrNode): the new input node of the giving op_node.
op_node(core.Node): the operator node that is needed to update input's link. op_node(IrOpNode): the operator node that is needed to update input's link.
""" """
assert old_input_node in self.graph.nodes() and new_input_node in \ assert old_input_node.node in self.graph.nodes() and new_input_node.node in \
self.graph.nodes() and op_node in self.graph.nodes(), \ self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.' 'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node) old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node) op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node) new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node) op_node.inputs_append(new_input_node)
op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) op_node.rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out): def link_to(self, node_in, node_out):
""" """
Connect two nodes. Connect two nodes.
Args: Args:
node_in(core.Node): the input node. node_in(IrNode): the input node.
node_out(core.Node): the output node. node_out(IrNode): the output node.
""" """
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
'The two arguments(node_in&node_out) must be in the graph nodes.' 'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out) node_in.outputs_append(node_out)
node_out.inputs_append(node_in) node_out.inputs_append(node_in)
...@@ -1761,7 +2048,8 @@ class IrGraph(object): ...@@ -1761,7 +2048,8 @@ class IrGraph(object):
remove_nodes = set(remove_nodes) remove_nodes = set(remove_nodes)
else: else:
remove_nodes = {remove_nodes} remove_nodes = {remove_nodes}
core.graph_safe_remove_nodes(self.graph, remove_nodes) original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes)
def has_circle(self): def has_circle(self):
""" """
...@@ -1788,18 +2076,23 @@ class IrGraph(object): ...@@ -1788,18 +2076,23 @@ class IrGraph(object):
Notes: the `graph` cannot contain a circle. Notes: the `graph` cannot contain a circle.
Returns: Returns:
set(core.Node): nodes in topology order. set(IrNode): nodes in topology order.
""" """
return core.topology_sort(self.graph) ordered_nodes = core.topology_sort(self.graph)
return {IrNode(n) for n in ordered_nodes}
def build_adjacency_list(self): def build_adjacency_list(self):
""" """
Build an adjacency list of operations for the `graph`. Build an adjacency list of operations for the `graph`.
Returns: Returns:
dict{core.Node: set(core.Node)}: the adjacency list. dict{IrNode: set(IrNode)}: the adjacency list.
""" """
return core.build_adjacency_list(self.graph) adj_list = core.build_adjacency_list(self.graph)
wrapped_adj_list = dict()
for k, v in six.iteritems(adj_list):
wrapped_adj_list[IrNode(k)] = {IrNode(n) for n in v}
return wrapped_adj_list
def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True): def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True):
""" """
...@@ -1809,7 +2102,7 @@ class IrGraph(object): ...@@ -1809,7 +2102,7 @@ class IrGraph(object):
Args: Args:
save_path(str): the save path of drawn graph. save_path(str): the save path of drawn graph.
name(str): the name of drawn graph. name(str): the name of drawn graph.
marked_nodes(set(core.Node)): nodes that are needed to be marked. marked_nodes(set(IrNode)): nodes that are needed to be marked.
Default value is None. Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True. in the graph will be removed. Default value is True.
...@@ -1824,20 +2117,22 @@ class IrGraph(object): ...@@ -1824,20 +2117,22 @@ class IrGraph(object):
print('The {} is saved as the dot filetype.'.format( print('The {} is saved as the dot filetype.'.format(
dot_file_path)) dot_file_path))
remove_ctr_vars = set()
if remove_ctr_var: if remove_ctr_var:
remove_ctr_vars = set() for node in self.all_var_nodes():
for node in self.graph.nodes():
if node.is_ctrl_var(): if node.is_ctrl_var():
remove_ctr_vars.add(node) remove_ctr_vars.add(node)
self.safe_remove_nodes(remove_ctr_vars) self.safe_remove_nodes(remove_ctr_vars)
ops_num = 0 print('Total ops num = {}.'.format(len(self.all_op_nodes())))
for node in self.graph.nodes():
if node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
if marked_nodes is not None: if marked_nodes is not None:
if not isinstance(marked_nodes, set): if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes) if isinstance(marked_nodes, Iterable):
marked_nodes = set(marked_nodes)
else:
marked_nodes = {marked_nodes}
marked_nodes = {n.node for n in marked_nodes}
remove_ctr_vars = {n.node for n in remove_ctr_vars}
marked_nodes = marked_nodes - remove_ctr_vars marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'): if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__') self.graph.erase('__graphviz__marked_node__')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册