未验证 提交 e00c7a2e 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #15830 from wzzju/add_ir_node_encapsulation

add IrNode&IrVarNode&IrOpNode. test=develop
......@@ -14,6 +14,7 @@
#include "paddle/fluid/pybind/ir.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -116,7 +117,7 @@ void BindNode(py::module *m) {
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("inputs_remove",
.def("remove_input",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.inputs.begin(), self.inputs.end(),
......@@ -125,7 +126,7 @@ void BindNode(py::module *m) {
self.inputs.erase(pos);
}
})
.def("inputs_remove",
.def("remove_input",
[](Node &self, Node &node) {
auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node);
......@@ -133,10 +134,10 @@ void BindNode(py::module *m) {
self.inputs.erase(pos);
}
})
.def("inputs_append",
.def("append_input",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("outputs_remove",
.def("remove_output",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.outputs.begin(), self.outputs.end(),
......@@ -145,7 +146,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos);
}
})
.def("outputs_remove",
.def("remove_output",
[](Node &self, Node &node) {
auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node);
......@@ -153,7 +154,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos);
}
})
.def("outputs_append",
.def("append_output",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
......
......@@ -17,7 +17,9 @@ import numpy as np
import six
from ..... import compat as cpt
from .... import core
from .... import Executor
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Program
from ....initializer import Constant
from .... import unique_name
......@@ -31,7 +33,7 @@ __all__ = [
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
program_exe=None,
place=None,
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max',
......@@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above.
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
......@@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
from paddle.fluid import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace())
place = fluid.CPUPlace()
transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe)
place)
transform_pass.apply(graph)
"""
self._scope = scope
self._program_exe = program_exe
self._place = place
self._weight_bits = weight_bits
self._activation_bits = activation_bits
......@@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _transform_forward(graph, op):
for var_node in op.inputs:
......@@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
if not self._is_test:
self._create_global_step(graph)
ops = graph.all_ops()
ops = graph.all_op_nodes()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
......@@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var(
......@@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
initializer(var, init_program.global_block())
self._program_exe.run(program=init_program, scope=self._scope)
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
return graph
......@@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max':
counter_name = cpt.to_text('@STEP_COUNTER@')
for node in graph.all_vars():
for node in graph.all_var_nodes():
if node.name() == counter_name:
self._global_step = node
if self._global_step is None:
global_step_in = graph.create_param_node(
global_step_in = graph.create_persistable_node(
name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
......@@ -228,14 +231,14 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={
......@@ -258,15 +261,15 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_param_node(
scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.var().dtype())
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
......@@ -275,11 +278,11 @@ class QuantizationTransformPass(object):
if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node(
scales_node = graph.create_persistable_node(
name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size],
var_dtype=var_node.var().dtype())
var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
......@@ -314,9 +317,9 @@ class QuantizationTransformPass(object):
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
......@@ -400,22 +403,22 @@ class QuantizationFreezePass(object):
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_quant_op_names:
input_arg_name = op_node.op().input('X')[0]
input_arg_name = op_node.input('X')[0]
if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param))
else:
scale_v = self._load_var(op_node.op().output('OutScale')
[0])[0]
scale_v = self._load_var(
op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v
else:
scale_v = graph.var_node(op_node.op().output('OutScale')[0])
scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node)
......@@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node)
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
......@@ -451,8 +454,8 @@ class QuantizationFreezePass(object):
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.op().output('Out')[0]
v = op_node.op().input('X')[0]
k = op_node.output('Out')[0]
v = op_node.input('X')[0]
if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v
else:
......@@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node):
max_range = None
scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
......@@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
original_var_name)
max_range = param_range * act_range / scale_v
else:
assert isinstance(scale_v, core.Node)
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
......@@ -490,9 +493,9 @@ class QuantizationFreezePass(object):
output_var_node = op_node.outputs[0]
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.var().type(),
shape=output_var_node.var().shape(),
var_dtype=output_var_node.var().dtype())
var_type=output_var_node.type(),
shape=output_var_node.shape(),
var_dtype=output_var_node.dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={
......@@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
def _original_var_name(self, var_name):
......@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
input_map = {}
for op_node in ops:
op_name = op_node.name()
......@@ -605,10 +613,10 @@ class ConvertToInt8Pass(object):
def _convert_to_int8(self, graph, var_node):
int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_param_node(
int8_var_node = graph.create_persistable_node(
name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=core.VarDesc.VarType.INT8)
array = self._load_var(var_node.name())
self._scope.var(int8_var_node_name)
......@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
......@@ -655,11 +668,11 @@ class TransformForMobilePass(object):
Args:
graph(IrGraph): the graph will be transformed.
"""
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._fake_quant_op_names:
op_node.op().set_type('quantize')
op_node.set_type('quantize')
quant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs:
graph.link_to(input_node, quant_node)
......@@ -667,7 +680,7 @@ class TransformForMobilePass(object):
graph.link_to(quant_node, output_node)
graph.safe_remove_nodes(op_node)
if name in self._fake_dequant_op_names:
op_node.op().set_type('dequantize')
op_node.set_type('dequantize')
dequant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs:
graph.link_to(input_node, dequant_node)
......
......@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase):
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set()
for op in graph.all_ops():
for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops()))
self.assertEqual(len(nodes), len(graph.all_op_nodes()))
nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops()))
self.assertEqual(len(nodes_map), len(graph.all_op_nodes()))
nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
......
......@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
place=place,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
......@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
......@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
place=place,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
......@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
......@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
scope=scope, program_exe=exe, activation_quantize_type=quant_type)
scope=scope, place=place, activation_quantize_type=quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
marked_nodes = set()
for op in main_graph.all_ops():
for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
......@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters = 5
batch_size = 8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
......@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data),
fetch_list=[loss])
#loss_v = train_exe.run(feed=feeder.feed(data),
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
......@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
......@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
#print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
#print('{}: {}'.format('w_quant' + dev_name + quant_type,
# np.sum(w_quant)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type,
np.sum(w_quant)))
# Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
......@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
#print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册