From c71025eb4592eeac54d0e1c99cbf47bb31d3c92a Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Thu, 26 Aug 2021 10:30:45 +0800 Subject: [PATCH] fix the bug of channel-wise quantization for ernie (#34948) --- .../ir/quant_conv2d_dequant_fuse_pass.cc | 4 + ...fake_channel_wise_dequantize_max_abs.pbtxt | 4 + paddle/fluid/operators/fake_dequantize_op.cc | 81 ++++++++++++++----- paddle/fluid/operators/fake_dequantize_op.cu | 17 ++-- paddle/fluid/operators/fake_dequantize_op.h | 12 +-- .../slim/quantization/quantization_pass.py | 7 +- 6 files changed, 92 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 354db8acf87..5958728946c 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() { .AddAttr("quant_axis") .IsIntIn({0, 1}) .IsOptional() + .End() + .AddAttr("x_num_col_dims") + .IsType() + .IsOptional() .End(); AddOpCompat(OpCompat("conv2d")) .AddInput("Input") diff --git a/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt b/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt index ec80ffaaf32..c32c170ce65 100644 --- a/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt +++ b/paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt @@ -17,4 +17,8 @@ def { name: "quant_axis" type: INT } + attrs { + name: "x_num_col_dims" + type: INT + } } diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index b70fe78e1a5..14ae6beb4e4 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor { void operator()(const platform::CPUDeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, const int scale_num, T max_range, const int quant_axis, - framework::Tensor* out) { + const int x_num_col_dims, framework::Tensor* out) { if (scale_num == 1) { // Dequant op is before quantized op // Dequantize the weight of quantized op @@ -81,23 +81,50 @@ struct ChannelDequantizeFunctor { } else if (scale_num == 2) { // Dequant op is after quantized op // Dequantize the output tensor of quantized op - int batch_size = in->dims()[0]; - int channel = in->dims()[1]; - const T* scale_one = scales[0]->data(); - const T* scale_two = scales[1]->data(); - for (int i = 0; i < batch_size; i++) { - framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( - framework::slice_ddim(in->dims(), 1, in->dims().size())); - framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( - framework::slice_ddim(out->dims(), 1, out->dims().size())); - for (int j = 0; j < channel; j++) { - T s = scale_one[j]; - framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); - framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); - auto in_e = framework::EigenVector::Flatten(one_channel_in); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s * scale_two[0] / max_range; + if (x_num_col_dims > 1) { + auto in_dims = in->dims(); + const int64_t channel = in_dims[x_num_col_dims]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + int64_t out_iter = 1; + for (int i = 0; i < x_num_col_dims; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(dev_ctx.GetPlace()); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_one[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s * scale_two[0] / max_range; + ++cur_in; + ++cur_out; + } + } + } + } else { + int batch_size = in->dims()[0]; + int channel = in->dims()[1]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + for (int i = 0; i < batch_size; i++) { + framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( + framework::slice_ddim(in->dims(), 1, in->dims().size())); + framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( + framework::slice_ddim(out->dims(), 1, out->dims().size())); + for (int j = 0; j < channel; j++) { + T s = scale_one[j]; + framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); + framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s * scale_two[0] / max_range; + } } } } @@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker "the received is %d", quant_axis)); }); - + AddAttr("x_num_col_dims", + "The x_num_col_dims of mul. Only used for mul or matmul.") + .SetDefault(1) + .AddCustomChecker([](const int& x_num_col_dims) { + PADDLE_ENFORCE_EQ(x_num_col_dims == 0, false, + platform::errors::InvalidArgument( + "'x_num_col_dims' should be larger than 0, but " + "the received is %d", + x_num_col_dims)); + }); AddComment(R"DOC( FakeChannelWiseDequantizeMaxAbsOp operator. @@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs) R"ROC(add new attributes [quant_axis] for applying per-channel " "dequantization to conv2d_tranpose and mul ops.)ROC", paddle::framework::compatible::OpVersionDesc().NewAttr( - "quant_axis", "The axis for dequantization.", 0)); + "quant_axis", "The axis for dequantization.", 0)) + .AddCheckpoint( + R"ROC(add new attributes [x_num_col_dims] for applying per-channel " + "dequantization to mul ops.)ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "x_num_col_dims", "The x_num_col_dims for dequantization.", 1)); diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index a89c430c7ab..c88a8fe196e 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, template __global__ void DequantizeTwoScale(const T* in, const T* scale_one, const T* scale_two, T max_range, int num, - int batch_size, int channel, T* out) { + int iter_size, int channel, T* out) { int tid = threadIdx.x; - int channel_size = num / (batch_size * channel); + int channel_size = num / (iter_size * channel); int scale_index = blockIdx.x % channel; const T* in_c = in + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size; @@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor { void operator()(const platform::CUDADeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, const int scale_num, T max_range, const int quant_axis, - framework::Tensor* out) { + const int x_num_col_dims, framework::Tensor* out) { auto in_dims = in->dims(); const T* in_data = in->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); @@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor { } else if (scale_num == 2) { // Not need to consider quant_axis int num = in->numel(); - int batch_size = in->dims()[0]; - int channel = in->dims()[1]; + int iter_size = 1; + for (int i = 0; i < x_num_col_dims; i++) { + iter_size *= in->dims()[i]; + } + int channel = in->dims()[x_num_col_dims]; const T* scale_one = scales[0]->data(); const T* scale_two = scales[1]->data(); int block = 1024; - int grid = batch_size * channel; + int grid = iter_size * channel; DequantizeTwoScale<<>>( - in_data, scale_one, scale_two, max_range, num, batch_size, channel, + in_data, scale_one, scale_two, max_range, num, iter_size, channel, out_data); } } diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index 6ddb12771fd..4485edcafba 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -33,7 +33,8 @@ template struct ChannelDequantizeFunctor { void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, const int scale_num, - T max_range, const int quant_axis, framework::Tensor* out); + T max_range, const int quant_axis, const int x_num_col_dims, + framework::Tensor* out); }; template @@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { auto quant_bits = ctx.Attr>("quant_bits"); auto quant_axis = ctx.Attr("quant_axis"); + auto x_num_col_dims = ctx.Attr("x_num_col_dims"); int max_range = 1; auto& dev_ctx = ctx.template device_context(); @@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { max_range *= (std::pow(2, quant_bits[0] - 1) - 1); } else if (scale_num == 2) { PADDLE_ENFORCE_EQ( - scales[0]->numel(), in->dims()[1], + scales[0]->numel(), in->dims()[x_num_col_dims], platform::errors::PreconditionNotMet( "The number of first scale values must be the same with " - "second dimension value of Input(X) when the `Scales` has two " - "elements, but %ld != %ld here.", + "corresponding dimension value of Input(X) when the `Scales` " + "has two elements, but %ld != %ld here.", scales[0]->numel(), in->dims()[1])); PADDLE_ENFORCE_EQ(scales[1]->numel(), 1, platform::errors::PreconditionNotMet( @@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { } ChannelDequantizeFunctor()( dev_ctx, in, scales.data(), scale_num, static_cast(max_range), - quant_axis, out); + quant_axis, x_num_col_dims, out); } }; diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 9917730daa5..c2d7a9bb4d5 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1273,12 +1273,17 @@ class QuantizationFreezePass(object): var_type=output_var_node.type(), shape=output_var_node.shape(), var_dtype=output_var_node.dtype()) + if op_node.op().has_attr("x_num_col_dims"): + x_num_col_dims = op_node.op().attr("x_num_col_dims") + else: + x_num_col_dims = 1 dequant_op_node = graph.create_op_node( op_type='fake_channel_wise_dequantize_max_abs', attrs={ 'quant_bits': [self._weight_bits, self._activation_bits], 'quant_axis': quant_axis, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward, + 'x_num_col_dims': x_num_col_dims }, inputs={ 'X': output_var_node, -- GitLab