提交 ec88b6cc 编写于 作者: Z Zhen Wang

add channel wise quantization in ir pass.

上级 81b4fad8
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -54,10 +55,6 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -54,10 +55,6 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto scales = ctx.MultiInput<framework::Tensor>("Scales"); auto scales = ctx.MultiInput<framework::Tensor>("Scales");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X).");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits"); auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
int max_range = std::pow(2, quant_bits[0] - 1) - 1; int max_range = std::pow(2, quant_bits[0] - 1) - 1;
...@@ -65,15 +62,38 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -65,15 +62,38 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
auto dequant = DequantizeFunctor<DeviceContext, T>(); auto dequant = DequantizeFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) { if (scales.size() == 1) {
framework::Tensor one_channel_in = in->Slice(i, i + 1); PADDLE_ENFORCE_EQ(
framework::Tensor one_channel_out = out->Slice(i, i + 1); scales[0]->numel(), in->dims()[0],
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); "The number of first scale values must be the same with "
dequant(dev_ctx, &one_channel_in, &one_channel_scale, "first dimension value of Input(X) when the `Scales` has only one "
static_cast<T>(max_range), &one_channel_out); "element.");
} for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
if (scales.size() == 2) { framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
} else if (scales.size() == 2) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements.");
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int64_t j = 0; j < in->dims()[1]; j++) {
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1, scales[1]->numel(), 1,
"The second scale tensor should only have one value at now."); "The second scale tensor should only have one value at now.");
......
...@@ -169,10 +169,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -169,10 +169,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
ctx->HasOutput("Out"), ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null."); "Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("OutScales"), ctx->HasOutput("OutScale"),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); "Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -192,7 +192,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -192,7 +192,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddOutput("Out", AddOutput("Out",
"(Tensor) Output of quantized low level tensor, " "(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."); "but also saved as float data type.");
AddOutput("OutScales", "(Tensor) Current channel wise scale"); AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
......
...@@ -78,8 +78,8 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -78,8 +78,8 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto* out_scales = context.Output<framework::Tensor>("OutScales"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace()); T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
...@@ -91,13 +91,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -91,13 +91,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
framework::Tensor one_channel = in->Slice(i, i + 1); framework::Tensor one_channel = in->Slice(i, i + 1);
const T* one_channel_data = one_channel.data<T>(); const T* one_channel_data = one_channel.data<T>();
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(), find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scales_data[i]); &out_scale_data[i]);
} }
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>(); auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) { for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1); framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1); framework::Tensor one_channel_scale = out_scale->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt, clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out); &one_channel_out);
} }
......
...@@ -22,6 +22,7 @@ from ....framework import IrGraph ...@@ -22,6 +22,7 @@ from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
from ....framework import Program from ....framework import Program
from ....initializer import Constant from ....initializer import Constant
from ....initializer import NumpyArrayInitializer
from .... import unique_name from .... import unique_name
__all__ = [ __all__ = [
...@@ -54,14 +55,15 @@ class QuantizationTransformPass(object): ...@@ -54,14 +55,15 @@ class QuantizationTransformPass(object):
the bias is not quantized. the bias is not quantized.
activation_bits (int): quantization bit number for activation. activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation, activation_quantize_type (str): quantization type for activation,
now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
the quantization scale will be calculated dynamically each step If use 'abs_max' mode, the quantization scale will be calculated
in both training and testing period. If use 'range_abs_max', dynamically each step in both training and testing period. If use
a static quantization scale will be calculated during training 'range_abs_max', a static quantization scale will be calculated
and used in inference. during training and used in inference.
weight_quantize_type (str): quantization type for weights, weight_quantize_type (str): quantization type for weights,
support 'abs_max'. The 'range_abs_max' usually is not used for support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
weight, since weights are fixed once the model is well trained. usually is not used for weight, since weights are fixed once the
model is well trained.
window_size (int): the window size for 'range_abs_max' quantization. window_size (int): the window size for 'range_abs_max' quantization.
Examples: Examples:
...@@ -84,7 +86,11 @@ class QuantizationTransformPass(object): ...@@ -84,7 +86,11 @@ class QuantizationTransformPass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max'] quant_type = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type: if activation_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ", "Unknown activation_quantize_type : '%s'. It can only be ",
...@@ -93,7 +99,7 @@ class QuantizationTransformPass(object): ...@@ -93,7 +99,7 @@ class QuantizationTransformPass(object):
if weight_quantize_type not in quant_type: if weight_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ", "Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type)) str(weight_quantize_type))
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
...@@ -103,6 +109,7 @@ class QuantizationTransformPass(object): ...@@ -103,6 +109,7 @@ class QuantizationTransformPass(object):
self._need_initialized = collections.OrderedDict() self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
...@@ -135,10 +142,26 @@ class QuantizationTransformPass(object): ...@@ -135,10 +142,26 @@ class QuantizationTransformPass(object):
else self._activation_bits else self._activation_bits
quant_type = self._weight_quantize_type if var_node.name() \ quant_type = self._weight_quantize_type if var_node.name() \
in persistable_vars else self._activation_quantize_type in persistable_vars else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op( if quant_type == 'channel_wise_abs_max':
graph, var_node, quant_bits, quant_type) assert var_node.name(
dequant_var_node = self._insert_dequant_op( ) in persistable_vars, "'channel_wise_abs_max' can only be applied on weights."
graph, quant_var_node, scale_var_node, quant_bits) if op.name() in self._conv_ops:
quant_var_node, scale_var_node = self._insert_channel_quant_op(
graph, var_node, quant_bits)
dequant_var_node = self._insert_channel_dequant_op(
graph, quant_var_node, [scale_var_node],
[quant_bits])
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, 'abs_max')
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node,
quant_bits)
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node dequantized_vars[var_node.name()] = dequant_var_node
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
...@@ -244,7 +267,7 @@ class QuantizationTransformPass(object): ...@@ -244,7 +267,7 @@ class QuantizationTransformPass(object):
scale_var_node = graph.create_var_node( scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
...@@ -384,6 +407,36 @@ class QuantizationTransformPass(object): ...@@ -384,6 +407,36 @@ class QuantizationTransformPass(object):
return quant_var_node, scale_out_node return quant_var_node, scale_out_node
def _insert_channel_quant_op(self, graph, var_node, quant_bits):
"""
Insert fake_channel_wise_quantize_abs_max op in the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()),
var_type=var_node.type(),
shape=[var_node.shape()[0]],
var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
'OutScale': scale_var_node})
graph.link_to(var_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
""" """
Insert fake_dequantize_op in the graph. Insert fake_dequantize_op in the graph.
...@@ -410,6 +463,33 @@ class QuantizationTransformPass(object): ...@@ -410,6 +463,33 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node return dequant_var_node
def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
quant_bits):
"""
Insert fake_channel_wise_dequantize_max_abs in the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs',
attrs={
'quant_bits': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node,
'Scales': scale_var_nodes},
outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op_node)
for scale_n in scale_var_nodes:
graph.link_to(scale_n, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _quantized_var_name(self, var_name): def _quantized_var_name(self, var_name):
""" """
Return quantized variable name for the input `var_name`. Return quantized variable name for the input `var_name`.
...@@ -442,7 +522,7 @@ class QuantizationFreezePass(object): ...@@ -442,7 +522,7 @@ class QuantizationFreezePass(object):
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights. weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation. activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max'. weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained. model is well trained.
""" """
...@@ -463,11 +543,15 @@ class QuantizationFreezePass(object): ...@@ -463,11 +543,15 @@ class QuantizationFreezePass(object):
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = [ self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max', 'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max' 'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict() self._var_scale_map = collections.OrderedDict()
...@@ -489,20 +573,29 @@ class QuantizationFreezePass(object): ...@@ -489,20 +573,29 @@ class QuantizationFreezePass(object):
if self._weight_quantize_type == 'abs_max': if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name) param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param)) scale_v = np.max(np.abs(param))
elif self._weight_quantize_type == 'channel_wise_abs_max':
param = self._load_var(input_arg_name)
if len(param.shape) == 4: # conv2d or depthwise_conv2d
print('DEBUG**************************: %s' %
input_arg_name)
scale_v = []
for i in range(param.shape[0]):
scale_v.append(np.max(np.abs(param[i])))
else:
scale_v = np.max(np.abs(param))
else: else:
scale_v = self._load_var( scale_v = self._load_var(
op_node.output('OutScale')[0])[0] op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
else:
scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore # quantize weight and restore
param_v = self._load_var(input_arg_name) param_v = self._load_var(input_arg_name)
quantized_param_v = self._quant(param_v, scale_v, quantized_param_v = self._quant(param_v, scale_v,
self._weight_bits) self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
else:
scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
...@@ -514,7 +607,10 @@ class QuantizationFreezePass(object): ...@@ -514,7 +607,10 @@ class QuantizationFreezePass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
self._insert_post_dequant_op(graph, op_node) if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
for op_node in ops: for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops # insert dequant_op after fc/conv, need to rename inputs of the followed ops
...@@ -538,9 +634,73 @@ class QuantizationFreezePass(object): ...@@ -538,9 +634,73 @@ class QuantizationFreezePass(object):
self._op_input_rename_map[k] = self._op_input_rename_map[v] self._op_input_rename_map[k] = self._op_input_rename_map[v]
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
def _insert_post_channel_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
if original_var_name in persistable_vars:
assert isinstance(
scale_v,
list), 'The scale of parameter %s is not a list.' % (
original_var_name)
channel_scale = np.array(scale_v)
else:
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0]
weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype())
init_program = Program()
weight_scale_var = init_program.global_block().create_var(
name=weight_scale_node.name(),
shape=weight_scale_node.shape(),
dtype=weight_scale_node.dtype(),
type=weight_scale_node.type(),
lod_level=weight_scale_node.var().lod_level(),
persistable=weight_scale_node.persistable())
initializer = NumpyArrayInitializer(value=channel_scale)
initializer(weight_scale_var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
shape=output_var_node.shape(),
var_dtype=output_var_node.dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs',
attrs={
'quant_bits': [self._weight_bits, self._activation_bits],
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={
'X': output_var_node,
'Scales': [weight_scale_node, scale_var_node]
},
outputs={'Out': dequant_var_node})
graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(weight_scale_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
return dequant_var_node
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
max_range = None
scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
...@@ -637,7 +797,12 @@ class QuantizationFreezePass(object): ...@@ -637,7 +797,12 @@ class QuantizationFreezePass(object):
or isinstance(v, np.float64) or isinstance(v, np.float64)
def _quant(self, x, scale, num_bits): def _quant(self, x, scale, num_bits):
return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) if isinstance(scale, list):
for i, s in enumerate(scale):
x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
return x
else:
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
...@@ -731,9 +896,13 @@ class TransformForMobilePass(object): ...@@ -731,9 +896,13 @@ class TransformForMobilePass(object):
def __init__(self): def __init__(self):
self._fake_quant_op_names = [ self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max' 'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
def apply(self, graph): def apply(self, graph):
""" """
......
...@@ -243,7 +243,12 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -243,7 +243,12 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, place=place, activation_quantize_type=quant_type) scope=scope,
place=place,
activation_quantize_type=quant_type,
weight_quantize_type='channel_wise_abs_max')
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
...@@ -296,7 +301,11 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -296,7 +301,11 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list=[loss, w_var]) fetch_list=[loss, w_var])
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type='channel_wise_abs_max')
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -375,29 +384,32 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -375,29 +384,32 @@ class TestQuantizationFreezePass(unittest.TestCase):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
True, seed=1, quant_type='abs_max', for_ci=True) True, seed=1, quant_type='abs_max', for_ci=False)
def test_freeze_graph_cpu_dynamic(self): def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True) self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=False)
def test_freeze_graph_cuda_static(self): def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
True, seed=1, quant_type='range_abs_max', for_ci=True) True, seed=1, quant_type='range_abs_max', for_ci=False)
self.freeze_graph( self.freeze_graph(
True, True,
seed=1, seed=1,
quant_type='moving_average_abs_max', quant_type='moving_average_abs_max',
for_ci=True) for_ci=False)
def test_freeze_graph_cpu_static(self): def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
False, seed=2, quant_type='range_abs_max', for_ci=True) False, seed=2, quant_type='range_abs_max', for_ci=False)
self.freeze_graph( self.freeze_graph(
False, seed=2, quant_type='moving_average_abs_max', for_ci=True) False,
seed=2,
quant_type='moving_average_abs_max',
for_ci=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range): ...@@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range):
return y return y
def channel_wise_quantize_max_abs(x, quant_bit=8): def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False):
scales = [] scales = []
for i in range(x.shape[0]): if not use_second_dim:
scales.append(np.max(np.abs(x[i])).astype("float32")) for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy() y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1 max_range = math.pow(2, quant_bit - 1) - 1
for i, scale in enumerate(scales): for i, scale in enumerate(scales):
y[i] = np.round(y[i] / scale * max_range) y[i] = np.round(x[i] / scale * max_range)
else:
for i in range(x.shape[0]):
s = []
for j in range(x.shape[1]):
s.append(np.max(np.abs(x[i][j])).astype("float32"))
scales.append(s)
scales = np.amax(np.array(scales), axis=0)
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i in range(x.shape[0]):
for j, scale in enumerate(scales):
y[i][j] = np.round(x[i][j] / scale * max_range)
return y, scales return y, scales
...@@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x, ...@@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x,
scales, scales,
quant_bits, quant_bits,
activation_scale=None): activation_scale=None):
y = x.copy() if activation_scale is None:
for i in range(x.shape[0]): y = x.copy()
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i] for i in range(x.shape[0]):
if activation_scale is not None: y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i]
else:
y = x.copy()
for i in range(x.shape[0]):
for j in range(x.shape[1]):
y[i][j] = (scales[j] /
(math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j]
y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
return y return y
...@@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): ...@@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.set_args() self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs" self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type) x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) yq, scales = channel_wise_quantize_max_abs(
x, self.quant_bits[0], use_second_dim=True)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.activation_scale) self.activation_scale)
......
...@@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest): ...@@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest):
self.outputs = { self.outputs = {
'Out': outputs, 'Out': outputs,
'OutScales': np.array(scales).astype("float32"), 'OutScale': np.array(scales).astype("float32"),
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册