diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 68c7227e5a7123e1e751dd55e243ee481bf36540..4a8937ba1c7ef9827ecc9bf575d9893c95a3b22b 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -33,8 +33,51 @@ struct DequantizeFunctor { } }; +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor** scales, + const int scale_num, T max_range, framework::Tensor* out) { + if (scale_num == 1) { + const int channel = in->dims()[0]; + const T* scale_factor = scales[0]->data(); + for (int i = 0; i < channel; i++) { + T s = scale_factor[i]; + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = (s / max_range) * in_e; + } + } else if (scale_num == 2) { + int batch_size = in->dims()[0]; + int channel = in->dims()[1]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + for (int i = 0; i < batch_size; i++) { + framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( + framework::slice_ddim(in->dims(), 1, in->dims().size())); + framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( + framework::slice_ddim(out->dims(), 1, out->dims().size())); + for (int j = 0; j < channel; j++) { + T s = scale_one[j]; + framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); + framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = (s * scale_two[0] / max_range) * in_e; + } + } + } + } +}; + template struct DequantizeFunctor; template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { public: diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 35dcc69279d0119e75c4c5072e7817c839b9e819..02f9dc827d68cbb58447ed1557ff4bf310b2c017 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -44,8 +44,66 @@ struct DequantizeFunctor { } }; +template +__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, + int num, int channel, T* out) { + int tid = threadIdx.x; + int channel_size = num / channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; + } +} + +template +__global__ void DequantizeTwoScale(const T* in, const T* scale_one, + const T* scale_two, T max_range, int num, + int batch_size, int channel, T* out) { + int tid = threadIdx.x; + int channel_size = num / (batch_size * channel); + int scale_index = blockIdx.x % channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range; + } +} + +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor** scales, + const int scale_num, T max_range, framework::Tensor* out) { + const T* in_data = in->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + if (scale_num == 1) { + int num = in->numel(); + int channel = in->dims()[0]; + const T* scale_factor = scales[0]->data(); + int block = 1024; + int grid = channel; + DequantizeOneScale<<>>( + in_data, scale_factor, max_range, num, channel, out_data); + } else if (scale_num == 2) { + int num = in->numel(); + int batch_size = in->dims()[0]; + int channel = in->dims()[1]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + int block = 1024; + int grid = batch_size * channel; + DequantizeTwoScale<<>>( + in_data, scale_one, scale_two, max_range, num, batch_size, channel, + out_data); + } + } +}; + template struct DequantizeFunctor; template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index d05f2038531bbe9c35da54c94d2ef4d659acca70..ed9a0a4d65fab5ce1ef48835c332fade978d2bae 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -28,6 +29,13 @@ struct DequantizeFunctor { framework::Tensor* out); }; +template +struct ChannelDequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor** scales, const int scale_num, + T max_range, framework::Tensor* out); +}; + template class FakeDequantizeMaxAbsKernel : public framework::OpKernel { public: @@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { auto scales = ctx.MultiInput("Scales"); auto* out = ctx.Output("Out"); - PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0], - "The number of first scale values must be the same with " - "first dimension value of Input(X)."); - auto quant_bits = ctx.Attr>("quant_bits"); - int max_range = std::pow(2, quant_bits[0] - 1) - 1; + int max_range = 1; auto& dev_ctx = ctx.template device_context(); out->mutable_data(dev_ctx.GetPlace()); - - auto dequant = DequantizeFunctor(); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); - dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(max_range), &one_channel_out); - } - - if (scales.size() == 2) { + int scale_num = scales.size(); + if (scale_num == 1) { + PADDLE_ENFORCE_EQ( + scales[0]->numel(), in->dims()[0], + "The number of first scale values must be the same with " + "first dimension value of Input(X) when the `Scales` has only one " + "element."); + max_range *= (std::pow(2, quant_bits[0] - 1) - 1); + } else if (scale_num == 2) { + PADDLE_ENFORCE_EQ( + scales[0]->numel(), in->dims()[1], + "The number of first scale values must be the same with " + "second dimension value of Input(X) when the `Scales` has two " + "elements."); PADDLE_ENFORCE_EQ( scales[1]->numel(), 1, "The second scale tensor should only have one value at now."); - max_range = std::pow(2, quant_bits[1] - 1) - 1; - dequant(dev_ctx, out, scales[1], static_cast(max_range), out); + max_range *= (std::pow(2, quant_bits[0] - 1) - 1) * + (std::pow(2, quant_bits[1] - 1) - 1); } + ChannelDequantizeFunctor()( + dev_ctx, in, scales.data(), scale_num, static_cast(max_range), out); } }; diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index d51d51b4953073e9a350806f041bb3112fad239c..054ef4658cc0c4448d49870849017d3191d57db9 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -37,6 +37,21 @@ struct FindAbsMaxFunctor { template struct FindAbsMaxFunctor; +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* in, + const int num, const int channel, T* out) { + const int channel_size = num / channel; + for (int i = 0; i < channel; i++) { + auto* start = in + i * channel_size; + auto* end = in + (i + 1) * channel_size; + out[i] = std::abs(*(std::max_element(start, end, Compare()))); + } + } +}; + +template struct FindChannelAbsMaxFunctor; + template struct ClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, @@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor { template struct ClipAndFakeQuantFunctor; +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int channel, + framework::Tensor* out) { + auto* scale_data = scale.data(); + auto* in_data = in.data(); + auto* out_data = out->mutable_data(ctx.GetPlace()); + const int channel_size = in.numel() / channel; + platform::Transform trans; + for (int i = 0; i < channel; i++) { + T s = scale_data[i]; + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + trans(ctx, start, end, out_data + i * channel_size, + ClipFunctor(-s, s)); + } + for (int i = 0; i < channel; i++) { + T s = scale_data[i]; + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round(); + } + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + template struct FindRangeAbsMaxFunctor { void operator()(const platform::CPUDeviceContext& ctx, @@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ctx->HasOutput("Out"), "Output(Out) of FakeChannelWiseQuantizeOp should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("OutScales"), - "Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); + ctx->HasOutput("OutScale"), + "Output(Scale) of FakeChannelWiseQuantizeOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); + ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker AddOutput("Out", "(Tensor) Output of quantized low level tensor, " "but also saved as float data type."); - AddOutput("OutScales", "(Tensor) Current channel wise scale"); + AddOutput("OutScale", "(Tensor) Current channel wise scale"); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 3707f6772eac0d568c170d60c17d431e254d0b6b..33bd275e5cc507ec700b3694cd8b1df9672ec512 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -74,6 +74,45 @@ struct FindAbsMaxFunctor { template struct FindAbsMaxFunctor; +template +__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, + T* out) { + int tid = threadIdx.x; + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + extern __shared__ T shared_max_data[]; + shared_max_data[tid] = T(0); + for (int i = tid; i < channel_size; i += blockDim.x) { + T tmp = fabs(in_c[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + __syncthreads(); + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, const int channel, T* out) { + int block = 1024; + int grid = channel; + FindChannelAbsMaxKernel<<>>( + in, num, channel, out); + } +}; + +template struct FindChannelAbsMaxFunctor; + template __global__ void ClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, T* out) { @@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, T s = scale[0]; for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[bid]; + T x = in[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt / s * v; - out[bid] = round(v); + out[i] = round(v); } } +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); + } +}; + +template struct ClipAndFakeQuantFunctor; + +template +__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, + const int c, T* out) { + int tid = threadIdx.x; + + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + + T s = scale[blockIdx.x]; + for (int i = tid; i < channel_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt / s * v; + out_c[i] = round(v); + } +} + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int channel, + framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = channel; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ChannelClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, channel, out_data); + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + template __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, const T* last_scale, @@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor { template struct FindMovingAverageAbsMaxFunctor; -template -struct ClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = (block - 1 + num) / block; - - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - ClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); - } -}; - -template struct ClipAndFakeQuantFunctor; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index ec667e89e7699d87db9423f17014a2761ce62763..5ab38b086df7f9df33996ec83b5ec07047c204ba 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor { framework::Tensor* scales_arr, framework::Tensor* out_scale); }; +template +struct FindChannelAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const T* in, const int num, + const int channel, T* out); +}; + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + const int channel, framework::Tensor* out); +}; + template struct FindMovingAverageAbsMaxFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, @@ -78,29 +91,18 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); - auto* out_scales = context.Output("OutScales"); - T* out_scales_data = out_scales->mutable_data(context.GetPlace()); + auto* out_scale = context.Output("OutScale"); + T* out_scale_data = out_scale->mutable_data(context.GetPlace()); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); int bin_cnt = std::pow(2, bit_length - 1) - 1; auto& dev_ctx = context.template device_context(); - auto find_abs_max = FindAbsMaxFunctor(); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel = in->Slice(i, i + 1); - const T* one_channel_data = one_channel.data(); - find_abs_max(dev_ctx, one_channel_data, one_channel.numel(), - &out_scales_data[i]); - } - auto clip_quant = ClipAndFakeQuantFunctor(); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1); - clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt, - &one_channel_out); - } + FindChannelAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), in->dims()[0], out_scale_data); + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out); } }; diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 919db4c78e52edc9a8be44744f4b7704e3f62de4..5dcef506711b78c2aef30d16719f8766359ae8f3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -22,6 +22,7 @@ from ....framework import IrGraph from ....framework import IrNode from ....framework import Program from ....initializer import Constant +from ....initializer import NumpyArrayInitializer from .... import unique_name __all__ = [ @@ -54,14 +55,15 @@ class QuantizationTransformPass(object): the bias is not quantized. activation_bits (int): quantization bit number for activation. activation_quantize_type (str): quantization type for activation, - now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, - the quantization scale will be calculated dynamically each step - in both training and testing period. If use 'range_abs_max', - a static quantization scale will be calculated during training - and used in inference. + now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. + If use 'abs_max' mode, the quantization scale will be calculated + dynamically each step in both training and testing period. If use + 'range_abs_max', a static quantization scale will be calculated + during training and used in inference. weight_quantize_type (str): quantization type for weights, - support 'abs_max'. The 'range_abs_max' usually is not used for - weight, since weights are fixed once the model is well trained. + support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' + usually is not used for weight, since weights are fixed once the + model is well trained. window_size (int): the window size for 'range_abs_max' quantization. Examples: @@ -84,7 +86,11 @@ class QuantizationTransformPass(object): self._weight_bits = weight_bits self._activation_bits = activation_bits - quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max'] + quant_type = [ + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', + 'moving_average_abs_max' + ] + assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'." if activation_quantize_type not in quant_type: raise ValueError( "Unknown activation_quantize_type : '%s'. It can only be ", @@ -93,7 +99,7 @@ class QuantizationTransformPass(object): if weight_quantize_type not in quant_type: raise ValueError( "Unknown weight_quantize_type: '%s'. It can only be ", - "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", + "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", str(weight_quantize_type)) self._activation_quantize_type = activation_quantize_type @@ -103,6 +109,7 @@ class QuantizationTransformPass(object): self._need_initialized = collections.OrderedDict() self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] @@ -135,10 +142,26 @@ class QuantizationTransformPass(object): else self._activation_bits quant_type = self._weight_quantize_type if var_node.name() \ in persistable_vars else self._activation_quantize_type - quant_var_node, scale_var_node = self._insert_quant_op( - graph, var_node, quant_bits, quant_type) - dequant_var_node = self._insert_dequant_op( - graph, quant_var_node, scale_var_node, quant_bits) + if quant_type == 'channel_wise_abs_max': + assert var_node.name( + ) in persistable_vars, "'channel_wise_abs_max' can only be applied on weights." + if op.name() in self._conv_ops: + quant_var_node, scale_var_node = self._insert_channel_quant_op( + graph, var_node, quant_bits) + dequant_var_node = self._insert_channel_dequant_op( + graph, quant_var_node, [scale_var_node], + [quant_bits]) + else: + quant_var_node, scale_var_node = self._insert_quant_op( + graph, var_node, quant_bits, 'abs_max') + dequant_var_node = self._insert_dequant_op( + graph, quant_var_node, scale_var_node, + quant_bits) + else: + quant_var_node, scale_var_node = self._insert_quant_op( + graph, var_node, quant_bits, quant_type) + dequant_var_node = self._insert_dequant_op( + graph, quant_var_node, scale_var_node, quant_bits) dequantized_vars[var_node.name()] = dequant_var_node graph.update_input_link(var_node, dequant_var_node, op) @@ -244,7 +267,7 @@ class QuantizationTransformPass(object): scale_var_node = graph.create_var_node( name=self._quantized_scale_name(var_node.name()), var_type=var_node.type(), - shape=var_node.shape(), + shape=[1], var_dtype=var_node.dtype()) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', @@ -384,6 +407,36 @@ class QuantizationTransformPass(object): return quant_var_node, scale_out_node + def _insert_channel_quant_op(self, graph, var_node, quant_bits): + """ + Insert fake_channel_wise_quantize_abs_max op in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + scale_var_node = graph.create_var_node( + name=self._quantized_scale_name(var_node.name()), + var_type=var_node.type(), + shape=[var_node.shape()[0]], + var_dtype=var_node.dtype()) + quant_op_node = graph.create_op_node( + op_type='fake_channel_wise_quantize_abs_max', + attrs={ + 'bit_length': quant_bits, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': var_node}, + outputs={'Out': quant_var_node, + 'OutScale': scale_var_node}) + graph.link_to(var_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_var_node) + return quant_var_node, scale_var_node + def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): """ Insert fake_dequantize_op in the graph. @@ -410,6 +463,33 @@ class QuantizationTransformPass(object): graph.link_to(dequant_op_node, dequant_var_node) return dequant_var_node + def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, + quant_bits): + """ + Insert fake_channel_wise_dequantize_max_abs in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + dequant_op_node = graph.create_op_node( + op_type='fake_channel_wise_dequantize_max_abs', + attrs={ + 'quant_bits': quant_bits, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': var_node, + 'Scales': scale_var_nodes}, + outputs={'Out': dequant_var_node}) + graph.link_to(var_node, dequant_op_node) + for scale_n in scale_var_nodes: + graph.link_to(scale_n, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) + return dequant_var_node + def _quantized_var_name(self, var_name): """ Return quantized variable name for the input `var_name`. @@ -442,7 +522,7 @@ class QuantizationFreezePass(object): place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. weight_bits (int): quantization bit number for weights. activation_bits (int): quantization bit number for activation. - weight_quantize_type (str): quantization type for weights, support 'abs_max'. + weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. """ @@ -463,11 +543,15 @@ class QuantizationFreezePass(object): self._activation_bits = activation_bits self._weight_quantize_type = weight_quantize_type self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._fake_quant_op_names = [ 'fake_quantize_abs_max', 'fake_quantize_range_abs_max', - 'fake_quantize_moving_average_abs_max' + 'fake_quantize_moving_average_abs_max', + 'fake_channel_wise_quantize_abs_max' + ] + self._fake_dequant_op_names = [ + 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] - self._fake_dequant_op_names = ['fake_dequantize_max_abs'] self._op_input_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict() self._var_scale_map = collections.OrderedDict() @@ -489,20 +573,27 @@ class QuantizationFreezePass(object): if self._weight_quantize_type == 'abs_max': param = self._load_var(input_arg_name) scale_v = np.max(np.abs(param)) + elif self._weight_quantize_type == 'channel_wise_abs_max': + param = self._load_var(input_arg_name) + if len(param.shape) == 4: # conv2d or depthwise_conv2d + scale_v = [] + for i in range(param.shape[0]): + scale_v.append(np.max(np.abs(param[i]))) + else: + scale_v = np.max(np.abs(param)) else: scale_v = self._load_var( op_node.output('OutScale')[0])[0] self._var_scale_map[input_arg_name] = scale_v - else: - scale_v = graph.var_node(op_node.output('OutScale')[0]) - self._var_scale_map[input_arg_name] = scale_v - if input_arg_name in persistable_vars: self._remove_fake_quant_and_dequant_op(graph, op_node) # quantize weight and restore param_v = self._load_var(input_arg_name) quantized_param_v = self._quant(param_v, scale_v, self._weight_bits) self._restore_var(input_arg_name, quantized_param_v) + else: + scale_v = graph.var_node(op_node.output('OutScale')[0]) + self._var_scale_map[input_arg_name] = scale_v ops = graph.all_op_nodes() for op_node in ops: @@ -514,7 +605,10 @@ class QuantizationFreezePass(object): for op_node in ops: op_name = op_node.name() if op_name in self._quantizable_ops: - self._insert_post_dequant_op(graph, op_node) + if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops: + self._insert_post_channel_dequant_op(graph, op_node) + else: + self._insert_post_dequant_op(graph, op_node) for op_node in ops: # insert dequant_op after fc/conv, need to rename inputs of the followed ops @@ -538,9 +632,73 @@ class QuantizationFreezePass(object): self._op_input_rename_map[k] = self._op_input_rename_map[v] graph.safe_remove_nodes(op_node) + def _insert_post_channel_dequant_op(self, graph, op_node): + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] + for var_node in op_node.inputs: + name = var_node.name() + if name in self._op_input_rename_map: + old_in = graph.var_node(name) + new_in = graph.var_node(self._op_input_rename_map[name]) + new_in.clear_outputs() + graph.update_input_link(old_in, new_in, op_node) + original_var_name = self._original_var_name(name) + scale_v = self._var_scale_map[original_var_name] + if original_var_name in persistable_vars: + assert isinstance( + scale_v, + list), 'The scale of parameter %s is not a list.' % ( + original_var_name) + channel_scale = np.array(scale_v) + else: + assert isinstance(scale_v, IrNode) + scale_var_node = self._var_scale_map[original_var_name] + + if len(op_node.outputs) != 1: + raise ValueError("Only support one output, but op %s has" + " more than one output." % (op_node.name())) + + output_var_node = op_node.outputs[0] + weight_scale_node = graph.create_persistable_node( + name=unique_name.generate('channel_scale'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[channel_scale.shape[0]], + var_dtype=output_var_node.dtype()) + init_program = Program() + weight_scale_var = init_program.global_block().create_var( + name=weight_scale_node.name(), + shape=weight_scale_node.shape(), + dtype=weight_scale_node.dtype(), + type=weight_scale_node.type(), + lod_level=weight_scale_node.var().lod_level(), + persistable=weight_scale_node.persistable()) + initializer = NumpyArrayInitializer(value=channel_scale) + initializer(weight_scale_var, init_program.global_block()) + exe = Executor(self._place) + exe.run(program=init_program, scope=self._scope) + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(output_var_node.name()), + var_type=output_var_node.type(), + shape=output_var_node.shape(), + var_dtype=output_var_node.dtype()) + dequant_op_node = graph.create_op_node( + op_type='fake_channel_wise_dequantize_max_abs', + attrs={ + 'quant_bits': [self._weight_bits, self._activation_bits], + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={ + 'X': output_var_node, + 'Scales': [weight_scale_node, scale_var_node] + }, + outputs={'Out': dequant_var_node}) + graph.link_to(output_var_node, dequant_op_node) + graph.link_to(scale_var_node, dequant_op_node) + graph.link_to(weight_scale_node, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) + self._op_output_rename_map[output_var_node.name()] = dequant_var_node + return dequant_var_node + def _insert_post_dequant_op(self, graph, op_node): - max_range = None - scale_var_node = None persistable_vars = [p.name() for p in graph.all_persistable_nodes()] for var_node in op_node.inputs: name = var_node.name() @@ -637,7 +795,12 @@ class QuantizationFreezePass(object): or isinstance(v, np.float64) def _quant(self, x, scale, num_bits): - return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) + if isinstance(scale, list): + for i, s in enumerate(scale): + x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + return x + else: + return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) class ConvertToInt8Pass(object): @@ -731,9 +894,13 @@ class TransformForMobilePass(object): def __init__(self): self._fake_quant_op_names = [ - 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max', + 'fake_quantize_moving_average_abs_max', + 'fake_channel_wise_quantize_abs_max' + ] + self._fake_dequant_op_names = [ + 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] - self._fake_dequant_op_names = ['fake_dequantize_max_abs'] def apply(self, graph): """ diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index 0b4b2a285f5de2596b5d30c6b2a6213762a64e7a..c7feca0b82606cdba9a05fb6de821aa6d347d4e6 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase): arg_name.endswith('.quantized.dequantized')) self.assertTrue(arg_name in quantized_ops) - def linear_fc_quant(self, quant_type, for_ci=False): + def linear_fc_quant(self, activation_quant_type, for_ci=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase): transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), place=place, - activation_quantize_type=quant_type) + activation_quantize_type=activation_quant_type) transform_pass.apply(graph) if not for_ci: marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) + graph.draw('.', 'quantize_fc_' + activation_quant_type, + marked_nodes) program = graph.to_program() self.check_program(transform_pass, program) val_graph = IrGraph(core.Graph(program.desc), for_test=False) @@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase): for op in val_graph.all_op_nodes(): if op.name().find('quantize') > -1: val_marked_nodes.add(op) - val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) + val_graph.draw('.', 'val_fc_' + activation_quant_type, + val_marked_nodes) def test_linear_fc_quant_abs_max(self): self.linear_fc_quant('abs_max', for_ci=True) @@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase): def test_linear_fc_quant_moving_average_abs_max(self): self.linear_fc_quant('moving_average_abs_max', for_ci=True) - def residual_block_quant(self, quant_type, for_ci=False): + def residual_block_quant(self, activation_quant_type, for_ci=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase): transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), place=place, - activation_quantize_type=quant_type) + activation_quantize_type=activation_quant_type) transform_pass.apply(graph) if not for_ci: marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) + graph.draw('.', 'quantize_residual_' + activation_quant_type, + marked_nodes) program = graph.to_program() self.check_program(transform_pass, program) val_graph = IrGraph(core.Graph(program.desc), for_test=False) @@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase): for op in val_graph.all_op_nodes(): if op.name().find('quantize') > -1: val_marked_nodes.add(op) - val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) + val_graph.draw('.', 'val_residual_' + activation_quant_type, + val_marked_nodes) def test_residual_block_abs_max(self): self.residual_block_quant('abs_max', for_ci=True) @@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase): - def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False): + def freeze_graph(self, + use_cuda, + seed, + activation_quant_type, + weight_quant_type='abs_max', + for_ci=False): def build_program(main, startup, is_test): main.random_seed = seed startup.random_seed = seed @@ -243,7 +252,12 @@ class TestQuantizationFreezePass(unittest.TestCase): with fluid.scope_guard(scope): exe.run(startup) transform_pass = QuantizationTransformPass( - scope=scope, place=place, activation_quantize_type=quant_type) + scope=scope, + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type) + #transform_pass = QuantizationTransformPass( + # scope=scope, place=place, activation_quantize_type=activation_quant_type) transform_pass.apply(main_graph) transform_pass.apply(test_graph) dev_name = '_gpu_' if use_cuda else '_cpu_' @@ -252,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in main_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) + main_graph.draw('.', 'main' + dev_name + activation_quant_type + '_' + + weight_quant_type, marked_nodes) marked_nodes = set() for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) + test_graph.draw('.', 'test' + dev_name + activation_quant_type + '_' + + weight_quant_type, marked_nodes) build_strategy = fluid.BuildStrategy() build_strategy.memory_optimize = False @@ -282,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase): feed=feeder.feed(data), fetch_list=[loss]) if not for_ci: - print('{}: {}'.format('loss' + dev_name + quant_type, - loss_v)) + print('{}: {}'.format('loss' + dev_name + + activation_quant_type + '_' + + weight_quant_type, loss_v)) test_data = next(test_reader()) with fluid.program_guard(quantized_test_program): @@ -296,14 +313,17 @@ class TestQuantizationFreezePass(unittest.TestCase): fetch_list=[loss, w_var]) # Freeze graph for inference, but the weight of fc/conv is still float type. - freeze_pass = QuantizationFreezePass(scope=scope, place=place) + freeze_pass = QuantizationFreezePass( + scope=scope, place=place, weight_quantize_type=weight_quant_type) + #freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass.apply(test_graph) if not for_ci: marked_nodes = set() for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test_freeze' + dev_name + quant_type, + test_graph.draw('.', 'test_freeze' + dev_name + + activation_quant_type + '_' + weight_quant_type, marked_nodes) server_program = test_graph.to_program() @@ -313,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase): fetch_list=[loss]) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) if not for_ci: - print('{}: {}'.format('test_loss1' + dev_name + quant_type, - test_loss1)) - print('{}: {}'.format('test_loss2' + dev_name + quant_type, - test_loss2)) + print( + '{}: {}'.format('test_loss1' + dev_name + activation_quant_type + + '_' + weight_quant_type, test_loss1)) + print( + '{}: {}'.format('test_loss2' + dev_name + activation_quant_type + + '_' + weight_quant_type, test_loss2)) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) # Maybe failed, this is due to the calculation precision # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) if not for_ci: - print('{}: {}'.format('w_freeze' + dev_name + quant_type, - np.sum(w_freeze))) - print('{}: {}'.format('w_quant' + dev_name + quant_type, - np.sum(w_quant))) + print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_freeze))) + print('{}: {}'.format('w_quant' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_quant))) # Convert parameter to 8-bit. convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) @@ -334,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test_int8' + dev_name + quant_type, - marked_nodes) + test_graph.draw('.', 'test_int8' + dev_name + activation_quant_type + + '_' + weight_quant_type, marked_nodes) server_program_int8 = test_graph.to_program() # Save the 8-bit parameter and model file. with fluid.scope_guard(scope): - fluid.io.save_inference_model('server_int8' + dev_name + quant_type, - ['image', 'label'], [loss], exe, - server_program_int8) + fluid.io.save_inference_model( + 'server_int8' + dev_name + activation_quant_type + '_' + + weight_quant_type, ['image', 'label'], [loss], exe, + server_program_int8) # Test whether the 8-bit parameter and model file can be loaded successfully. [infer, feed, fetch] = fluid.io.load_inference_model( - 'server_int8' + dev_name + quant_type, exe) + 'server_int8' + dev_name + activation_quant_type + '_' + + weight_quant_type, exe) # Check the loaded 8-bit weight. w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) if not for_ci: - print('{}: {}'.format('w_8bit' + dev_name + quant_type, - np.sum(w_8bit))) - print('{}: {}'.format('w_freeze' + dev_name + quant_type, - np.sum(w_freeze))) + print('{}: {}'.format('w_8bit' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_8bit))) + print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_freeze))) mobile_pass = TransformForMobilePass() mobile_pass.apply(test_graph) @@ -362,42 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test_mobile' + dev_name + quant_type, + test_graph.draw('.', 'test_mobile' + dev_name + + activation_quant_type + '_' + weight_quant_type, marked_nodes) mobile_program = test_graph.to_program() with fluid.scope_guard(scope): - fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type, - ['image', 'label'], [loss], exe, - mobile_program) + fluid.io.save_inference_model( + 'mobile_int8' + dev_name + activation_quant_type + '_' + + weight_quant_type, ['image', 'label'], [loss], exe, + mobile_program) def test_freeze_graph_cuda_dynamic(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): self.freeze_graph( - True, seed=1, quant_type='abs_max', for_ci=True) + True, + seed=1, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=True) + with fluid.unique_name.guard(): + self.freeze_graph( + True, + seed=1, + activation_quant_type='abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) def test_freeze_graph_cpu_dynamic(self): with fluid.unique_name.guard(): - self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) def test_freeze_graph_cuda_static(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): self.freeze_graph( - True, seed=1, quant_type='range_abs_max', for_ci=True) + True, + seed=1, + activation_quant_type='range_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + True, + seed=1, + activation_quant_type='moving_average_abs_max', + weight_quant_type='abs_max', + for_ci=True) self.freeze_graph( True, seed=1, - quant_type='moving_average_abs_max', + activation_quant_type='range_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) + self.freeze_graph( + True, + seed=1, + activation_quant_type='moving_average_abs_max', + weight_quant_type='channel_wise_abs_max', for_ci=True) def test_freeze_graph_cpu_static(self): with fluid.unique_name.guard(): self.freeze_graph( - False, seed=2, quant_type='range_abs_max', for_ci=True) + False, + seed=2, + activation_quant_type='range_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='moving_average_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='range_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) self.freeze_graph( - False, seed=2, quant_type='moving_average_abs_max', for_ci=True) + False, + seed=2, + activation_quant_type='moving_average_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 32cb23cbfa9bdef4728e85d0014123652e4aefea..0812b02b47db7fa2d43e1d3bbd0a3f7b59911326 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range): return y -def channel_wise_quantize_max_abs(x, quant_bit=8): +def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False): scales = [] - for i in range(x.shape[0]): - scales.append(np.max(np.abs(x[i])).astype("float32")) - - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - for i, scale in enumerate(scales): - y[i] = np.round(y[i] / scale * max_range) + if not use_second_dim: + for i in range(x.shape[0]): + scales.append(np.max(np.abs(x[i])).astype("float32")) + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + for i, scale in enumerate(scales): + y[i] = np.round(x[i] / scale * max_range) + else: + for i in range(x.shape[0]): + s = [] + for j in range(x.shape[1]): + s.append(np.max(np.abs(x[i][j])).astype("float32")) + scales.append(s) + scales = np.amax(np.array(scales), axis=0) + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + for i in range(x.shape[0]): + for j, scale in enumerate(scales): + y[i][j] = np.round(x[i][j] / scale * max_range) return y, scales @@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x, scales, quant_bits, activation_scale=None): - y = x.copy() - for i in range(x.shape[0]): - y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i] - if activation_scale is not None: + if activation_scale is None: + y = x.copy() + for i in range(x.shape[0]): + y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i] + else: + y = x.copy() + for i in range(x.shape[0]): + for j in range(x.shape[1]): + y[i][j] = (scales[j] / + (math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j] y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) return y @@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) + yq, scales = channel_wise_quantize_max_abs( + x, self.quant_bits[0], use_second_dim=True) ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, self.activation_scale) diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index cf8f01edb9a6a2b6d91080248553491c54e7707b..07038b0441d0dc37a42cbf2058c1b5f41b47a5da 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest): self.outputs = { 'Out': outputs, - 'OutScales': np.array(scales).astype("float32"), + 'OutScale': np.array(scales).astype("float32"), } def test_check_output(self):