diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 24059140ab20e24917b93a5f60936b1087797ff9..1cd1be8e8d9da8c6a82ceefc3284084bfeda0252 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -13,10 +13,12 @@ // limitations under the License. #include "paddle/fluid/pybind/ir.h" +#include #include #include #include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_desc.h" @@ -27,6 +29,10 @@ namespace py = pybind11; using paddle::framework::ir::Graph; using paddle::framework::ir::Node; using paddle::framework::ir::GraphSafeRemoveNodes; +using paddle::framework::ir::HasCircle; +using paddle::framework::ir::GraphNum; +using paddle::framework::ir::TopologySortOperations; +using paddle::framework::ir::BuildOperationAdjList; using paddle::framework::OpDesc; using paddle::framework::ProgramDesc; using paddle::framework::VarDesc; @@ -36,6 +42,12 @@ namespace paddle { namespace pybind { void BindGraph(py::module *m) { m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); + m->def("has_circle", HasCircle); + m->def("graph_num", GraphNum); + m->def("topology_sort", TopologySortOperations, + return_value_policy::reference); + m->def("build_adjacency_list", BuildOperationAdjList, + return_value_policy::reference); py::class_>( *m, "Graph", "The graph is a Directed Acyclic Single Static Assignment Graph, see " @@ -46,7 +58,6 @@ void BindGraph(py::module *m) { .def("get_float", &Graph::Get) .def("get_double", &Graph::Get) .def("get_string", &Graph::Get) - .def("get_program", &Graph::Get) .def("get_marked_nodes", &Graph::Get>) .def("set", [](Graph &self, const std::string &attr_name, int attr) { return self.Set(attr_name, new int(attr)); }) @@ -63,11 +74,6 @@ void BindGraph(py::module *m) { [](Graph &self, const std::string &attr_name, double attr) { return self.Set(attr_name, new double(attr)); }) - .def("set", - [](Graph &self, const std::string &attr_name, - const ProgramDesc &attr) { - return self.Set(attr_name, new ProgramDesc(attr)); - }) .def("set", [](Graph &self, const std::string &attr_name, const std::unordered_set &attr) { @@ -108,42 +114,42 @@ void BindNode(py::module *m) { .def("is_op", &Node::IsOp) .def("is_var", &Node::IsVar) .def("is_ctrl_var", &Node::IsCtrlVar) + .def("clear_inputs", [](Node &self) { self.inputs.clear(); }) .def("inputs_remove", [](Node &self, int node_id) { - for (auto it = self.inputs.begin(); it != self.inputs.end(); - it++) { - if ((*it)->id() == node_id) { - self.inputs.erase(it); - } + auto pos = std::find_if( + self.inputs.begin(), self.inputs.end(), + [&node_id](const Node *n) { return n->id() == node_id; }); + if (pos != self.inputs.end()) { + self.inputs.erase(pos); } }) .def("inputs_remove", [](Node &self, Node &node) { - for (auto it = self.inputs.begin(); it != self.inputs.end(); - it++) { - if (*it == &node) { - self.inputs.erase(it); - } + auto pos = + std::find(self.inputs.begin(), self.inputs.end(), &node); + if (pos != self.inputs.end()) { + self.inputs.erase(pos); } }) .def("inputs_append", [](Node &self, Node &node) { self.inputs.push_back(&node); }) + .def("clear_outputs", [](Node &self) { self.outputs.clear(); }) .def("outputs_remove", [](Node &self, int node_id) { - for (auto it = self.outputs.begin(); it != self.outputs.end(); - it++) { - if ((*it)->id() == node_id) { - self.outputs.erase(it); - } + auto pos = std::find_if( + self.outputs.begin(), self.outputs.end(), + [&node_id](const Node *n) { return n->id() == node_id; }); + if (pos != self.outputs.end()) { + self.outputs.erase(pos); } }) .def("outputs_remove", [](Node &self, Node &node) { - for (auto it = self.outputs.begin(); it != self.outputs.end(); - it++) { - if (*it == &node) { - self.outputs.erase(it); - } + auto pos = + std::find(self.outputs.begin(), self.outputs.end(), &node); + if (pos != self.outputs.end()) { + self.outputs.erase(pos); } }) .def("outputs_append", diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 351513712cc4297bf7fbe67878aeba162ef66e4d..a4a01ad647b038bd2bfea00fefa30abb19f58b66 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -829,8 +829,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); - m.def("get_pass", [](const py::bytes &binary_str) { - std::string pass_type(binary_str); + m.def("get_pass", [](const std::string &pass_type) { auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); return std::shared_ptr(std::move(pass)); }); @@ -838,10 +837,9 @@ All parameter, weight, gradient are variables in Paddle. py::class_> pass(m, "Pass"); pass.def(py::init()) .def("has", &ir::Pass::Has) - .def("set", - [](ir::Pass &self, const std::string &attr_name, - const ProgramDesc &attr) { - return self.Set(attr_name, new ProgramDesc(attr)); + .def("set_not_owned", + [](ir::Pass &self, const std::string &attr_name, ProgramDesc &attr) { + self.SetNotOwned(attr_name, &attr); }) .def( "set", @@ -850,7 +848,6 @@ All parameter, weight, gradient are variables in Paddle. }) .def("set", [](ir::Pass &self, const std::string &name, int val) { self.Set(name, new int(val)); }) - .def("get_program", &ir::Pass::Get) .def("type", &ir::Pass::Type) .def("apply", [](ir::Pass &self, std::shared_ptr graph) { std::unique_ptr origin_graph(graph.get()); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 90b8fd1a0aab159eb1a829d67485c845182d295b..bcc997ff4511db45d2a775092c0798d7c1e9be06 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -64,6 +64,7 @@ if (WITH_TESTING) add_subdirectory(paddle/dataset/tests) add_subdirectory(paddle/fluid/tests) add_subdirectory(paddle/fluid/contrib/tests) + add_subdirectory(paddle/fluid/contrib/slim/tests) endif() install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR} DESTINATION opt/paddle/share/wheels diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 266a106bc507104c0a8db1c882b55ac59e88195e..18b58e6f388bbe9495333b12f32d63b74fddcb3a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -13,14 +13,19 @@ # limitations under the License. import collections +import numpy as np +import six +from ..... import compat as cpt from .... import core from ....framework import IrGraph from ....framework import Program -from ....framework import Variable from ....initializer import Constant from .... import unique_name -__all__ = ['QuantizationTransformPass'] +__all__ = [ + 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', + 'TransformForMobilePass' +] class QuantizationTransformPass(object): @@ -35,7 +40,13 @@ class QuantizationTransformPass(object): """ Convert and rewrite the IrGraph according to weight and activation quantization type. + Args: + scope(fluid.Scope): When activation use 'range_abs_max' as the quantize + type, this pass will create some new parameters. The scope is used to + initialize these new parameters. + program_exe(fluid.Executor): program_exe is used to initialize new + parameters described above. weight_bits (int): quantization bit number for weights, the bias is not quantized. activation_bits (int): quantization bit number for activation. @@ -49,6 +60,7 @@ class QuantizationTransformPass(object): support 'abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. window_size (int): the window size for 'range_abs_max' quantization. + Examples: .. code-block:: python # The original graph will be rewrite. @@ -88,31 +100,35 @@ class QuantizationTransformPass(object): self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] - self._fake_quant_op_types = [ - 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' - ] - self._fake_dequant_op_types = ['fake_dequantize_max_abs'] self._is_test = None self._global_step = None def apply(self, graph): + """ + Quantize the graph for training process. According to weight and + activation quantization type, the graph will be added some fake + quantize operators and fake dequantize operators. + + Args: + graph(IrGraph): the applied graph. + """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' self._need_initialized.clear() self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() - params = [p.name() for p in graph.all_parameters()] + persistable_vars = [p.name() for p in graph.all_persistable_vars()] def _transform_forward(graph, op): for var_node in op.inputs: if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] else: - quant_bits = self._weight_bits if var_node.name() in params \ + quant_bits = self._weight_bits if var_node.name() in persistable_vars \ else self._activation_bits quant_type = self._weight_quantize_type if var_node.name() \ - in params else self._activation_quantize_type + in persistable_vars else self._activation_quantize_type quant_var_node, scale_var_node = self._insert_quant_op( graph, var_node, quant_bits, quant_type) dequant_var_node = self._insert_dequant_op( @@ -150,9 +166,14 @@ class QuantizationTransformPass(object): assert self._program_exe is not None, \ 'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' init_program = Program() - for var_desc, initializer in self._need_initialized.iteritems(): - var = Variable(init_program.global_block()) - var._set_desc(var_desc) + for var_desc, initializer in six.iteritems(self._need_initialized): + var = init_program.global_block().create_var( + name=var_desc.name(), + shape=var_desc.shape(), + dtype=var_desc.dtype(), + type=var_desc.type(), + lod_level=var_desc.lod_level(), + persistable=var_desc.persistable()) initializer(var, init_program.global_block()) self._program_exe.run(program=init_program, scope=self._scope) @@ -161,7 +182,7 @@ class QuantizationTransformPass(object): def _create_global_step(self, graph): if self._weight_quantize_type == 'range_abs_max' or \ self._activation_quantize_type == 'range_abs_max': - counter_name = '@STEP_COUNTER@' + counter_name = cpt.to_text('@STEP_COUNTER@') for node in graph.all_vars(): if node.name() == counter_name: self._global_step = node @@ -175,9 +196,14 @@ class QuantizationTransformPass(object): Constant(value=0, force_cpu=True) global_step_out = graph.create_var_node_from_desc( global_step_in.var()) + # The attribute of `op_role` is needed by ParallelExecutor. increment_op = graph.create_op_node( op_type='increment', - attrs={'step': 1.0}, + attrs={ + 'step': 1.0, + 'op_role': + core.op_proto_and_checker_maker.OpRole.Forward + }, inputs={'X': global_step_in}, outputs={'Out': global_step_out}) graph.link_to(global_step_in, increment_op) @@ -212,7 +238,10 @@ class QuantizationTransformPass(object): var_dtype=var_node.var().dtype()) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', - attrs={'bit_length': quant_bits}, + attrs={ + 'bit_length': quant_bits, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, inputs={'X': var_node}, outputs={'Out': quant_var_node, 'OutScale': scale_var_node}) @@ -257,7 +286,8 @@ class QuantizationTransformPass(object): attrs = { 'window_size': self._window_size, 'bit_length': quant_bits, - 'is_test': self._is_test + 'is_test': self._is_test, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward } quant_op_node = graph.create_op_node( op_type='fake_quantize_range_abs_max', @@ -290,7 +320,10 @@ class QuantizationTransformPass(object): max_range = (1 << (quant_bits - 1)) - 1 dequant_op_node = graph.create_op_node( op_type='fake_dequantize_max_abs', - attrs={'max_range': float(max_range)}, + attrs={ + 'max_range': float(max_range), + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, inputs={'X': var_node, 'Scale': scale_var_node}, outputs={'Out': dequant_var_node}) @@ -316,3 +349,330 @@ class QuantizationTransformPass(object): Return the scale name of quantized variable for the input `var_name`. """ return "%s.scale" % (var_name) + + +class QuantizationFreezePass(object): + """ + The freeze pass is used to adjust the quantize operator order, for example: + 1) `activation -> quant -> dequant -> conv2d` will be freezed into + `activation -> quant -> conv2d -> dequant` + 2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`, + and weight will be sacled offline. + + Args: + scope(fluid.Scope): scope is used to get the weight tensor values. + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. + weight_bits (int): quantization bit number for weights. + activation_bits (int): quantization bit number for activation. + weight_quantize_type (str): quantization type for weights, support 'abs_max'. + The 'range_abs_max' usually is not used for weight, since weights are fixed once the + model is well trained. + """ + + def __init__(self, + scope, + place, + weight_bits=8, + activation_bits=8, + weight_quantize_type='abs_max'): + assert scope is not None, \ + 'The scope cannot be set None.' + assert place is not None, \ + 'The place cannot be set None.' + self._scope = scope + self._place = place + self._weight_bits = weight_bits + self._activation_bits = activation_bits + self._weight_quantize_type = weight_quantize_type + self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._fake_quant_op_names = [ + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + ] + self._fake_dequant_op_names = ['fake_dequantize_max_abs'] + self._op_input_rename_map = collections.OrderedDict() + self._op_output_rename_map = collections.OrderedDict() + self._var_scale_map = collections.OrderedDict() + + def apply(self, graph): + """ + Adjust quantize/dequantize operators order for the inference process. + + Args: + graph(IrGraph): the applied graph. + """ + persistable_vars = [p.name() for p in graph.all_persistable_vars()] + ops = graph.all_ops() + for op_node in ops: + op_name = op_node.name() + if op_name in self._fake_quant_op_names: + input_arg_name = op_node.op().input('X')[0] + if input_arg_name in persistable_vars: + if self._weight_quantize_type == 'abs_max': + param = self._load_var(input_arg_name) + scale_v = np.max(np.abs(param)) + else: + scale_v = self._load_var(op_node.op().output('OutScale') + [0])[0] + self._var_scale_map[input_arg_name] = scale_v + else: + scale_v = graph.var_node(op_node.op().output('OutScale')[0]) + self._var_scale_map[input_arg_name] = scale_v + if input_arg_name in persistable_vars: + self._remove_fake_quant_and_dequant_op(graph, op_node) + # quantize weight and restore + param_v = self._load_var(input_arg_name) + quantized_param_v = self._quant(param_v, scale_v, + self._weight_bits) + self._restore_var(input_arg_name, quantized_param_v) + + ops = graph.all_ops() + for op_node in ops: + op_name = op_node.name() + if op_name in self._fake_dequant_op_names: + self._remove_fake_quant_and_dequant_op(graph, op_node) + + ops = graph.all_ops() + for op_node in ops: + op_name = op_node.name() + if op_name in self._quantizable_ops: + self._insert_post_dequant_op(graph, op_node) + + for op_node in ops: + # insert dequant_op after fc/conv, need to rename inputs of the followed ops + for var_node in op_node.inputs: + name = var_node.name() + if name in self._op_output_rename_map: + old_in = graph.var_node(name) + new_in = self._op_output_rename_map[name] + graph.update_input_link(old_in, new_in, op_node) + + # remove the unused var node in the graph + self._remove_unused_var_nodes(graph) + return graph + + def _remove_fake_quant_and_dequant_op(self, graph, op_node): + k = op_node.op().output('Out')[0] + v = op_node.op().input('X')[0] + if v not in self._op_input_rename_map: + self._op_input_rename_map[k] = v + else: + self._op_input_rename_map[k] = self._op_input_rename_map[v] + graph.safe_remove_nodes(op_node) + + def _insert_post_dequant_op(self, graph, op_node): + max_range = None + scale_var_node = None + persistable_vars = [p.name() for p in graph.all_persistable_vars()] + for var_node in op_node.inputs: + name = var_node.name() + if name in self._op_input_rename_map: + old_in = graph.var_node(name) + new_in = graph.var_node(self._op_input_rename_map[name]) + new_in.clear_outputs() + graph.update_input_link(old_in, new_in, op_node) + original_var_name = self._original_var_name(name) + scale_v = self._var_scale_map[original_var_name] + if original_var_name in persistable_vars: + param_range = (1 << (self._weight_bits - 1)) - 1 + act_range = (1 << (self._activation_bits - 1)) - 1 + assert self._is_float( + scale_v), 'The scale of parameter %s is not a float.' % ( + original_var_name) + max_range = param_range * act_range / scale_v + else: + assert isinstance(scale_v, core.Node) + scale_var_node = self._var_scale_map[original_var_name] + + if len(op_node.outputs) != 1: + raise ValueError("Only support one output, but op %s has" + " more than one output." % (op_node.name())) + + output_var_node = op_node.outputs[0] + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(output_var_node.name()), + var_type=output_var_node.var().type(), + shape=output_var_node.var().shape(), + var_dtype=output_var_node.var().dtype()) + dequant_op_node = graph.create_op_node( + op_type='fake_dequantize_max_abs', + attrs={ + 'max_range': float(max_range), + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': output_var_node, + 'Scale': scale_var_node}, + outputs={'Out': dequant_var_node}) + graph.link_to(output_var_node, dequant_op_node) + graph.link_to(scale_var_node, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) + self._op_output_rename_map[output_var_node.name()] = dequant_var_node + return dequant_var_node + + def _load_var(self, name): + return np.array(self._scope.find_var(name).get_tensor()) + + def _restore_var(self, name, array): + tensor = self._scope.find_var(name).get_tensor() + tensor.set(array, self._place) + + def _remove_unused_var_nodes(self, graph): + all_used_vars = set() + ops = graph.all_ops() + for op_node in ops: + for input_node in op_node.inputs: + all_used_vars.add(input_node) + for output_node in op_node.outputs: + all_used_vars.add(output_node) + + all_unused_vars = graph.all_vars() - all_used_vars + graph.safe_remove_nodes(all_unused_vars) + + def _original_var_name(self, var_name): + """ + Return the original variable name. + """ + if var_name.endswith('.quantized.dequantized'): + return var_name[:-len('.quantized.dequantized')] + if var_name.endswith('.quantized'): + return var_name[:-len('.quantized')] + if var_name.endswith('.dequantized'): + return var_name[:-len('.dequantized')] + if var_name.endswith('.scale'): + return var_name[:-len('.scale')] + else: + return var_name + + def _dequantized_var_name(self, var_name): + """ + Return dequantized variable name for the input `var_name`. + """ + return "%s.dequantized" % (var_name) + + def _is_float(self, v): + return isinstance(v, float) or isinstance(v, np.float32) \ + or isinstance(v, np.float64) + + def _quant(self, x, scale, num_bits): + return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) + + +class ConvertToInt8Pass(object): + """ + Convert the weights into int8_t type. + + Args: + scope(fluid.Scope): scope is used to get the weight tensor values. + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the + 8bits weight tensors. + """ + + def __init__(self, scope, place): + assert scope is not None, \ + 'The scope cannot be set None.' + assert place is not None, \ + 'The place cannot be set None.' + self._scope = scope + self._place = place + self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + + def apply(self, graph): + """ + Convert weights' tpye of the graph. After that, the data type of the + graph weigths is int8_t. + + Args: + graph(IrGraph): the applied graph. + """ + persistable_vars = [p.name() for p in graph.all_persistable_vars()] + ops = graph.all_ops() + input_map = {} + for op_node in ops: + op_name = op_node.name() + if op_name in self._quantizable_ops: + for var_node in op_node.inputs: + name = var_node.name() + if name in persistable_vars: + if name not in input_map: + int8_var_node = self._convert_to_int8(graph, + var_node) + input_map[name] = int8_var_node + graph.update_input_link(var_node, input_map[name], + op_node) + + # remove the unused var node in the graph + self._remove_unused_var_nodes(graph) + return graph + + def _convert_to_int8(self, graph, var_node): + int8_var_node_name = var_node.name() + ".int8" + int8_var_node = graph.create_param_node( + name=cpt.to_text(int8_var_node_name), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=core.VarDesc.VarType.INT8) + array = self._load_var(var_node.name()) + self._scope.var(int8_var_node_name) + self._store_var(int8_var_node_name, array, np.int8) + return int8_var_node + + def _load_var(self, name): + return np.array(self._scope.find_var(name).get_tensor()) + + def _store_var(self, name, array, dtype): + tensor = self._scope.find_var(name).get_tensor() + tensor.set(array.astype(dtype), self._place) + + def _remove_unused_var_nodes(self, graph): + all_used_vars = set() + ops = graph.all_ops() + for op_node in ops: + for input_node in op_node.inputs: + all_used_vars.add(input_node) + for output_node in op_node.outputs: + all_used_vars.add(output_node) + + all_unused_vars = graph.all_vars() - all_used_vars + graph.safe_remove_nodes(all_unused_vars) + + +class TransformForMobilePass(object): + """ + This pass is used to convert the freezed graph for paddle-mobile execution. + """ + + def __init__(self): + self._fake_quant_op_names = [ + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + ] + self._fake_dequant_op_names = ['fake_dequantize_max_abs'] + + def apply(self, graph): + """ + Because paddle-mobile use `quantize` an `dequantize` as the names of + quantize operator and dequantize operator, the `apply` function just + realize this logic. + + Args: + graph(IrGraph): the graph will be transformed. + """ + ops = graph.all_ops() + for op_node in ops: + name = op_node.name() + if name in self._fake_quant_op_names: + op_node.op().set_type('quantize') + quant_node = graph.create_op_node_from_desc(op_node.op()) + for input_node in op_node.inputs: + graph.link_to(input_node, quant_node) + for output_node in op_node.outputs: + graph.link_to(quant_node, output_node) + graph.safe_remove_nodes(op_node) + if name in self._fake_dequant_op_names: + op_node.op().set_type('dequantize') + dequant_node = graph.create_op_node_from_desc(op_node.op()) + for input_node in op_node.inputs: + graph.link_to(input_node, dequant_node) + for output_node in op_node.outputs: + graph.link_to(dequant_node, output_node) + graph.safe_remove_nodes(op_node) + + return graph diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..79bec8c4ad34d682895250bc29b1fddb3a569bd4 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/fluid/contrib/slim/unitest/__init__.py b/python/paddle/fluid/contrib/slim/tests/__init__.py similarity index 100% rename from python/paddle/fluid/contrib/slim/unitest/__init__.py rename to python/paddle/fluid/contrib/slim/tests/__init__.py diff --git a/python/paddle/fluid/contrib/slim/unitest/configs/config.yaml b/python/paddle/fluid/contrib/slim/tests/configs/config.yaml similarity index 88% rename from python/paddle/fluid/contrib/slim/unitest/configs/config.yaml rename to python/paddle/fluid/contrib/slim/tests/configs/config.yaml index db488b96330210df15b02b19d90abd5c9101f844..d9b49029d3e34d487ad65fe0f7e54e2cee1d5838 100644 --- a/python/paddle/fluid/contrib/slim/unitest/configs/config.yaml +++ b/python/paddle/fluid/contrib/slim/tests/configs/config.yaml @@ -1,5 +1,5 @@ version: 1.0 -include: ["./unitest/configs/pruners.yaml", "./unitest/configs/pruners_0.yaml"] +include: ["./configs/pruners.yaml", "./configs/pruners_0.yaml"] pruners: pruner_1: class: 'RatioPruner' diff --git a/python/paddle/fluid/contrib/slim/unitest/configs/pruners.yaml b/python/paddle/fluid/contrib/slim/tests/configs/pruners.yaml similarity index 100% rename from python/paddle/fluid/contrib/slim/unitest/configs/pruners.yaml rename to python/paddle/fluid/contrib/slim/tests/configs/pruners.yaml diff --git a/python/paddle/fluid/contrib/slim/unitest/configs/pruners_0.yaml b/python/paddle/fluid/contrib/slim/tests/configs/pruners_0.yaml similarity index 100% rename from python/paddle/fluid/contrib/slim/unitest/configs/pruners_0.yaml rename to python/paddle/fluid/contrib/slim/tests/configs/pruners_0.yaml diff --git a/python/paddle/fluid/contrib/slim/unitest/test_factory.py b/python/paddle/fluid/contrib/slim/tests/test_factory.py similarity index 95% rename from python/paddle/fluid/contrib/slim/unitest/test_factory.py rename to python/paddle/fluid/contrib/slim/tests/test_factory.py index 07f28aac905d1a2813dbde6143235c7916fd9278..2fc72b6475e6bdd977dafb57696046a1100d0087 100644 --- a/python/paddle/fluid/contrib/slim/unitest/test_factory.py +++ b/python/paddle/fluid/contrib/slim/tests/test_factory.py @@ -18,7 +18,7 @@ import unittest class TestFactory(unittest.TestCase): def test_parse(self): - factory = ConfigFactory('./unitest/configs/config.yaml') + factory = ConfigFactory('./configs/config.yaml') pruner = factory.instance('pruner_1') self.assertEquals(pruner.ratios['conv1_1.w'], 0.3) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..75e0c95b5c3cc06d66eab9de0b85e5d7ed110837 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -0,0 +1,80 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function +import unittest +import paddle.fluid as fluid +import six +from paddle.fluid.framework import IrGraph +from paddle.fluid import core + + +def residual_block(num): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in six.moves.xrange(num): + conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) + short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) + hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') + fc = fluid.layers.fc(input=hidden, size=10) + loss = fluid.layers.cross_entropy(input=fc, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestGraph(unittest.TestCase): + def test_graph_functions(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = residual_block(2) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + graph = IrGraph(core.Graph(main.desc), for_test=False) + marked_nodes = set() + for op in graph.all_ops(): + if op.name().find('conv2d') > -1: + marked_nodes.add(op) + graph.draw('.', 'residual', marked_nodes) + self.assertFalse(graph.has_circle()) + self.assertEqual(graph.graph_num(), 1) + nodes = graph.topology_sort() + self.assertEqual(len(nodes), len(graph.all_ops())) + nodes_map = graph.build_adjacency_list() + self.assertEqual(len(nodes_map), len(graph.all_ops())) + nodes_num = len(graph.all_nodes()) + graph.safe_remove_nodes(marked_nodes) + self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2f291132f3049af21420f863972792c1a862b9ad --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -0,0 +1,372 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import random +import numpy as np +import paddle.fluid as fluid +import six +import paddle +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass +from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid import core + + +def linear_fc(num): + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in six.moves.xrange(num): + hidden = fluid.layers.fc(hidden, size=128, act='relu') + loss = fluid.layers.cross_entropy(input=hidden, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def residual_block(num): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in six.moves.xrange(num): + conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) + short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) + hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') + fc = fluid.layers.fc(input=hidden, size=10) + loss = fluid.layers.cross_entropy(input=fc, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return avg_loss + + +class TestQuantizationTransformPass(unittest.TestCase): + def setUp(self): + self.quantizable_op_and_inputs = { + 'conv2d': ['Input', 'Filter'], + 'depthwise_conv2d': ['Input', 'Filter'], + 'mul': ['X', 'Y'] + } + self.quantizable_grad_op_inputs = { + 'conv2d_grad': ['Input', 'Filter'], + 'depthwise_conv2d_grad': ['Input', 'Filter'], + 'mul_grad': ['X', 'Y'] + } + + def check_program(self, transform_pass, program): + quantized_ops = set() + for block in program.blocks: + for op in block.ops: + # check forward + if op.type in self.quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + quantized_ops.add(arg_name) + + for op in block.ops: + # check backward + if op.type in self.quantizable_grad_op_inputs: + for pname in self.quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + self.assertTrue(arg_name in quantized_ops) + + def linear_fc_quant(self, quant_type): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = linear_fc(3) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + exe = fluid.Executor(fluid.CPUPlace()) + graph = IrGraph(core.Graph(main.desc), for_test=False) + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + program_exe=exe, + activation_quantize_type=quant_type) + transform_pass.apply(graph) + marked_nodes = set() + for op in graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) + program = graph.to_program() + self.check_program(transform_pass, program) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + val_marked_nodes = set() + for op in val_graph.all_ops(): + if op.name().find('quantize') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) + + def test_linear_fc_quant_abs_max(self): + self.act_quant_op_type = 'fake_quantize_abs_max' + self.linear_fc_quant('abs_max') + + def test_linear_fc_quant_range_abs_max(self): + self.act_quant_op_type = 'fake_quantize_range_abs_max' + self.linear_fc_quant('range_abs_max') + + def residual_block_quant(self, quant_type): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = residual_block(2) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + exe = fluid.Executor(fluid.CPUPlace()) + graph = IrGraph(core.Graph(main.desc), for_test=False) + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + program_exe=exe, + activation_quantize_type=quant_type) + transform_pass.apply(graph) + marked_nodes = set() + for op in graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) + program = graph.to_program() + self.check_program(transform_pass, program) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + val_marked_nodes = set() + for op in val_graph.all_ops(): + if op.name().find('quantize') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) + + def test_residual_block_abs_max(self): + self.act_quant_op_type = 'fake_quantize_abs_max' + self.residual_block_quant('abs_max') + + def test_residual_block_range_abs_max(self): + self.act_quant_op_type = 'fake_quantize_range_abs_max' + self.residual_block_quant('range_abs_max') + + +class TestQuantizationFreezePass(unittest.TestCase): + def freeze_graph(self, use_cuda, seed, quant_type): + def build_program(main, startup, is_test): + main.random_seed = seed + startup.random_seed = seed + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + img = fluid.layers.data( + name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + loss = conv_net(img, label) + if not is_test: + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + return [img, label], loss + + random.seed(0) + np.random.seed(0) + + main = fluid.Program() + startup = fluid.Program() + test_program = fluid.Program() + feeds, loss = build_program(main, startup, False) + build_program(test_program, startup, True) + test_program = test_program.clone(for_test=True) + main_graph = IrGraph(core.Graph(main.desc), for_test=False) + test_graph = IrGraph(core.Graph(test_program.desc), for_test=True) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup) + transform_pass = QuantizationTransformPass( + scope=scope, program_exe=exe, activation_quantize_type=quant_type) + transform_pass.apply(main_graph) + transform_pass.apply(test_graph) + dev_name = '_gpu_' if use_cuda else '_cpu_' + marked_nodes = set() + for op in main_graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) + marked_nodes = set() + for op in test_graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) + + quantized_main_program = main_graph.to_program() + quantized_test_program = test_graph.to_program() + iters = 5 + batch_size = 8 + + #train_exe = fluid.ParallelExecutor( + # main_program=quantized_main_program, + # use_cuda=bool(use_cuda), + # loss_name=loss.name, + # scope=scope) + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + with fluid.scope_guard(scope): + for _ in range(iters): + data = next(train_reader()) + loss_v = exe.run(program=quantized_main_program, + feed=feeder.feed(data), + fetch_list=[loss]) + #loss_v = train_exe.run(feed=feeder.feed(data), + # fetch_list=[loss.name]) + #print('{}: {}'.format('loss' + dev_name + quant_type, loss_v)) + + test_data = next(test_reader()) + with fluid.program_guard(quantized_test_program): + w_var = fluid.framework._get_var('conv2d_1.w_0.quantized', + quantized_test_program) + # Testing + with fluid.scope_guard(scope): + test_loss1, w_quant = exe.run(program=quantized_test_program, + feed=feeder.feed(test_data), + fetch_list=[loss, w_var]) + + # Freeze graph for inference, but the weight of fc/conv is still float type. + freeze_pass = QuantizationFreezePass(scope=scope, place=place) + freeze_pass.apply(test_graph) + marked_nodes = set() + for op in test_graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'test_freeze' + dev_name + quant_type, + marked_nodes) + + server_program = test_graph.to_program() + with fluid.scope_guard(scope): + test_loss2, = exe.run(program=server_program, + feed=feeder.feed(test_data), + fetch_list=[loss]) + self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) + #print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) + #print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) + w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) + # Maybe failed, this is due to the calculation precision + # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) + #print('{}: {}'.format('w_freeze' + dev_name + quant_type, + # np.sum(w_freeze))) + #print('{}: {}'.format('w_quant' + dev_name + quant_type, + # np.sum(w_quant))) + + # Convert parameter to 8-bit. + convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) + convert_int8_pass.apply(test_graph) + marked_nodes = set() + for op in test_graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes) + server_program_int8 = test_graph.to_program() + # Save the 8-bit parameter and model file. + with fluid.scope_guard(scope): + fluid.io.save_inference_model('server_int8' + dev_name + quant_type, + ['image', 'label'], [loss], exe, + server_program_int8) + # Test whether the 8-bit parameter and model file can be loaded successfully. + [infer, feed, fetch] = fluid.io.load_inference_model( + 'server_int8' + dev_name + quant_type, exe) + # Check the loaded 8-bit weight. + w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) + self.assertEqual(w_8bit.dtype, np.int8) + self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) + #print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) + #print('{}: {}'.format('w_freeze' + dev_name + quant_type, + # np.sum(w_freeze))) + + mobile_pass = TransformForMobilePass() + mobile_pass.apply(test_graph) + marked_nodes = set() + for op in test_graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'test_mobile' + dev_name + quant_type, + marked_nodes) + + mobile_program = test_graph.to_program() + with fluid.scope_guard(scope): + fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type, + ['image', 'label'], [loss], exe, + mobile_program) + + def test_freeze_graph_cuda_dynamic(self): + if fluid.core.is_compiled_with_cuda(): + with fluid.unique_name.guard(): + self.freeze_graph(True, seed=1, quant_type='abs_max') + + def test_freeze_graph_cpu_dynamic(self): + with fluid.unique_name.guard(): + self.freeze_graph(False, seed=2, quant_type='abs_max') + + def test_freeze_graph_cuda_static(self): + if fluid.core.is_compiled_with_cuda(): + with fluid.unique_name.guard(): + self.freeze_graph(True, seed=1, quant_type='range_abs_max') + + def test_freeze_graph_cpu_static(self): + with fluid.unique_name.guard(): + self.freeze_graph(False, seed=2, quant_type='range_abs_max') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py deleted file mode 100644 index 1bd4b95d6b90b7f16d507061190f0b463f6c4cc5..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py +++ /dev/null @@ -1,175 +0,0 @@ -# copyright (c) 2018 paddlepaddle authors. all rights reserved. -# -# licensed under the apache license, version 2.0 (the "license"); -# you may not use this file except in compliance with the license. -# you may obtain a copy of the license at -# -# http://www.apache.org/licenses/license-2.0 -# -# unless required by applicable law or agreed to in writing, software -# distributed under the license is distributed on an "as is" basis, -# without warranties or conditions of any kind, either express or implied. -# see the license for the specific language governing permissions and -# limitations under the license. - -import unittest -import random -import numpy as np -import paddle.fluid as fluid -import six -from paddle.fluid.framework import Program -from paddle.fluid.framework import IrGraph -from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid import core - - -def linear_fc(num): - data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - hidden = data - for _ in six.moves.xrange(num): - hidden = fluid.layers.fc(hidden, size=128, act='relu') - loss = fluid.layers.cross_entropy(input=hidden, label=label) - loss = fluid.layers.mean(loss) - return loss - - -def residual_block(num): - def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - bias_attr=False): - tmp = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - act=None, - bias_attr=bias_attr) - return fluid.layers.batch_norm(input=tmp, act=act) - - data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - hidden = data - for _ in six.moves.xrange(num): - conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) - short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) - hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - fc = fluid.layers.fc(input=hidden, size=10) - loss = fluid.layers.cross_entropy(input=fc, label=label) - loss = fluid.layers.mean(loss) - return loss - - -class TestQuantizationTransformPass(unittest.TestCase): - def setUp(self): - self.quantizable_op_and_inputs = { - 'conv2d': ['Input', 'Filter'], - 'depthwise_conv2d': ['Input', 'Filter'], - 'mul': ['X', 'Y'] - } - self.quantizable_grad_op_inputs = { - 'conv2d_grad': ['Input', 'Filter'], - 'depthwise_conv2d_grad': ['Input', 'Filter'], - 'mul_grad': ['X', 'Y'] - } - - def check_program(self, transform_pass, program): - quantized_ops = set() - for block in program.blocks: - for op in block.ops: - # check forward - if op.type in self.quantizable_op_and_inputs: - for arg_name in op.input_arg_names: - self.assertTrue( - arg_name.endswith('.quantized.dequantized')) - quantized_ops.add(arg_name) - - for op in block.ops: - # check backward - if op.type in self.quantizable_grad_op_inputs: - for pname in self.quantizable_grad_op_inputs[op.type]: - arg_name = op.input(pname)[0] - self.assertTrue( - arg_name.endswith('.quantized.dequantized')) - self.assertTrue(arg_name in quantized_ops) - - def linear_fc_quant(self, quant_type): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - loss = linear_fc(3) - opt = fluid.optimizer.Adam(learning_rate=0.001) - opt.minimize(loss) - exe = fluid.Executor(fluid.CPUPlace()) - graph = IrGraph(core.Graph(main.desc), for_test=False) - transform_pass = QuantizationTransformPass( - scope=fluid.global_scope(), - program_exe=exe, - activation_quantize_type=quant_type) - transform_pass.apply(graph) - marked_nodes = set() - for op in graph.all_ops(): - if op.name().find('quantize') > -1: - marked_nodes.add(op) - graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) - program = graph.to_program() - self.check_program(transform_pass, program) - val_graph = IrGraph(core.Graph(program.desc), for_test=False) - val_marked_nodes = set() - for op in val_graph.all_ops(): - if op.name().find('quantize') > -1: - val_marked_nodes.add(op) - val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) - - def test_linear_fc_quant_abs_max(self): - self.act_quant_op_type = 'fake_quantize_abs_max' - self.linear_fc_quant('abs_max') - - def test_linear_fc_quant_range_abs_max(self): - self.act_quant_op_type = 'fake_quantize_range_abs_max' - self.linear_fc_quant('range_abs_max') - - def residual_block_quant(self, quant_type): - main = fluid.Program() - startup = fluid.Program() - with fluid.program_guard(main, startup): - loss = residual_block(2) - opt = fluid.optimizer.Adam(learning_rate=0.001) - opt.minimize(loss) - exe = fluid.Executor(fluid.CPUPlace()) - graph = IrGraph(core.Graph(main.desc), for_test=False) - transform_pass = QuantizationTransformPass( - scope=fluid.global_scope(), - program_exe=exe, - activation_quantize_type=quant_type) - transform_pass.apply(graph) - marked_nodes = set() - for op in graph.all_ops(): - if op.name().find('quantize') > -1: - marked_nodes.add(op) - graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) - program = graph.to_program() - self.check_program(transform_pass, program) - val_graph = IrGraph(core.Graph(program.desc), for_test=False) - val_marked_nodes = set() - for op in val_graph.all_ops(): - if op.name().find('quantize') > -1: - val_marked_nodes.add(op) - val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) - - def test_residual_block_abs_max(self): - self.act_quant_op_type = 'fake_quantize_abs_max' - self.residual_block_quant('abs_max') - - def test_residual_block_range_abs_max(self): - self.act_quant_op_type = 'fake_quantize_range_abs_max' - self.residual_block_quant('range_abs_max') - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py index 86fa84ad4bd7a55fb27f4e43128f0bfda6dfe6db..77fdf0087b93c3ad44a2492de68f8f57ce243ef3 100644 --- a/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py +++ b/python/paddle/fluid/contrib/tests/test_quantize_transpiler.py @@ -204,9 +204,11 @@ class TestQuantizeTranspiler(unittest.TestCase): build_program(test_program, startup, True) test_program = test_program.clone(for_test=True) - quant_transpiler = QuantizeTranspiler() - quant_transpiler.training_transpile(main) - quant_transpiler.training_transpile(test_program) + quant_type = 'range_abs_max' # 'range_abs_max' or 'abs_max' + quant_transpiler = QuantizeTranspiler( + activation_quantize_type=quant_type) + quant_transpiler.training_transpile(main, startup) + quant_transpiler.training_transpile(test_program, startup) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 832c97c7deb49b4e118e15989ab7a34da6ce57a0..ef304b11106628f8541b348fb263274a0c4b31e9 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -16,6 +16,8 @@ from __future__ import print_function import collections from collections import defaultdict +from collections import Iterable +import contextlib from .wrapped_decorator import signature_safe_contextmanager import os import re @@ -1529,12 +1531,16 @@ class Block(object): class IrGraph(object): """ - IrGraph uses core.Graph as the delegation to accomplish the manipulation. + Python IrGraph. Beneath it is a core.Graph, which is used for + create a c++ Ir Pass Graph. An IrGraph is just a graph view of + a Program. In an IrGraph, both Variables and Operators are graph + nodes. """ def __init__(self, graph, for_test=False): """ - Construct the IrGraph using core.Graph. + Construct an IrGraph using core.Graph. + Args: graph(core.Graph): C++ Graph. for_test(bool): True for the test graph and false for the train graph. @@ -1545,23 +1551,81 @@ class IrGraph(object): self._for_test = for_test def is_test(self): + """ + If the graph is used for testing, the function returns true. Otherwise, returns false. + """ return self._for_test - def all_parameters(self): - param_nodes = set() - for node in self.graph.nodes(): - if node.is_var() and node.var() is not None and node.var( - ).persistable(): - param_nodes.add(node) - return param_nodes + def all_nodes(self): + """ + Return all nodes included in the graph as a set. + """ + return {node for node in self.graph.nodes()} def all_vars(self): + """ + Return all variable nodes included in the graph as a set. + """ return {node for node in self.graph.nodes() if node.is_var()} + def all_persistable_vars(self): + """ + Return all persistable variable nodes included in the graph as a set. + """ + persistable_nodes = set() + for node in self.graph.nodes(): + if node.is_var() and node.var() is not None and node.var( + ).persistable(): + persistable_nodes.add(node) + return persistable_nodes + def all_ops(self): + """ + Return all operator nodes included in the graph as a set. + """ return {node for node in self.graph.nodes() if node.is_op()} + def var_node(self, name): + """ + Get a variable node by name from the graph. + + Args: + name(str): the name of the variable node. + + Raises: + ValueError: The If input's type is not str, or this graph + doesn't have a variable with the giving name. + + Returns: + core.Node: the variable node with the giving name. + """ + if not isinstance(name, six.string_types): + raise TypeError( + "var require string as parameter, but get %s instead." % + (type(name))) + target_var_node = None + var_nodes = self.all_vars() + for var_node in var_nodes: + if var_node.name() == name: + target_var_node = var_node + if target_var_node is None: + raise ValueError("var_node %s not in this graph" % name) + return target_var_node + def create_param_node(self, name, var_type, shape, var_dtype): + """ + Create a persistable variable node in the graph. In IrGraph, + it can not distinguish between persistable variables and parameters. + + Args: + name(str): the name of the persistable variable node. + vart_type(core.VarDesc.VarType): the type of the persistable variable node. + shape(list): the shape of the persistable variable node. + var_dtype(core.VarDesc.VarType): the data type of the persistable variable node. + + Returns: + core.Node: the created persistable variable node. + """ var_desc = core.VarDesc(name) var_desc.set_type(var_type) var_desc.set_shape(shape) @@ -1570,6 +1634,20 @@ class IrGraph(object): return self.graph.create_var_node(var_desc) def create_var_node(self, name, var_type, shape, var_dtype): + """ + Create a variable node in the graph. The created variable node is + not persistable. + + Args: + name(str): the name of the variable node. + vart_type(core.VarDesc.VarType): the type of the variable node. + shape(list): the shape of the variable node. + var_dtype(core.VarDesc.VarType): the data type of the variable node. + + Returns: + core.Node: the created variable node. + """ + var_desc = core.VarDesc(name) var_desc.set_type(var_type) var_desc.set_shape(shape) @@ -1577,19 +1655,41 @@ class IrGraph(object): return self.graph.create_var_node(var_desc) def create_var_node_from_desc(self, var_desc): + """ + Create a variable node by using an existing VarDesc in the graph. + Depend on the giving VarDesc, the created variable node may be persistable. + + Args: + var_desc(core.VarDesc): the giving variable description. + + Returns: + core.Node: the created variable node. + """ return self.graph.create_var_node(var_desc) def create_op_node(self, op_type, attrs, inputs, outputs): + """ + Create a operator node in the graph. + + Args: + op_type(str): the type of the operator node. + attrs(dict): the attributes of the operator node. + inputs(dict): the inputs of the operator node. + outputs(dict): the outpus of the operator node. + + Returns: + core.Node: the created operator node. + """ op_desc = core.OpDesc() op_desc.set_type(op_type) - for attr, value in attrs.iteritems(): + for attr, value in six.iteritems(attrs): self._update_desc_attr(op_desc, attr, value) - for input_name, var_nodes in inputs.iteritems(): + for input_name, var_nodes in six.iteritems(inputs): if not isinstance(var_nodes, list): var_nodes = [var_nodes] op_desc.set_input(input_name, [var_node.name() for var_node in var_nodes]) - for output_name, var_nodes in outputs.iteritems(): + for output_name, var_nodes in six.iteritems(outputs): if not isinstance(var_nodes, list): var_nodes = [var_nodes] op_desc.set_output(output_name, @@ -1597,11 +1697,29 @@ class IrGraph(object): return self.graph.create_op_node(op_desc) def create_op_node_from_desc(self, op_desc): + """ + Create a operator node by using an existing OpDesc in the graph. + + Args: + op_desc(core.VarDesc): the giving operator description. + + Returns: + core.Node: the created operator node. + """ return self.graph.create_op_node(op_desc) def update_input_link(self, old_input_node, new_input_node, op_node): - assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \ - op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.' + """ + Update the input's link of a operator node. + + Args: + old_input_node(core.Node): the old input node of the giving op_node. + new_input_node(core.Node): the new input node of the giving op_node. + op_node(core.Node): the operator node that is needed to update input's link. + """ + assert old_input_node in self.graph.nodes() and new_input_node in \ + self.graph.nodes() and op_node in self.graph.nodes(), \ + 'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.' old_input_node.outputs_remove(op_node) op_node.inputs_remove(old_input_node) new_input_node.outputs_append(op_node) @@ -1609,17 +1727,85 @@ class IrGraph(object): op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) def link_to(self, node_in, node_out): + """ + Connect two nodes. + + Args: + node_in(core.Node): the input node. + node_out(core.Node): the output node. + """ assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ - 'Th two arguments must be in the graph nodes.' + 'The two arguments(node_in&node_out) must be in the graph nodes.' node_in.outputs_append(node_out) node_out.inputs_append(node_in) def safe_remove_nodes(self, remove_nodes): + """ + Remove nodes safely since links connected to these removed nodes are + also removed. + + Args: + remove_nodes(set): the nodes prepared to be removed. + """ if not isinstance(remove_nodes, set): - remove_nodes = set(remove_nodes) + if isinstance(remove_nodes, Iterable): + remove_nodes = set(remove_nodes) + else: + remove_nodes = {remove_nodes} core.graph_safe_remove_nodes(self.graph, remove_nodes) - def draw(self, save_path, name, marked_nodes=None): + def has_circle(self): + """ + Check if the graph has a circle. + + Returns: + bool: True if the graph has a circle else False. + """ + return core.has_circle(self.graph) + + def graph_num(self): + """ + Count the number of unconnected graphs in this graph. + + Returns: + int: the number of unconnected graphs. + """ + return core.graph_num(self.graph) + + def topology_sort(self): + """ + Perform the topology sort operation on the graph. + + Notes: the `graph` cannot contain a circle. + + Returns: + set(core.Node): nodes in topology order. + """ + return core.topology_sort(self.graph) + + def build_adjacency_list(self): + """ + Build an adjacency list of operations for the `graph`. + + Returns: + dict{core.Node: set(core.Node)}: the adjacency list. + """ + return core.build_adjacency_list(self.graph) + + def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True): + """ + Draw the graph. If `dot` command is installed, the drawn graph + will be saved as pdf file type, otherwise dot file type is used. + + Args: + save_path(str): the save path of drawn graph. + name(str): the name of drawn graph. + marked_nodes(set(core.Node)): nodes that are needed to be marked. + Default value is None. + remove_ctr_var(bool): If it is set True, all control variable nodes + in the graph will be removed. Default value is True. + """ + def _convert_to_pdf(dot_file_path): pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ @@ -1629,15 +1815,17 @@ class IrGraph(object): print('The {} is saved as the dot filetype.'.format( dot_file_path)) - remove_ctr_vars = set() + if remove_ctr_var: + remove_ctr_vars = set() + for node in self.graph.nodes(): + if node.is_ctrl_var(): + remove_ctr_vars.add(node) + self.safe_remove_nodes(remove_ctr_vars) ops_num = 0 for node in self.graph.nodes(): - if node.is_ctrl_var(): - remove_ctr_vars.add(node) - elif node.is_op(): + if node.is_op(): ops_num += 1 print('Total ops num = {}.'.format(ops_num)) - self.safe_remove_nodes(remove_ctr_vars) if marked_nodes is not None: if not isinstance(marked_nodes, set): marked_nodes = set(marked_nodes) @@ -1652,10 +1840,20 @@ class IrGraph(object): _convert_to_pdf(viz_dot_path) def to_program(self): + """ + Convert the graph into a Program. + + Notes: When the graph includes backward operator nodes, the + conversion process may be failed. Usually, this function is + only used to convert a test graph. + + Returns: + Program: a program converted from the graph. + """ convert_pass = core.get_pass('graph_to_program_pass') - convert_pass.set('program', Program().desc) + desc = core.ProgramDesc() + convert_pass.set_not_owned('program', desc) convert_pass.apply(self.graph) - desc = convert_pass.get_program('program') program = Program._construct_from_desc(desc) return program