未验证 提交 c71025eb 编写于 作者: X XGZhang 提交者: GitHub

fix the bug of channel-wise quantization for ernie (#34948)

上级 0efda9d9
...@@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsIntIn({0, 1}) .IsIntIn({0, 1})
.IsOptional() .IsOptional()
.End()
.AddAttr("x_num_col_dims")
.IsType<int>()
.IsOptional()
.End(); .End();
AddOpCompat(OpCompat("conv2d")) AddOpCompat(OpCompat("conv2d"))
.AddInput("Input") .AddInput("Input")
......
...@@ -17,4 +17,8 @@ def { ...@@ -17,4 +17,8 @@ def {
name: "quant_axis" name: "quant_axis"
type: INT type: INT
} }
attrs {
name: "x_num_col_dims"
type: INT
}
} }
...@@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx, void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales, const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, const int quant_axis, const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) { const int x_num_col_dims, framework::Tensor* out) {
if (scale_num == 1) { if (scale_num == 1) {
// Dequant op is before quantized op // Dequant op is before quantized op
// Dequantize the weight of quantized op // Dequantize the weight of quantized op
...@@ -81,23 +81,50 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -81,23 +81,50 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
} else if (scale_num == 2) { } else if (scale_num == 2) {
// Dequant op is after quantized op // Dequant op is after quantized op
// Dequantize the output tensor of quantized op // Dequantize the output tensor of quantized op
int batch_size = in->dims()[0]; if (x_num_col_dims > 1) {
int channel = in->dims()[1]; auto in_dims = in->dims();
const T* scale_one = scales[0]->data<T>(); const int64_t channel = in_dims[x_num_col_dims];
const T* scale_two = scales[1]->data<T>(); const T* scale_one = scales[0]->data<T>();
for (int i = 0; i < batch_size; i++) { const T* scale_two = scales[1]->data<T>();
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( int64_t out_iter = 1;
framework::slice_ddim(in->dims(), 1, in->dims().size())); for (int i = 0; i < x_num_col_dims; i++) {
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( out_iter *= in_dims[i];
framework::slice_ddim(out->dims(), 1, out->dims().size())); }
for (int j = 0; j < channel; j++) { int64_t step_i = in->numel() / out_iter;
T s = scale_one[j]; int64_t step_j = in->numel() / (out_iter * channel);
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); auto* in_data = in->data<T>();
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in); for (int64_t i = 0; i < out_iter; i++) {
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); for (int64_t j = 0; j < channel; j++) {
auto& dev = *dev_ctx.eigen_device(); auto* cur_in = in_data + i * step_i + j * step_j;
out_e.device(dev) = in_e * s * scale_two[0] / max_range; auto* cur_out = out_data + i * step_i + j * step_j;
T s = scale_one[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s * scale_two[0] / max_range;
++cur_in;
++cur_out;
}
}
}
} else {
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s * scale_two[0] / max_range;
}
} }
} }
} }
...@@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker ...@@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
}); });
AddAttr<int>("x_num_col_dims",
"The x_num_col_dims of mul. Only used for mul or matmul.")
.SetDefault(1)
.AddCustomChecker([](const int& x_num_col_dims) {
PADDLE_ENFORCE_EQ(x_num_col_dims == 0, false,
platform::errors::InvalidArgument(
"'x_num_col_dims' should be larger than 0, but "
"the received is %d",
x_num_col_dims));
});
AddComment(R"DOC( AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator. FakeChannelWiseDequantizeMaxAbsOp operator.
...@@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs) ...@@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
R"ROC(add new attributes [quant_axis] for applying per-channel " R"ROC(add new attributes [quant_axis] for applying per-channel "
"dequantization to conv2d_tranpose and mul ops.)ROC", "dequantization to conv2d_tranpose and mul ops.)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr( paddle::framework::compatible::OpVersionDesc().NewAttr(
"quant_axis", "The axis for dequantization.", 0)); "quant_axis", "The axis for dequantization.", 0))
.AddCheckpoint(
R"ROC(add new attributes [x_num_col_dims] for applying per-channel "
"dequantization to mul ops.)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"x_num_col_dims", "The x_num_col_dims for dequantization.", 1));
...@@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, ...@@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
template <typename T> template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one, __global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num, const T* scale_two, T max_range, int num,
int batch_size, int channel, T* out) { int iter_size, int channel, T* out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = num / (batch_size * channel); int channel_size = num / (iter_size * channel);
int scale_index = blockIdx.x % channel; int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size;
...@@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales, const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, const int quant_axis, const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) { const int x_num_col_dims, framework::Tensor* out) {
auto in_dims = in->dims(); auto in_dims = in->dims();
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
...@@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
} else if (scale_num == 2) { } else if (scale_num == 2) {
// Not need to consider quant_axis // Not need to consider quant_axis
int num = in->numel(); int num = in->numel();
int batch_size = in->dims()[0]; int iter_size = 1;
int channel = in->dims()[1]; for (int i = 0; i < x_num_col_dims; i++) {
iter_size *= in->dims()[i];
}
int channel = in->dims()[x_num_col_dims];
const T* scale_one = scales[0]->data<T>(); const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>(); const T* scale_two = scales[1]->data<T>();
int block = 1024; int block = 1024;
int grid = batch_size * channel; int grid = iter_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>( DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_one, scale_two, max_range, num, batch_size, channel, in_data, scale_one, scale_two, max_range, num, iter_size, channel,
out_data); out_data);
} }
} }
......
...@@ -33,7 +33,8 @@ template <typename DeviceContext, typename T> ...@@ -33,7 +33,8 @@ template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor { struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num, const framework::Tensor** scales, const int scale_num,
T max_range, const int quant_axis, framework::Tensor* out); T max_range, const int quant_axis, const int x_num_col_dims,
framework::Tensor* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits"); auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
auto quant_axis = ctx.Attr<int>("quant_axis"); auto quant_axis = ctx.Attr<int>("quant_axis");
auto x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int max_range = 1; int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
...@@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
max_range *= (std::pow(2, quant_bits[0] - 1) - 1); max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) { } else if (scale_num == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1], scales[0]->numel(), in->dims()[x_num_col_dims],
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with " "The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two " "corresponding dimension value of Input(X) when the `Scales` "
"elements, but %ld != %ld here.", "has two elements, but %ld != %ld here.",
scales[0]->numel(), in->dims()[1])); scales[0]->numel(), in->dims()[1]));
PADDLE_ENFORCE_EQ(scales[1]->numel(), 1, PADDLE_ENFORCE_EQ(scales[1]->numel(), 1,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
} }
ChannelDequantizeFunctor<DeviceContext, T>()( ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range),
quant_axis, out); quant_axis, x_num_col_dims, out);
} }
}; };
......
...@@ -1273,12 +1273,17 @@ class QuantizationFreezePass(object): ...@@ -1273,12 +1273,17 @@ class QuantizationFreezePass(object):
var_type=output_var_node.type(), var_type=output_var_node.type(),
shape=output_var_node.shape(), shape=output_var_node.shape(),
var_dtype=output_var_node.dtype()) var_dtype=output_var_node.dtype())
if op_node.op().has_attr("x_num_col_dims"):
x_num_col_dims = op_node.op().attr("x_num_col_dims")
else:
x_num_col_dims = 1
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs', op_type='fake_channel_wise_dequantize_max_abs',
attrs={ attrs={
'quant_bits': [self._weight_bits, self._activation_bits], 'quant_bits': [self._weight_bits, self._activation_bits],
'quant_axis': quant_axis, 'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'x_num_col_dims': x_num_col_dims
}, },
inputs={ inputs={
'X': output_var_node, 'X': output_var_node,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册