diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 756470c5b0bfb2ce2a90531442211bcb19172dae..ce5731a1f414e8ef6d8af22a3bb17109e82beb87 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" namespace paddle { @@ -243,3 +244,4 @@ USE_PASS(sequential_execution_pass); USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(lock_free_optimize_pass); +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 6cf405efe63d2bc284c4650771a747b27bb4a9f6..33ccee6aa0a94b8fd8308214d6144ae832d40bab 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -28,10 +28,14 @@ std::unique_ptr Pass::Apply(std::unique_ptr graph) const { PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", attr); } + auto* native_graph = graph.get(); auto applied_graph = ApplyImpl(std::move(graph)); // TODO(panyx0718): Add more verifications. PADDLE_ENFORCE(!HasCircle(*applied_graph), "Illegal Pass. Generated graph shouldn't has cycle."); + PADDLE_ENFORCE(applied_graph.get() == native_graph, + "Pass::Apply() cannot delete the passed graph and shouldn't " + "return a new graph.(For the need of pybind11)"); applied_ = true; return applied_graph; } diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index d32fe58f8695a5c14f276ef038416f5c47f3400f..24059140ab20e24917b93a5f60936b1087797ff9 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -15,7 +15,9 @@ #include "paddle/fluid/pybind/ir.h" #include #include +#include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" @@ -24,6 +26,7 @@ namespace py = pybind11; using paddle::framework::ir::Graph; using paddle::framework::ir::Node; +using paddle::framework::ir::GraphSafeRemoveNodes; using paddle::framework::OpDesc; using paddle::framework::ProgramDesc; using paddle::framework::VarDesc; @@ -32,6 +35,7 @@ using pybind11::return_value_policy; namespace paddle { namespace pybind { void BindGraph(py::module *m) { + m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); py::class_>( *m, "Graph", "The graph is a Directed Acyclic Single Static Assignment Graph, see " @@ -42,6 +46,8 @@ void BindGraph(py::module *m) { .def("get_float", &Graph::Get) .def("get_double", &Graph::Get) .def("get_string", &Graph::Get) + .def("get_program", &Graph::Get) + .def("get_marked_nodes", &Graph::Get>) .def("set", [](Graph &self, const std::string &attr_name, int attr) { return self.Set(attr_name, new int(attr)); }) .def("set", @@ -57,6 +63,17 @@ void BindGraph(py::module *m) { [](Graph &self, const std::string &attr_name, double attr) { return self.Set(attr_name, new double(attr)); }) + .def("set", + [](Graph &self, const std::string &attr_name, + const ProgramDesc &attr) { + return self.Set(attr_name, new ProgramDesc(attr)); + }) + .def("set", + [](Graph &self, const std::string &attr_name, + const std::unordered_set &attr) { + return self.Set(attr_name, + new std::unordered_set(attr)); + }) .def("erase", &Graph::Erase) .def("nodes", &Graph::Nodes, return_value_policy::reference) .def("create_var_node", @@ -85,12 +102,52 @@ void BindNode(py::module *m) { py::class_ node(*m, "Node"); node.def("name", &Node::Name) .def("node_type", &Node::NodeType) - .def("var", &Node::Var) - .def("op", &Node::Op) + .def("var", &Node::Var, return_value_policy::reference) + .def("op", &Node::Op, return_value_policy::reference) .def("id", &Node::id) .def("is_op", &Node::IsOp) .def("is_var", &Node::IsVar) .def("is_ctrl_var", &Node::IsCtrlVar) + .def("inputs_remove", + [](Node &self, int node_id) { + for (auto it = self.inputs.begin(); it != self.inputs.end(); + it++) { + if ((*it)->id() == node_id) { + self.inputs.erase(it); + } + } + }) + .def("inputs_remove", + [](Node &self, Node &node) { + for (auto it = self.inputs.begin(); it != self.inputs.end(); + it++) { + if (*it == &node) { + self.inputs.erase(it); + } + } + }) + .def("inputs_append", + [](Node &self, Node &node) { self.inputs.push_back(&node); }) + .def("outputs_remove", + [](Node &self, int node_id) { + for (auto it = self.outputs.begin(); it != self.outputs.end(); + it++) { + if ((*it)->id() == node_id) { + self.outputs.erase(it); + } + } + }) + .def("outputs_remove", + [](Node &self, Node &node) { + for (auto it = self.outputs.begin(); it != self.outputs.end(); + it++) { + if (*it == &node) { + self.outputs.erase(it); + } + } + }) + .def("outputs_append", + [](Node &self, Node &node) { self.outputs.push_back(&node); }) .def_readwrite("inputs", &Node::inputs) .def_readwrite("outputs", &Node::outputs); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 4b218fb3a2af0933ea1e87abe20e7e031c32f721..e729be4a95a58510f1e0162af4216feaa400d971 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -228,7 +228,7 @@ void BindBlockDesc(pybind11::module *m) { void BindVarDsec(pybind11::module *m) { pybind11::class_ var_desc(*m, "VarDesc", ""); - var_desc + var_desc.def(pybind11::init()) .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference) .def("set_name", &pd::VarDesc::SetName) .def("set_shape", &pd::VarDesc::SetShape) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index b086c218988b09955a06743325d00f9bd57f7d90..c470483756659b55329c022e0c43002182db815b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -788,21 +788,33 @@ All parameter, weight, gradient are variables in Paddle. m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); + m.def("get_pass", [](const py::bytes &binary_str) { + std::string pass_type(binary_str); + auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); + return std::shared_ptr(std::move(pass)); + }); py::class_> pass(m, "Pass"); pass.def(py::init()) + .def("has", &ir::Pass::Has) + .def("set", + [](ir::Pass &self, const std::string &attr_name, + const ProgramDesc &attr) { + return self.Set(attr_name, new ProgramDesc(attr)); + }) .def( - "set_str", + "set", [](ir::Pass &self, const std::string &name, const std::string &attr) { self.Set(name, new std::string(attr)); }) - .def("set_int", [](ir::Pass &self, const std::string &name, - int val) { self.Set(name, new int(val)); }) + .def("set", [](ir::Pass &self, const std::string &name, + int val) { self.Set(name, new int(val)); }) + .def("get_program", &ir::Pass::Get) .def("type", &ir::Pass::Type) .def("apply", [](ir::Pass &self, std::shared_ptr graph) { std::unique_ptr origin_graph(graph.get()); auto optim_graph = self.Apply(std::move(origin_graph)); - graph.reset(optim_graph.release()); + optim_graph.release(); }); py::class_> pb( diff --git a/python/paddle/fluid/contrib/slim/graph/graph.py b/python/paddle/fluid/contrib/slim/graph/graph.py index 7d6b0702035d49189c0919f976ea3c0c52663547..f38d9783413a01cd1005a014c0aba5ecf5cc79c2 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph.py +++ b/python/paddle/fluid/contrib/slim/graph/graph.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from __future__ import print_function +import os +import subprocess from ....framework import Program +from ....framework import Block +from .... import core __all__ = ['Graph', 'ImitationGraph', 'IRGraph'] diff --git a/python/paddle/fluid/contrib/slim/quantization/__init__.py b/python/paddle/fluid/contrib/slim/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c26475f48855674d97abf5778a631646734fcf8 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import quantization_pass +from .quantization_pass import * + +__all__ = quantization_pass.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..266a106bc507104c0a8db1c882b55ac59e88195e --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -0,0 +1,318 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from .... import core +from ....framework import IrGraph +from ....framework import Program +from ....framework import Variable +from ....initializer import Constant +from .... import unique_name + +__all__ = ['QuantizationTransformPass'] + + +class QuantizationTransformPass(object): + def __init__(self, + scope=None, + program_exe=None, + weight_bits=8, + activation_bits=8, + activation_quantize_type='abs_max', + weight_quantize_type='abs_max', + window_size=10000): + """ + Convert and rewrite the IrGraph according to weight and + activation quantization type. + Args: + weight_bits (int): quantization bit number for weights, + the bias is not quantized. + activation_bits (int): quantization bit number for activation. + activation_quantize_type (str): quantization type for activation, + now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, + the quantization scale will be calculated dynamically each step + in both training and testing period. If use 'range_abs_max', + a static quantization scale will be calculated during training + and used in inference. + weight_quantize_type (str): quantization type for weights, + support 'abs_max'. The 'range_abs_max' usually is not used for + weight, since weights are fixed once the model is well trained. + window_size (int): the window size for 'range_abs_max' quantization. + Examples: + .. code-block:: python + # The original graph will be rewrite. + import paddle.fluid as fluid + from paddle.fluid.contrib.slim.quantization \ + import QuantizationTransformPass + from paddle.fluid.contrib.slim.graph import IrGraph + from paddle.fluid import core + + graph = IrGraph(core.Graph(program.desc), for_test=False) + exe = fluid.Executor(fluid.CPUPlace()) + transform_pass = QuantizationTransformPass(fluid.global_scope(), + exe) + transform_pass.apply(graph) + """ + self._scope = scope + self._program_exe = program_exe + self._weight_bits = weight_bits + self._activation_bits = activation_bits + + quant_type = ['abs_max', 'range_abs_max'] + if activation_quantize_type not in quant_type: + raise ValueError( + "Unknown activation_quantize_type : '%s'. It can only be ", + "'abs_max' or 'range_abs_max'.", str(activation_quantize_type)) + if weight_quantize_type not in quant_type: + raise ValueError( + "Unknown weight_quantize_type: '%s'. It can only be ", + "'abs_max' or 'range_abs_max'.", str(weight_quantize_type)) + + self._activation_quantize_type = activation_quantize_type + self._weight_quantize_type = weight_quantize_type + self._window_size = window_size + + self._need_initialized = collections.OrderedDict() + self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._quantizable_grad_ops = [ + '%s_grad' % (op) for op in self._quantizable_ops + ] + self._fake_quant_op_types = [ + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + ] + self._fake_dequant_op_types = ['fake_dequantize_max_abs'] + self._is_test = None + self._global_step = None + + def apply(self, graph): + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + self._need_initialized.clear() + self._is_test = graph.is_test() + # marked the variable which has been dequantized. + dequantized_vars = collections.OrderedDict() + params = [p.name() for p in graph.all_parameters()] + + def _transform_forward(graph, op): + for var_node in op.inputs: + if var_node.name() in dequantized_vars: + dequant_var_node = dequantized_vars[var_node.name()] + else: + quant_bits = self._weight_bits if var_node.name() in params \ + else self._activation_bits + quant_type = self._weight_quantize_type if var_node.name() \ + in params else self._activation_quantize_type + quant_var_node, scale_var_node = self._insert_quant_op( + graph, var_node, quant_bits, quant_type) + dequant_var_node = self._insert_dequant_op( + graph, quant_var_node, scale_var_node, quant_bits) + dequantized_vars[var_node.name()] = dequant_var_node + graph.update_input_link(var_node, dequant_var_node, op) + + def _transform_backward(graph, op): + no_dequanted_input_vars = True + for var_node in op.inputs: + if var_node.name() in dequantized_vars: + dequant_var_node = dequantized_vars[var_node.name()] + graph.update_input_link(var_node, dequant_var_node, op) + no_dequanted_input_vars = False + if no_dequanted_input_vars: + raise ValueError("There is no dequanted inputs for op %s." % + (op.name())) + + if not self._is_test: + self._create_global_step(graph) + ops = graph.all_ops() + # The process of _transform_forward and _transform_backward is needed in two for loops. + # The loop for transforming the forward graph: + for op in ops: + if op.name() in self._quantizable_ops: + _transform_forward(graph, op) + # The loop for renaming the inputs of backward op. + for op in ops: + if op.name() in self._quantizable_grad_ops: + _transform_backward(graph, op) + + if len(self._need_initialized) > 0: + assert self._scope is not None, \ + 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' + assert self._program_exe is not None, \ + 'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' + init_program = Program() + for var_desc, initializer in self._need_initialized.iteritems(): + var = Variable(init_program.global_block()) + var._set_desc(var_desc) + initializer(var, init_program.global_block()) + self._program_exe.run(program=init_program, scope=self._scope) + + return graph + + def _create_global_step(self, graph): + if self._weight_quantize_type == 'range_abs_max' or \ + self._activation_quantize_type == 'range_abs_max': + counter_name = '@STEP_COUNTER@' + for node in graph.all_vars(): + if node.name() == counter_name: + self._global_step = node + if self._global_step is None: + global_step_in = graph.create_param_node( + name=counter_name, + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=core.VarDesc.VarType.INT64) + self._need_initialized[global_step_in.var()] = \ + Constant(value=0, force_cpu=True) + global_step_out = graph.create_var_node_from_desc( + global_step_in.var()) + increment_op = graph.create_op_node( + op_type='increment', + attrs={'step': 1.0}, + inputs={'X': global_step_in}, + outputs={'Out': global_step_out}) + graph.link_to(global_step_in, increment_op) + graph.link_to(increment_op, global_step_out) + self._global_step = global_step_out + + def _insert_quant_op(self, graph, var_node, quant_bits, quant_type): + """ + Insert fake_quantize_op in the graph. + """ + if quant_type == 'abs_max': + return self._insert_quant_abs_max_op(graph, var_node, quant_bits) + elif quant_type == 'range_abs_max': + return self._insert_quant_range_abs_max_op(graph, var_node, + quant_bits) + + def _insert_quant_abs_max_op(self, graph, var_node, quant_bits): + """ + Insert fake_quantize_abs_max op in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + scale_var_node = graph.create_var_node( + name=self._quantized_scale_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + quant_op_node = graph.create_op_node( + op_type='fake_quantize_abs_max', + attrs={'bit_length': quant_bits}, + inputs={'X': var_node}, + outputs={'Out': quant_var_node, + 'OutScale': scale_var_node}) + graph.link_to(var_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_var_node) + return quant_var_node, scale_var_node + + def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits): + """ + Insert fake_quantize_range_abs_max on the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + + scale_in_node = graph.create_param_node( + name=self._quantized_scale_name(var_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=var_node.var().dtype()) + self._need_initialized[scale_in_node.var()] = Constant(value=0.001) + + scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) + inputs = {'X': var_node, 'InScale': scale_in_node} + outputs = {'Out': quant_var_node, 'OutScale': scale_out_node} + + if not self._is_test: + # The name of scales_var_node maybe 'scales_0', 'scales_1', etc. + scales_node = graph.create_param_node( + name=unique_name.generate('scales'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[self._window_size], + var_dtype=var_node.var().dtype()) + self._need_initialized[scales_node.var()] = Constant(value=0) + inputs['Iter'] = self._global_step + outputs['OutScales'] = scales_node + attrs = { + 'window_size': self._window_size, + 'bit_length': quant_bits, + 'is_test': self._is_test + } + quant_op_node = graph.create_op_node( + op_type='fake_quantize_range_abs_max', + attrs=attrs, + inputs=inputs, + outputs=outputs) + + graph.link_to(var_node, quant_op_node) + graph.link_to(scale_in_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_out_node) + + if not self._is_test: + graph.link_to(self._global_step, quant_op_node) + graph.link_to(quant_op_node, scales_node) + + return quant_var_node, scale_out_node + + def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): + """ + Insert fake_dequantize_op in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(var_node.name()), + var_type=var_node.var().type(), + shape=var_node.var().shape(), + var_dtype=var_node.var().dtype()) + max_range = (1 << (quant_bits - 1)) - 1 + dequant_op_node = graph.create_op_node( + op_type='fake_dequantize_max_abs', + attrs={'max_range': float(max_range)}, + inputs={'X': var_node, + 'Scale': scale_var_node}, + outputs={'Out': dequant_var_node}) + graph.link_to(var_node, dequant_op_node) + graph.link_to(scale_var_node, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) + return dequant_var_node + + def _quantized_var_name(self, var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.quantized" % (var_name) + + def _dequantized_var_name(self, var_name): + """ + Return dequantized variable name for the input `var_name`. + """ + return "%s.dequantized" % (var_name) + + def _quantized_scale_name(self, var_name): + """ + Return the scale name of quantized variable for the input `var_name`. + """ + return "%s.scale" % (var_name) diff --git a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd4b95d6b90b7f16d507061190f0b463f6c4cc5 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py @@ -0,0 +1,175 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import random +import numpy as np +import paddle.fluid as fluid +import six +from paddle.fluid.framework import Program +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid import core + + +def linear_fc(num): + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in six.moves.xrange(num): + hidden = fluid.layers.fc(hidden, size=128, act='relu') + loss = fluid.layers.cross_entropy(input=hidden, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def residual_block(num): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in six.moves.xrange(num): + conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) + short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) + hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') + fc = fluid.layers.fc(input=hidden, size=10) + loss = fluid.layers.cross_entropy(input=fc, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestQuantizationTransformPass(unittest.TestCase): + def setUp(self): + self.quantizable_op_and_inputs = { + 'conv2d': ['Input', 'Filter'], + 'depthwise_conv2d': ['Input', 'Filter'], + 'mul': ['X', 'Y'] + } + self.quantizable_grad_op_inputs = { + 'conv2d_grad': ['Input', 'Filter'], + 'depthwise_conv2d_grad': ['Input', 'Filter'], + 'mul_grad': ['X', 'Y'] + } + + def check_program(self, transform_pass, program): + quantized_ops = set() + for block in program.blocks: + for op in block.ops: + # check forward + if op.type in self.quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + quantized_ops.add(arg_name) + + for op in block.ops: + # check backward + if op.type in self.quantizable_grad_op_inputs: + for pname in self.quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + self.assertTrue(arg_name in quantized_ops) + + def linear_fc_quant(self, quant_type): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = linear_fc(3) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + exe = fluid.Executor(fluid.CPUPlace()) + graph = IrGraph(core.Graph(main.desc), for_test=False) + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + program_exe=exe, + activation_quantize_type=quant_type) + transform_pass.apply(graph) + marked_nodes = set() + for op in graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) + program = graph.to_program() + self.check_program(transform_pass, program) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + val_marked_nodes = set() + for op in val_graph.all_ops(): + if op.name().find('quantize') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) + + def test_linear_fc_quant_abs_max(self): + self.act_quant_op_type = 'fake_quantize_abs_max' + self.linear_fc_quant('abs_max') + + def test_linear_fc_quant_range_abs_max(self): + self.act_quant_op_type = 'fake_quantize_range_abs_max' + self.linear_fc_quant('range_abs_max') + + def residual_block_quant(self, quant_type): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = residual_block(2) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + exe = fluid.Executor(fluid.CPUPlace()) + graph = IrGraph(core.Graph(main.desc), for_test=False) + transform_pass = QuantizationTransformPass( + scope=fluid.global_scope(), + program_exe=exe, + activation_quantize_type=quant_type) + transform_pass.apply(graph) + marked_nodes = set() + for op in graph.all_ops(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) + program = graph.to_program() + self.check_program(transform_pass, program) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) + val_marked_nodes = set() + for op in val_graph.all_ops(): + if op.name().find('quantize') > -1: + val_marked_nodes.add(op) + val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) + + def test_residual_block_abs_max(self): + self.act_quant_op_type = 'fake_quantize_abs_max' + self.residual_block_quant('abs_max') + + def test_residual_block_range_abs_max(self): + self.act_quant_op_type = 'fake_quantize_range_abs_max' + self.residual_block_quant('range_abs_max') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 77239f2e8ee5243611352d275c0d595670e24185..fc5e471ae3041b0245c873d85efa8b49f3a43678 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -23,6 +23,7 @@ import traceback import six import numpy as np +import subprocess from .. import compat as cpt from .proto import framework_pb2 @@ -1512,6 +1513,154 @@ class Block(object): return ret_var +class IrGraph(object): + """ + IrGraph uses core.Graph as the delegation to accomplish the manipulation. + """ + + def __init__(self, graph, for_test=False): + """ + Construct the IrGraph using core.Graph. + Args: + graph(core.Graph): C++ Graph. + for_test(bool): True for the test graph and false for the train graph. + """ + assert isinstance( + graph, core.Graph), 'graph must be the instance of core.Graph.' + self.graph = graph + self._for_test = for_test + + def is_test(self): + return self._for_test + + def all_parameters(self): + param_nodes = set() + for node in self.graph.nodes(): + if node.is_var() and node.var() is not None and node.var( + ).persistable(): + param_nodes.add(node) + return param_nodes + + def all_vars(self): + return {node for node in self.graph.nodes() if node.is_var()} + + def all_ops(self): + return {node for node in self.graph.nodes() if node.is_op()} + + def create_param_node(self, name, var_type, shape, var_dtype): + var_desc = core.VarDesc(name) + var_desc.set_type(var_type) + var_desc.set_shape(shape) + var_desc.set_dtype(var_dtype) + var_desc.set_persistable(True) + return self.graph.create_var_node(var_desc) + + def create_var_node(self, name, var_type, shape, var_dtype): + var_desc = core.VarDesc(name) + var_desc.set_type(var_type) + var_desc.set_shape(shape) + var_desc.set_dtype(var_dtype) + return self.graph.create_var_node(var_desc) + + def create_var_node_from_desc(self, var_desc): + return self.graph.create_var_node(var_desc) + + def create_op_node(self, op_type, attrs, inputs, outputs): + op_desc = core.OpDesc() + op_desc.set_type(op_type) + for attr, value in attrs.iteritems(): + self._update_desc_attr(op_desc, attr, value) + for input_name, var_nodes in inputs.iteritems(): + if not isinstance(var_nodes, list): + var_nodes = [var_nodes] + op_desc.set_input(input_name, + [var_node.name() for var_node in var_nodes]) + for output_name, var_nodes in outputs.iteritems(): + if not isinstance(var_nodes, list): + var_nodes = [var_nodes] + op_desc.set_output(output_name, + [var_node.name() for var_node in var_nodes]) + return self.graph.create_op_node(op_desc) + + def create_op_node_from_desc(self, op_desc): + return self.graph.create_op_node(op_desc) + + def update_input_link(self, old_input_node, new_input_node, op_node): + assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \ + op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.' + old_input_node.outputs_remove(op_node) + op_node.inputs_remove(old_input_node) + new_input_node.outputs_append(op_node) + op_node.inputs_append(new_input_node) + op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) + + def link_to(self, node_in, node_out): + assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ + 'Th two arguments must be in the graph nodes.' + node_in.outputs_append(node_out) + node_out.inputs_append(node_in) + + def safe_remove_nodes(self, remove_nodes): + if not isinstance(remove_nodes, set): + remove_nodes = set(remove_nodes) + core.graph_safe_remove_nodes(self.graph, remove_nodes) + + def draw(self, save_path, name, marked_nodes=None): + def _convert_to_pdf(dot_file_path): + pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' + exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ + + ' -o ' + pdf_save_path, shell=True) + if exited_code != 0: + print('The dot command is needed for creating pdf files.') + print('The {} is saved as the dot filetype.'.format( + dot_file_path)) + + remove_ctr_vars = set() + ops_num = 0 + for node in self.graph.nodes(): + if node.is_ctrl_var(): + remove_ctr_vars.add(node) + elif node.is_op(): + ops_num += 1 + print('Total ops num = {}.'.format(ops_num)) + self.safe_remove_nodes(remove_ctr_vars) + if marked_nodes is not None: + if not isinstance(marked_nodes, set): + marked_nodes = set(marked_nodes) + marked_nodes = marked_nodes - remove_ctr_vars + if self.graph.has('__graphviz__marked_node__'): + self.graph.erase('__graphviz__marked_node__') + self.graph.set('__graphviz__marked_node__', marked_nodes) + viz_dot_path = os.path.join(save_path, name) + '.dot' + viz_pass = core.get_pass('graph_viz_pass') + viz_pass.set('graph_viz_path', viz_dot_path) + viz_pass.apply(self.graph) + _convert_to_pdf(viz_dot_path) + + def to_program(self): + convert_pass = core.get_pass('graph_to_program_pass') + convert_pass.set('program', Program().desc) + convert_pass.apply(self.graph) + desc = convert_pass.get_program('program') + program = Program._construct_from_desc(desc) + return program + + def _update_desc_attr(self, desc, name, val): + """ + Update the value of desc's attribute by attribute's name. + """ + if isinstance(val, Block): + desc.set_block_attr(name, val.desc) + elif isinstance(val, list) and val and all( + isinstance(v, Block) for v in val): + desc.set_blocks_attr(name, [v.desc for v in val]) + elif isinstance(val, core.BlockDesc) or \ + isinstance(val, core.ProgramDesc): + desc.set_serialized_attr(name, val.serialize_to_string()) + else: + desc._set_attr(name, val) + + class Program(object): """ Python Program. Beneath it is a ProgramDesc, which is used for @@ -1936,6 +2085,23 @@ class Program(object): p._sync_with_cpp() return p + @staticmethod + def _construct_from_desc(desc): + """ + Construct a program from program desc. + + Args: + desc(core.ProgramDesc): The program desc for constructing. + + Returns: + Program: A program. + """ + p = Program() + p.desc = desc + p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())] + p._sync_with_cpp() + return p + @property def random_seed(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 3fcdc57906c214bdc8179c55b576e2e9e8d80973..69a38618cde7f100849a30e1cb1a2f4738e55e0d 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -123,7 +123,7 @@ class TestDistRunnerBase(object): pass_builder = build_stra._finalize_strategy_and_create_passes() mypass = pass_builder.insert_pass( len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") - mypass.set_int("num_repeats", args.batch_merge_repeat) + mypass.set("num_repeats", args.batch_merge_repeat) if args.update_method == "nccl2": build_stra.num_trainers = len(args.endpoints.split(",")) diff --git a/python/paddle/fluid/tests/unittests/test_pass_builder.py b/python/paddle/fluid/tests/unittests/test_pass_builder.py index 8c9e489e02839e25cfabe14c16bfd91a908bd734..7e1c2572f08598b8b600517e4a82b48ca71cc20d 100644 --- a/python/paddle/fluid/tests/unittests/test_pass_builder.py +++ b/python/paddle/fluid/tests/unittests/test_pass_builder.py @@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase): pass_builder.remove_pass(len(pass_builder.all_passes()) - 1) self.assertEqual(origin_len + 1, len(pass_builder.all_passes())) - viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass") + viz_pass.set("graph_viz_path", "/tmp/test_viz_pass") self.check_network_convergence( use_cuda=core.is_compiled_with_cuda(), diff --git a/python/setup.py.in b/python/setup.py.in index e00c88b3a6e49bd806bcc6125fa8dc0f69928227..fb4b273a0676fcbcb4402eaf54ddf73d37a2754f 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -113,6 +113,7 @@ packages=['paddle', 'paddle.fluid.contrib.slim.core', 'paddle.fluid.contrib.slim.graph', 'paddle.fluid.contrib.slim.prune', + 'paddle.fluid.contrib.slim.quantization', 'paddle.fluid.contrib.utils', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details']