未验证 提交 c71025eb 编写于 作者: X XGZhang 提交者: GitHub

fix the bug of channel-wise quantization for ernie (#34948)

上级 0efda9d9
......@@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("quant_axis")
.IsIntIn({0, 1})
.IsOptional()
.End()
.AddAttr("x_num_col_dims")
.IsType<int>()
.IsOptional()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
......
......@@ -17,4 +17,8 @@ def {
name: "quant_axis"
type: INT
}
attrs {
name: "x_num_col_dims"
type: INT
}
}
......@@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
const int x_num_col_dims, framework::Tensor* out) {
if (scale_num == 1) {
// Dequant op is before quantized op
// Dequantize the weight of quantized op
......@@ -81,23 +81,50 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
} else if (scale_num == 2) {
// Dequant op is after quantized op
// Dequantize the output tensor of quantized op
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s * scale_two[0] / max_range;
if (x_num_col_dims > 1) {
auto in_dims = in->dims();
const int64_t channel = in_dims[x_num_col_dims];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int64_t out_iter = 1;
for (int i = 0; i < x_num_col_dims; i++) {
out_iter *= in_dims[i];
}
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j;
T s = scale_one[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s * scale_two[0] / max_range;
++cur_in;
++cur_out;
}
}
}
} else {
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s * scale_two[0] / max_range;
}
}
}
}
......@@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
"the received is %d",
quant_axis));
});
AddAttr<int>("x_num_col_dims",
"The x_num_col_dims of mul. Only used for mul or matmul.")
.SetDefault(1)
.AddCustomChecker([](const int& x_num_col_dims) {
PADDLE_ENFORCE_EQ(x_num_col_dims == 0, false,
platform::errors::InvalidArgument(
"'x_num_col_dims' should be larger than 0, but "
"the received is %d",
x_num_col_dims));
});
AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.
......@@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
R"ROC(add new attributes [quant_axis] for applying per-channel "
"dequantization to conv2d_tranpose and mul ops.)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"quant_axis", "The axis for dequantization.", 0));
"quant_axis", "The axis for dequantization.", 0))
.AddCheckpoint(
R"ROC(add new attributes [x_num_col_dims] for applying per-channel "
"dequantization to mul ops.)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"x_num_col_dims", "The x_num_col_dims for dequantization.", 1));
......@@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num,
int batch_size, int channel, T* out) {
int iter_size, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / (batch_size * channel);
int channel_size = num / (iter_size * channel);
int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
......@@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
const int x_num_col_dims, framework::Tensor* out) {
auto in_dims = in->dims();
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
......@@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
} else if (scale_num == 2) {
// Not need to consider quant_axis
int num = in->numel();
int batch_size = in->dims()[0];
int channel = in->dims()[1];
int iter_size = 1;
for (int i = 0; i < x_num_col_dims; i++) {
iter_size *= in->dims()[i];
}
int channel = in->dims()[x_num_col_dims];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = batch_size * channel;
int grid = iter_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
in_data, scale_one, scale_two, max_range, num, iter_size, channel,
out_data);
}
}
......
......@@ -33,7 +33,8 @@ template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, const int quant_axis, framework::Tensor* out);
T max_range, const int quant_axis, const int x_num_col_dims,
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
......@@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
auto quant_axis = ctx.Attr<int>("quant_axis");
auto x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
......@@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1],
scales[0]->numel(), in->dims()[x_num_col_dims],
platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements, but %ld != %ld here.",
"corresponding dimension value of Input(X) when the `Scales` "
"has two elements, but %ld != %ld here.",
scales[0]->numel(), in->dims()[1]));
PADDLE_ENFORCE_EQ(scales[1]->numel(), 1,
platform::errors::PreconditionNotMet(
......@@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
}
ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range),
quant_axis, out);
quant_axis, x_num_col_dims, out);
}
};
......
......@@ -1273,12 +1273,17 @@ class QuantizationFreezePass(object):
var_type=output_var_node.type(),
shape=output_var_node.shape(),
var_dtype=output_var_node.dtype())
if op_node.op().has_attr("x_num_col_dims"):
x_num_col_dims = op_node.op().attr("x_num_col_dims")
else:
x_num_col_dims = 1
dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs',
attrs={
'quant_bits': [self._weight_bits, self._activation_bits],
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role': core.op_proto_and_checker_maker.OpRole.Forward,
'x_num_col_dims': x_num_col_dims
},
inputs={
'X': output_var_node,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册