diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index df0ff772c9d35c88ec5a6112525c56aa92d359b9..ad73085f52e9f17f82ae4e9e4666cd83cb6c509f 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -233,3 +233,4 @@ USE_PASS(sequential_execution_pass); USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(lock_free_optimize_pass); +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 6cf405efe63d2bc284c4650771a747b27bb4a9f6..33ccee6aa0a94b8fd8308214d6144ae832d40bab 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -28,10 +28,14 @@ std::unique_ptr Pass::Apply(std::unique_ptr graph) const { PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", attr); } + auto* native_graph = graph.get(); auto applied_graph = ApplyImpl(std::move(graph)); // TODO(panyx0718): Add more verifications. PADDLE_ENFORCE(!HasCircle(*applied_graph), "Illegal Pass. Generated graph shouldn't has cycle."); + PADDLE_ENFORCE(applied_graph.get() == native_graph, + "Pass::Apply() cannot delete the passed graph and shouldn't " + "return a new graph.(For the need of pybind11)"); applied_ = true; return applied_graph; } diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index d32fe58f8695a5c14f276ef038416f5c47f3400f..1205ccf7f025c3aea62255f1d42d34df4fd9826d 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -42,6 +42,7 @@ void BindGraph(py::module *m) { .def("get_float", &Graph::Get) .def("get_double", &Graph::Get) .def("get_string", &Graph::Get) + .def("get_program", &Graph::Get) .def("set", [](Graph &self, const std::string &attr_name, int attr) { return self.Set(attr_name, new int(attr)); }) .def("set", @@ -57,6 +58,11 @@ void BindGraph(py::module *m) { [](Graph &self, const std::string &attr_name, double attr) { return self.Set(attr_name, new double(attr)); }) + .def("set", + [](Graph &self, const std::string &attr_name, + const ProgramDesc &attr) { + return self.Set(attr_name, new ProgramDesc(attr)); + }) .def("erase", &Graph::Erase) .def("nodes", &Graph::Nodes, return_value_policy::reference) .def("create_var_node", diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 4b218fb3a2af0933ea1e87abe20e7e031c32f721..09c08f1ffc884a0f6f30acfd6613e6e2f5a8fe75 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -229,6 +229,12 @@ void BindBlockDesc(pybind11::module *m) { void BindVarDsec(pybind11::module *m) { pybind11::class_ var_desc(*m, "VarDesc", ""); var_desc + .def("__init__", + [](pd::VarDesc &self, const pybind11::bytes &binary_str) { + std::string str(binary_str); + new (&self) pd::VarDesc(str); + }, + pybind11::return_value_policy::reference) .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference) .def("set_name", &pd::VarDesc::SetName) .def("set_shape", &pd::VarDesc::SetShape) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f3f4854a9efbcf5ab325e7f6aec81135c018dcd5..ae50f3885f68a3687689d7e2047ecf3c09c344a6 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -786,9 +786,20 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); + m.def("get_pass", [](const py::bytes &binary_str) { + std::string pass_type(binary_str); + auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); + return std::shared_ptr(std::move(pass)); + }); py::class_> pass(m, "Pass"); pass.def(py::init()) + .def("has", &ir::Pass::Has) + .def("set_program", + [](ir::Pass &self, const std::string &attr_name, + const ProgramDesc &attr) { + return self.Set(attr_name, new ProgramDesc(attr)); + }) .def( "set_str", [](ir::Pass &self, const std::string &name, const std::string &attr) { @@ -796,11 +807,12 @@ All parameter, weight, gradient are variables in Paddle. }) .def("set_int", [](ir::Pass &self, const std::string &name, int val) { self.Set(name, new int(val)); }) + .def("get_program", &ir::Pass::Get) .def("type", &ir::Pass::Type) .def("apply", [](ir::Pass &self, std::shared_ptr graph) { std::unique_ptr origin_graph(graph.get()); auto optim_graph = self.Apply(std::move(origin_graph)); - graph.reset(optim_graph.release()); + optim_graph.release(); }); py::class_> pb( diff --git a/python/paddle/fluid/contrib/slim/graph/graph.py b/python/paddle/fluid/contrib/slim/graph/graph.py index 7d6b0702035d49189c0919f976ea3c0c52663547..774da2d1ef182fe5f7d2cdd306b1beeb7772a6c9 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph.py +++ b/python/paddle/fluid/contrib/slim/graph/graph.py @@ -13,8 +13,81 @@ # limitations under the License. from ....framework import Program +from ....framework import Block +from .... import core -__all__ = ['Graph', 'ImitationGraph', 'IRGraph'] +__all__ = ['Graph', 'ImitationGraph', 'PyGraph'] + + +class PyGraph(object): + """ + PyGraph uses core.Graph as the delegation to accomplish the manipulation. + """ + + def __init__(self, graph): + assert isinstance( + graph, core.Graph), 'graph must be the instance of core.Graph.' + self.graph = graph + + def all_parameters(self): + params = [] + for node in self.graph.nodes(): + if node.is_var() and node.var().persistable(): + params.append(node) + return params + + def all_vars(self): + return [node for node in self.graph.nodes() if node.is_var()] + + def all_ops(self): + return [node for node in self.graph.nodes() if node.is_op()] + + def create_param_node(self, name, var_type, shape, var_dtype): + var_desc = core.VarDesc(name) + var_desc.set_type(var_type) + var_desc.set_shape(shape) + var_desc.set_dtype(var_dtype) + var_desc.set_persistable(True) + return self.graph.create_var_node(var_desc) + + def create_var_node(self, name, var_type, shape, var_dtype): + var_desc = core.VarDesc(name) + var_desc.set_type(var_type) + var_desc.set_shape(shape) + var_desc.set_dtype(var_dtype) + return self.graph.create_var_node(var_desc) + + def create_var_node_from_desc(self, var_desc): + return self.graph.create_var_node(var_desc) + + def create_op_node(self, op_type, attrs, inputs, outputs): + op_desc = core.OpDesc() + op_desc.set_type(op_type) + for attr, value in attrs.iteritems(): + self._update_desc_attr(op_desc, attr, value) + for input_name, var_node in inputs.iteritems(): + op_desc.set_input(input_name, [var_node.name()]) + for output_name, var_node in outputs.iteritems(): + op_desc.set_output(output_name, [var_node.name()]) + return self.graph.create_op_node(op_desc) + + def create_op_node_from_desc(self, op_desc): + return self.graph.create_op_node(op_desc) + + def _update_desc_attr(self, desc, name, val): + """ + Update the value of desc's attribute by attribute's name. + """ + if isinstance(val, Block): + desc.set_block_attr(name, val.desc) + elif isinstance(val, list) and val and all( + isinstance(v, Block) for v in val): + desc.set_blocks_attr(name, [v.desc for v in val]) + elif isinstance(val, core.BlockDesc) or \ + isinstance(val, core.ProgramDesc): + desc.set_serialized_attr(name, val.serialize_to_string()) + else: + desc._set_attr(name, val) class Graph(object): @@ -39,7 +112,3 @@ class ImitationGraph(Graph): def all_parameters(self): return self.program.global_block().all_parameters() - - -class IRGraph(Graph): - pass diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_performer.py b/python/paddle/fluid/contrib/slim/quantization/quantization_performer.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9207dfbc97879e0fded81b6381d456d85efb28 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_performer.py @@ -0,0 +1,287 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import numpy as np +from .... import core +from ....initializer import Constant +from .... import unique_name +from ..graph import PyGraph + + +class QuantizationPerformer(object): + def __init__(self, + weight_bits=8, + activation_bits=8, + activation_quantize_type='abs_max', + weight_quantize_type='abs_max', + window_size=10000): + """ + Convert and rewrite the IRGraph according to weight and + activation quantization type. + Args: + weight_bits (int): quantization bit number for weights, + the bias is not quantized. + activation_bits (int): quantization bit number for activation. + activation_quantize_type (str): quantization type for activation, + now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, + the quantization scale will be calculated dynamically each step + in both training and testing period. If use 'range_abs_max', + a static quantization scale will be calculated during training + and used in inference. + weight_quantize_type (str): quantization type for weights, + support 'abs_max'. The 'range_abs_max' usually is not used for + weight, since weights are fixed once the model is well trained. + window_size (int): the window size for 'range_abs_max' quantization. + Examples: + .. code-block:: python + # the original graph will be rewrite, if you don't want to + # change it, please clone at first. + # graph = graph.clone() + from paddle.fluid.contrib.slim import * + from paddle.fluid.contrib.quantize import * + graph = IRGraph(program) + performer = QuantizationPerformer() + performer.quantize_transform(graph) + """ + self.weight_bits = weight_bits + self.activation_bits = activation_bits + + quant_type = ['abs_max', 'range_abs_max'] + if activation_quantize_type not in quant_type: + raise ValueError( + "Unknown activation_quantize_type : '%s'. It can only be ", + "'abs_max' or 'range_abs_max'.", str(activation_quantize_type)) + if weight_quantize_type not in quant_type: + raise ValueError( + "Unknown weight_quantize_type: '%s'. It can only be ", + "'abs_max' or 'range_abs_max'.", str(weight_quantize_type)) + + self.activation_quantize_type = activation_quantize_type + self.weight_quantize_type = weight_quantize_type + self.window_size = window_size + + self.need_inited_outer = collections.OrderedDict() + self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self.quantizable_grad_ops = [ + '%s_grad' % (op) for op in self.quantizable_ops + ] + self.fake_quant_op_types = [ + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + ] + self.fake_dequant_op_types = ['fake_dequantize_max_abs'] + self.is_test = None + self.global_step = None + + def quantize_transform(self, graph, is_test): + self.need_inited_outer.clear() + self.is_test = is_test + assert isinstance(graph, + PyGraph), 'graph must be the instance of PyGraph.' + # marked the variable which has been dequantized. + dequantized_vars = collections.OrderedDict() + params = [p.name() for p in graph.all_parameters()] + + def _transform_forward(graph, op): + for var_node in op.inputs: + if var_node.name() in dequantized_vars: + dequant_var_node = dequantized_vars[var_node.name()] + else: + quant_bits = self.weight_bits if var_node.name() in params \ + else self.activation_bits + quant_type = self.weight_quantize_type if var_node.name() \ + in params else self.activation_quantize_type + quant_var_node, scale_var_node = self._insert_quant_op( + graph, var_node, quant_bits, quant_type) + dequant_var_node = self._insert_dequant_op( + graph, quant_var_node, scale_var_node, quant_bits) + dequantized_vars[var_node.name()] = dequant_var_node + self._update_input(var_node, dequant_var_node, op) + + if not self.is_test: + self._create_global_step(graph) + ops = graph.all_ops() + for op in ops: + # transform the forward graph + if op.name() in self.quantizable_ops: + _transform_forward(graph, op) + # rename the inputs of backward op + if op.name() in self.quantizable_grad_ops: + _transform_backward(graph, op) + return self.need_inited_outer + + def _insert_quant_op(self, graph, var_node, quant_bits, quant_type): + """ + Insert fake_quantize_op in the graph. + """ + if quant_type == 'abs_max': + return self._insert_quant_abs_max_op(graph, var_node, quant_bits) + elif quant_type == 'range_abs_max': + return self._inser_quant_range_abs_max_op(graph, var_node, + quant_bits) + + def _insert_quant_abs_max_op(self, graph, var_node, quant_bits): + """ + Insert fake_quantize_abs_max op in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + scale_var_node = graph.create_var_node( + name=self._quantized_scale_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + quant_op_node = graph.create_op_node( + op_type='fake_quantize_abs_max', + attrs={'bit_length': quant_bits}, + inputs={'X': var_node}, + outputs={'Out': quant_var_node, + 'OutScale': scale_var_node}) + self._link_to(var_node, quant_op_node) + self._link_to(quant_op_node, quant_var_node) + self._link_to(quant_op_node, scale_var_node) + return quant_var_node, scale_var_node + + def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits): + """ + Insert fake_quantize_range_abs_max on the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + + scale_in_node = graph.create_param_node( + name=self._quantized_scale_name(var_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=var_node.var().dtype()) + self.need_inited_outer[scale_in_node.var()] = Constant(value=0.001) + + scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) + inputs = {'X': var_node, 'InScale': scale_in_node} + outputs = {'Out': quant_var_node, 'OutScale': scale_out_node} + + if not self.is_test: + # The name of scales_var_node maybe 'scales_0', 'scales_1', etc. + scales_node = graph.create_param_node( + name=unique_name.generate('scales'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[self.window_size], + var_dtype=var_node.var().dtype()) + self.need_inited_outer[scales_node.var()] = Constant(value=0) + inputs['Iter'] = self.global_step + outputs['OutScales'] = scales_node + attrs = { + 'window_size': self.window_size, + 'bit_length': quant_bits, + 'is_test': self.is_test + } + quant_op_node = graph.create_op_node( + op_type='fake_quantize_range_abs_max', + attrs=attrs, + inputs=inputs, + outputs=outputs) + + self._link_to(var_node, quant_op_node) + self._link_to(scale_in_node, quant_op_node) + self._link_to(quant_op_node, quant_var_node) + self._link_to(quant_op_node, scale_out_node) + + if not self.is_test: + self._link_to(self.global_step, quant_op_node) + self._link_to(quant_op_node, scales_node) + + return quant_var_node, scale_out_node + + def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): + """ + Insert fake_dequantize_op in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + max_range = (1 << (quant_bits - 1)) - 1 + dequant_op_node = graph.create_op_node( + op_type='fake_dequantize_max_abs', + attrs={'max_range': float(max_range)}, + inputs={'X': var_node, + 'Scale': scale_var_node}, + outputs={'Out': dequant_var_node}) + self._link_to(var_node, dequant_op_node) + self._link_to(scale_var_node, dequant_op_node) + self._link_to(dequant_op_node, dequant_var_node) + return dequant_var_node + + def _update_input(self, old_input_node, new_input_node, op_node): + old_input_node.outputs.remove(op_node) + op_node.inputs.remove(old_input_node) + new_input_node.outputs.append(op_node) + op_node.inputs.append(new_input_node) + + def _link_to(node_in, node_out): + node_in.outputs.append(node_out) + node_out.inputs.append(node_in) + + def _quantized_var_name(self, var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.quantized" % (var_name) + + def _dequantized_var_name(self, var_name): + """ + Return dequantized variable name for the input `var_name`. + """ + return "%s.dequantized" % (var_name) + + def _quantized_scale_name(self, var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.scale" % (var_name) + + def _original_var_name(self, var_name): + """ + Return the original variable name. + """ + if var_name.endswith('.quantized.dequantized'): + return var_name[:-len('.quantized.dequantized')] + if var_name.endswith('.quantized'): + return var_name[:-len('.quantized')] + if var_name.endswith('.dequantized'): + return var_name[:-len('.dequantized')] + if var_name.endswith('.scale'): + return var_name[:-len('.scale')] + else: + return var_name + + def _is_float(self, v): + return isinstance(v, float) or isinstance(v, np.float32) + + def _quant(self, x, scale, num_bits): + y = np.round(x / scale * ((1 << (num_bits - 1)) - 1)) + return y