未验证 提交 58727e8e 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #15455 from wzzju/graph_quantization

Graph quantization pass. TODO(Add public API comments.)
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
......@@ -243,3 +244,4 @@ USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(graph_to_program_pass);
......@@ -28,10 +28,14 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
attr);
}
auto* native_graph = graph.get();
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(applied_graph.get() == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)");
applied_ = true;
return applied_graph;
}
......
......@@ -15,7 +15,9 @@
#include "paddle/fluid/pybind/ir.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
......@@ -24,6 +26,7 @@
namespace py = pybind11;
using paddle::framework::ir::Graph;
using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
......@@ -32,6 +35,7 @@ using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
......@@ -42,6 +46,8 @@ void BindGraph(py::module *m) {
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
.def("set",
......@@ -57,6 +63,17 @@ void BindGraph(py::module *m) {
[](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
return self.Set(attr_name,
new std::unordered_set<const Node *>(attr));
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node",
......@@ -85,12 +102,52 @@ void BindNode(py::module *m) {
py::class_<Node> node(*m, "Node");
node.def("name", &Node::Name)
.def("node_type", &Node::NodeType)
.def("var", &Node::Var)
.def("op", &Node::Op)
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("inputs_remove",
[](Node &self, int node_id) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if ((*it)->id() == node_id) {
self.inputs.erase(it);
}
}
})
.def("inputs_remove",
[](Node &self, Node &node) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if (*it == &node) {
self.inputs.erase(it);
}
}
})
.def("inputs_append",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("outputs_remove",
[](Node &self, int node_id) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if ((*it)->id() == node_id) {
self.outputs.erase(it);
}
}
})
.def("outputs_remove",
[](Node &self, Node &node) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if (*it == &node) {
self.outputs.erase(it);
}
}
})
.def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
......
......@@ -228,7 +228,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc
var_desc.def(pybind11::init<const std::string &>())
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape)
......
......@@ -788,21 +788,33 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const py::bytes &binary_str) {
std::string pass_type(binary_str);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
.def("set",
[](ir::Pass &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def(
"set_str",
"set",
[](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set_int", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get());
auto optim_graph = self.Apply(std::move(origin_graph));
graph.reset(optim_graph.release());
optim_graph.release();
});
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
......
......@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import subprocess
from ....framework import Program
from ....framework import Block
from .... import core
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import quantization_pass
from .quantization_pass import *
__all__ = quantization_pass.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from .... import core
from ....framework import IrGraph
from ....framework import Program
from ....framework import Variable
from ....initializer import Constant
from .... import unique_name
__all__ = ['QuantizationTransformPass']
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
program_exe=None,
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
window_size=10000):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
Args:
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation,
now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode,
the quantization scale will be calculated dynamically each step
in both training and testing period. If use 'range_abs_max',
a static quantization scale will be calculated during training
and used in inference.
weight_quantize_type (str): quantization type for weights,
support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
.. code-block:: python
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace())
transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe)
transform_pass.apply(graph)
"""
self._scope = scope
self._program_exe = program_exe
self._weight_bits = weight_bits
self._activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max']
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(weight_quantize_type))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._window_size = window_size
self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
]
self._fake_quant_op_types = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self._fake_dequant_op_types = ['fake_dequantize_max_abs']
self._is_test = None
self._global_step = None
def apply(self, graph):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear()
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()]
def _transform_forward(graph, op):
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
else:
quant_bits = self._weight_bits if var_node.name() in params \
else self._activation_bits
quant_type = self._weight_quantize_type if var_node.name() \
in params else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node
graph.update_input_link(var_node, dequant_var_node, op)
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.name()))
if not self._is_test:
self._create_global_step(graph)
ops = graph.all_ops()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op)
if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in self._need_initialized.iteritems():
var = Variable(init_program.global_block())
var._set_desc(var_desc)
initializer(var, init_program.global_block())
self._program_exe.run(program=init_program, scope=self._scope)
return graph
def _create_global_step(self, graph):
if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max':
counter_name = '@STEP_COUNTER@'
for node in graph.all_vars():
if node.name() == counter_name:
self._global_step = node
if self._global_step is None:
global_step_in = graph.create_param_node(
name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self._need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
increment_op = graph.create_op_node(
op_type='increment',
attrs={'step': 1.0},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
graph.link_to(global_step_in, increment_op)
graph.link_to(increment_op, global_step_out)
self._global_step = global_step_out
def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
"""
Insert fake_quantize_op in the graph.
"""
if quant_type == 'abs_max':
return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
elif quant_type == 'range_abs_max':
return self._insert_quant_range_abs_max_op(graph, var_node,
quant_bits)
def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
"""
Insert fake_quantize_abs_max op in the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={'bit_length': quant_bits},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
'OutScale': scale_var_node})
graph.link_to(var_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node
def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
"""
Insert fake_quantize_range_abs_max on the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
scale_in_node = graph.create_param_node(
name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.var().dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}
if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node(
name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size],
var_dtype=var_node.var().dtype())
self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
attrs = {
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self._is_test
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs)
graph.link_to(var_node, quant_op_node)
graph.link_to(scale_in_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_out_node)
if not self._is_test:
graph.link_to(self._global_step, quant_op_node)
graph.link_to(quant_op_node, scales_node)
return quant_var_node, scale_out_node
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
"""
Insert fake_dequantize_op in the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)},
inputs={'X': var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
"""
return "%s.quantized" % (var_name)
def _dequantized_var_name(self, var_name):
"""
Return dequantized variable name for the input `var_name`.
"""
return "%s.dequantized" % (var_name)
def _quantized_scale_name(self, var_name):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import random
import numpy as np
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid import core
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self):
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y']
}
self.quantizable_grad_op_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y']
}
def check_program(self, transform_pass, program):
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
# check forward
if op.type in self.quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
quantized_ops.add(arg_name)
for op in block.ops:
# check backward
if op.type in self.quantizable_grad_op_inputs:
for pname in self.quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
if __name__ == '__main__':
unittest.main()
......@@ -23,6 +23,7 @@ import traceback
import six
import numpy as np
import subprocess
from .. import compat as cpt
from .proto import framework_pb2
......@@ -1512,6 +1513,154 @@ class Block(object):
return ret_var
class IrGraph(object):
"""
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph, for_test=False):
"""
Construct the IrGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self._for_test = for_test
def is_test(self):
return self._for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc)
return program
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
"""
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
class Program(object):
"""
Python Program. Beneath it is a ProgramDesc, which is used for
......@@ -1936,6 +2085,23 @@ class Program(object):
p._sync_with_cpp()
return p
@staticmethod
def _construct_from_desc(desc):
"""
Construct a program from program desc.
Args:
desc(core.ProgramDesc): The program desc for constructing.
Returns:
Program: A program.
"""
p = Program()
p.desc = desc
p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp()
return p
@property
def random_seed(self):
"""
......
......@@ -123,7 +123,7 @@ class TestDistRunnerBase(object):
pass_builder = build_stra._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
mypass.set("num_repeats", args.batch_merge_repeat)
if args.update_method == "nccl2":
build_stra.num_trainers = len(args.endpoints.split(","))
......
......@@ -111,7 +111,7 @@ class TestPassBuilder(unittest.TestCase):
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
viz_pass.set_str("graph_viz_path", "/tmp/test_viz_pass")
viz_pass.set("graph_viz_path", "/tmp/test_viz_pass")
self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(),
......
......@@ -113,6 +113,7 @@ packages=['paddle',
'paddle.fluid.contrib.slim.core',
'paddle.fluid.contrib.slim.graph',
'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册