未验证 提交 e00c7a2e 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #15830 from wzzju/add_ir_node_encapsulation

add IrNode&IrVarNode&IrOpNode. test=develop
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -116,7 +117,7 @@ void BindNode(py::module *m) { ...@@ -116,7 +117,7 @@ void BindNode(py::module *m) {
.def("is_var", &Node::IsVar) .def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar) .def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); }) .def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("inputs_remove", .def("remove_input",
[](Node &self, int node_id) { [](Node &self, int node_id) {
auto pos = std::find_if( auto pos = std::find_if(
self.inputs.begin(), self.inputs.end(), self.inputs.begin(), self.inputs.end(),
...@@ -125,7 +126,7 @@ void BindNode(py::module *m) { ...@@ -125,7 +126,7 @@ void BindNode(py::module *m) {
self.inputs.erase(pos); self.inputs.erase(pos);
} }
}) })
.def("inputs_remove", .def("remove_input",
[](Node &self, Node &node) { [](Node &self, Node &node) {
auto pos = auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node); std::find(self.inputs.begin(), self.inputs.end(), &node);
...@@ -133,10 +134,10 @@ void BindNode(py::module *m) { ...@@ -133,10 +134,10 @@ void BindNode(py::module *m) {
self.inputs.erase(pos); self.inputs.erase(pos);
} }
}) })
.def("inputs_append", .def("append_input",
[](Node &self, Node &node) { self.inputs.push_back(&node); }) [](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); }) .def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("outputs_remove", .def("remove_output",
[](Node &self, int node_id) { [](Node &self, int node_id) {
auto pos = std::find_if( auto pos = std::find_if(
self.outputs.begin(), self.outputs.end(), self.outputs.begin(), self.outputs.end(),
...@@ -145,7 +146,7 @@ void BindNode(py::module *m) { ...@@ -145,7 +146,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos); self.outputs.erase(pos);
} }
}) })
.def("outputs_remove", .def("remove_output",
[](Node &self, Node &node) { [](Node &self, Node &node) {
auto pos = auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node); std::find(self.outputs.begin(), self.outputs.end(), &node);
...@@ -153,7 +154,7 @@ void BindNode(py::module *m) { ...@@ -153,7 +154,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos); self.outputs.erase(pos);
} }
}) })
.def("outputs_append", .def("append_output",
[](Node &self, Node &node) { self.outputs.push_back(&node); }) [](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs) .def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs); .def_readwrite("outputs", &Node::outputs);
......
...@@ -17,7 +17,9 @@ import numpy as np ...@@ -17,7 +17,9 @@ import numpy as np
import six import six
from ..... import compat as cpt from ..... import compat as cpt
from .... import core from .... import core
from .... import Executor
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Program from ....framework import Program
from ....initializer import Constant from ....initializer import Constant
from .... import unique_name from .... import unique_name
...@@ -31,7 +33,7 @@ __all__ = [ ...@@ -31,7 +33,7 @@ __all__ = [
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
def __init__(self, def __init__(self,
scope=None, scope=None,
program_exe=None, place=None,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
activation_quantize_type='abs_max', activation_quantize_type='abs_max',
...@@ -45,7 +47,7 @@ class QuantizationTransformPass(object): ...@@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to type, this pass will create some new parameters. The scope is used to
initialize these new parameters. initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above. parameters described above.
weight_bits (int): quantization bit number for weights, weight_bits (int): quantization bit number for weights,
the bias is not quantized. the bias is not quantized.
...@@ -71,13 +73,13 @@ class QuantizationTransformPass(object): ...@@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
from paddle.fluid import core from paddle.fluid import core
graph = IrGraph(core.Graph(program.desc), for_test=False) graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace()) place = fluid.CPUPlace()
transform_pass = QuantizationTransformPass(fluid.global_scope(), transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe) place)
transform_pass.apply(graph) transform_pass.apply(graph)
""" """
self._scope = scope self._scope = scope
self._program_exe = program_exe self._place = place
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
...@@ -118,7 +120,7 @@ class QuantizationTransformPass(object): ...@@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
self._is_test = graph.is_test() self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _transform_forward(graph, op): def _transform_forward(graph, op):
for var_node in op.inputs: for var_node in op.inputs:
...@@ -149,7 +151,7 @@ class QuantizationTransformPass(object): ...@@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
if not self._is_test: if not self._is_test:
self._create_global_step(graph) self._create_global_step(graph)
ops = graph.all_ops() ops = graph.all_op_nodes()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
...@@ -163,8 +165,8 @@ class QuantizationTransformPass(object): ...@@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
if len(self._need_initialized) > 0: if len(self._need_initialized) > 0:
assert self._scope is not None, \ assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._program_exe is not None, \ assert self._place is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' 'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program() init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized): for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var( var = init_program.global_block().create_var(
...@@ -175,7 +177,8 @@ class QuantizationTransformPass(object): ...@@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
lod_level=var_desc.lod_level(), lod_level=var_desc.lod_level(),
persistable=var_desc.persistable()) persistable=var_desc.persistable())
initializer(var, init_program.global_block()) initializer(var, init_program.global_block())
self._program_exe.run(program=init_program, scope=self._scope) exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
return graph return graph
...@@ -183,11 +186,11 @@ class QuantizationTransformPass(object): ...@@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
if self._weight_quantize_type == 'range_abs_max' or \ if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max': self._activation_quantize_type == 'range_abs_max':
counter_name = cpt.to_text('@STEP_COUNTER@') counter_name = cpt.to_text('@STEP_COUNTER@')
for node in graph.all_vars(): for node in graph.all_var_nodes():
if node.name() == counter_name: if node.name() == counter_name:
self._global_step = node self._global_step = node
if self._global_step is None: if self._global_step is None:
global_step_in = graph.create_param_node( global_step_in = graph.create_persistable_node(
name=counter_name, name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
...@@ -228,14 +231,14 @@ class QuantizationTransformPass(object): ...@@ -228,14 +231,14 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()), name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node( scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={
...@@ -258,15 +261,15 @@ class QuantizationTransformPass(object): ...@@ -258,15 +261,15 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()), name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
scale_in_node = graph.create_param_node( scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001) self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
...@@ -275,11 +278,11 @@ class QuantizationTransformPass(object): ...@@ -275,11 +278,11 @@ class QuantizationTransformPass(object):
if not self._is_test: if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc. # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node( scales_node = graph.create_persistable_node(
name=unique_name.generate('scales'), name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size], shape=[self._window_size],
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0) self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self._global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
...@@ -314,9 +317,9 @@ class QuantizationTransformPass(object): ...@@ -314,9 +317,9 @@ class QuantizationTransformPass(object):
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()), name=self._dequantized_var_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
max_range = (1 << (quant_bits - 1)) - 1 max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
...@@ -400,22 +403,22 @@ class QuantizationFreezePass(object): ...@@ -400,22 +403,22 @@ class QuantizationFreezePass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
""" """
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_quant_op_names: if op_name in self._fake_quant_op_names:
input_arg_name = op_node.op().input('X')[0] input_arg_name = op_node.input('X')[0]
if input_arg_name in persistable_vars: if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max': if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name) param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param)) scale_v = np.max(np.abs(param))
else: else:
scale_v = self._load_var(op_node.op().output('OutScale') scale_v = self._load_var(
[0])[0] op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
else: else:
scale_v = graph.var_node(op_node.op().output('OutScale')[0]) scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars: if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
...@@ -425,13 +428,13 @@ class QuantizationFreezePass(object): ...@@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
self._weight_bits) self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_dequant_op_names: if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
...@@ -451,8 +454,8 @@ class QuantizationFreezePass(object): ...@@ -451,8 +454,8 @@ class QuantizationFreezePass(object):
return graph return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node): def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.op().output('Out')[0] k = op_node.output('Out')[0]
v = op_node.op().input('X')[0] v = op_node.input('X')[0]
if v not in self._op_input_rename_map: if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v self._op_input_rename_map[k] = v
else: else:
...@@ -462,7 +465,7 @@ class QuantizationFreezePass(object): ...@@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
max_range = None max_range = None
scale_var_node = None scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in self._op_input_rename_map: if name in self._op_input_rename_map:
...@@ -480,7 +483,7 @@ class QuantizationFreezePass(object): ...@@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
original_var_name) original_var_name)
max_range = param_range * act_range / scale_v max_range = param_range * act_range / scale_v
else: else:
assert isinstance(scale_v, core.Node) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1: if len(op_node.outputs) != 1:
...@@ -490,9 +493,9 @@ class QuantizationFreezePass(object): ...@@ -490,9 +493,9 @@ class QuantizationFreezePass(object):
output_var_node = op_node.outputs[0] output_var_node = op_node.outputs[0]
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.var().type(), var_type=output_var_node.type(),
shape=output_var_node.var().shape(), shape=output_var_node.shape(),
var_dtype=output_var_node.var().dtype()) var_dtype=output_var_node.dtype())
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
attrs={ attrs={
...@@ -517,14 +520,19 @@ class QuantizationFreezePass(object): ...@@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
def _remove_unused_var_nodes(self, graph): def _remove_unused_var_nodes(self, graph):
all_used_vars = set() all_used_vars = set()
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
for input_node in op_node.inputs: for input_node in op_node.inputs:
all_used_vars.add(input_node) all_used_vars.add(input_node)
for output_node in op_node.outputs: for output_node in op_node.outputs:
all_used_vars.add(output_node) all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
def _original_var_name(self, var_name): def _original_var_name(self, var_name):
...@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object): ...@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
""" """
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_ops() ops = graph.all_op_nodes()
input_map = {} input_map = {}
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
...@@ -605,10 +613,10 @@ class ConvertToInt8Pass(object): ...@@ -605,10 +613,10 @@ class ConvertToInt8Pass(object):
def _convert_to_int8(self, graph, var_node): def _convert_to_int8(self, graph, var_node):
int8_var_node_name = var_node.name() + ".int8" int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_param_node( int8_var_node = graph.create_persistable_node(
name=cpt.to_text(int8_var_node_name), name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=core.VarDesc.VarType.INT8) var_dtype=core.VarDesc.VarType.INT8)
array = self._load_var(var_node.name()) array = self._load_var(var_node.name())
self._scope.var(int8_var_node_name) self._scope.var(int8_var_node_name)
...@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object): ...@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
def _remove_unused_var_nodes(self, graph): def _remove_unused_var_nodes(self, graph):
all_used_vars = set() all_used_vars = set()
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
for input_node in op_node.inputs: for input_node in op_node.inputs:
all_used_vars.add(input_node) all_used_vars.add(input_node)
for output_node in op_node.outputs: for output_node in op_node.outputs:
all_used_vars.add(output_node) all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
...@@ -655,11 +668,11 @@ class TransformForMobilePass(object): ...@@ -655,11 +668,11 @@ class TransformForMobilePass(object):
Args: Args:
graph(IrGraph): the graph will be transformed. graph(IrGraph): the graph will be transformed.
""" """
ops = graph.all_ops() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
name = op_node.name() name = op_node.name()
if name in self._fake_quant_op_names: if name in self._fake_quant_op_names:
op_node.op().set_type('quantize') op_node.set_type('quantize')
quant_node = graph.create_op_node_from_desc(op_node.op()) quant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs: for input_node in op_node.inputs:
graph.link_to(input_node, quant_node) graph.link_to(input_node, quant_node)
...@@ -667,7 +680,7 @@ class TransformForMobilePass(object): ...@@ -667,7 +680,7 @@ class TransformForMobilePass(object):
graph.link_to(quant_node, output_node) graph.link_to(quant_node, output_node)
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
if name in self._fake_dequant_op_names: if name in self._fake_dequant_op_names:
op_node.op().set_type('dequantize') op_node.set_type('dequantize')
dequant_node = graph.create_op_node_from_desc(op_node.op()) dequant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs: for input_node in op_node.inputs:
graph.link_to(input_node, dequant_node) graph.link_to(input_node, dequant_node)
......
...@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase): ...@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase):
opt.minimize(loss) opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1: if op.name().find('conv2d') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes) graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle()) self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1) self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort() nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops())) self.assertEqual(len(nodes), len(graph.all_op_nodes()))
nodes_map = graph.build_adjacency_list() nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops())) self.assertEqual(len(nodes_map), len(graph.all_op_nodes()))
nodes_num = len(graph.all_nodes()) nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes) graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
......
...@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = linear_fc(3) loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace()) place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
program_exe=exe, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
...@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_ops(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
val_marked_nodes.add(op) val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
...@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = residual_block(2) loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace()) place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
program_exe=exe, place=place,
activation_quantize_type=quant_type) activation_quantize_type=quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
...@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_ops(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
val_marked_nodes.add(op) val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
...@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, program_exe=exe, activation_quantize_type=quant_type) scope=scope, place=place, activation_quantize_type=quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
marked_nodes = set() marked_nodes = set()
for op in main_graph.all_ops(): for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
...@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters = 5 iters = 5
batch_size = 8 batch_size = 8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
...@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program, loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
#loss_v = train_exe.run(feed=feeder.feed(data), print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader()) test_data = next(test_reader())
with fluid.program_guard(quantized_test_program): with fluid.program_guard(quantized_test_program):
...@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type, test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
...@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
#print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision # Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze))) np.sum(w_freeze)))
#print('{}: {}'.format('w_quant' + dev_name + quant_type, print('{}: {}'.format('w_quant' + dev_name + quant_type,
# np.sum(w_quant))) np.sum(w_quant)))
# Convert parameter to 8-bit. # Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph) convert_int8_pass.apply(test_graph)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes) test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
...@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
#print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze))) np.sum(w_freeze)))
mobile_pass = TransformForMobilePass() mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph) mobile_pass.apply(test_graph)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_ops(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type, test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册