diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 60e4ac8cbcfd8cc8f1d14363538fe1e118b953cd..9d3e0806ac79d838765ca5a4bbf61d0f67ab6ed5 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass, paddle::framework::ir::ConvTransposeBNFusePass); REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass, paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass); +REGISTER_PASS(depthwise_conv_bn_fuse_pass, + paddle::framework::ir::DepthwiseConvBNFusePass); +REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, + paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h index fcdbcf299c504c00b3027207bc2f4ac019d48ffc..57a9f69ca15af2759874a1e2a0b58399de652693 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { std::string conv_type() const { return "conv2d_transpose"; } }; +class DepthwiseConvBNFusePass : public ConvBNFusePass { + public: + std::string conv_type() const { return "depthwise_conv2d"; } +}; + +class DepthwiseConvEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { + public: + std::string conv_type() const { return "depthwise_conv2d"; } +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 0d2b951ee1c544151e99af8216db7809e2a77852..9b0328b0945ba9b57cb9ab27233656e3b0af4f5f 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -37,20 +37,49 @@ template struct ChannelDequantizeFunctor { void operator()(const platform::CPUDeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, - const int scale_num, T max_range, framework::Tensor* out) { + const int scale_num, T max_range, const int quant_axis, + framework::Tensor* out) { if (scale_num == 1) { - const int channel = in->dims()[0]; + // Dequant op is before quantized op + // Dequantize the weight of quantized op + auto in_dims = in->dims(); + const int64_t channel = in_dims[quant_axis]; const T* scale_factor = scales[0]->data(); - for (int i = 0; i < channel; i++) { - T s = scale_factor[i]; - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto in_e = framework::EigenVector::Flatten(one_channel_in); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s / max_range; + if (quant_axis == 0) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_factor[i]; + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s / max_range; + } + } else if (quant_axis == 1) { + int64_t out_iter = 1; + for (int i = 0; i < quant_axis; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(dev_ctx.GetPlace()); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_factor[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s / max_range; + ++cur_in; + ++cur_out; + } + } + } } } else if (scale_num == 2) { + // Dequant op is after quantized op + // Dequantize the output tensor of quantized op int batch_size = in->dims()[0]; int channel = in->dims()[1]; const T* scale_one = scales[0]->data(); @@ -157,6 +186,18 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker "Quantization bit numbers in quantization stage. " "The size of `quant_bits` should be equal to the size of `Scales`.") .SetDefault({8}); + AddAttr("quant_axis", + "(int, default 0) The axis for quantization. " + "For conv2d, depthwise_conv2d, conv2d_transpose " + "and mul, the quant_axis is equal to the cout axis.") + .SetDefault(0) + .AddCustomChecker([](const int& quant_axis) { + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + }); AddComment(R"DOC( FakeChannelWiseDequantizeMaxAbsOp operator. diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 02f9dc827d68cbb58447ed1557ff4bf310b2c017..54a92b055a39d49ea061250b066957f933fb975e 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -45,8 +45,9 @@ struct DequantizeFunctor { }; template -__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, - int num, int channel, T* out) { +__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, + T max_range, int num, int channel, + T* out) { int tid = threadIdx.x; int channel_size = num / channel; const T* in_c = in + blockIdx.x * channel_size; @@ -56,6 +57,23 @@ __global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, } } +template +__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, + T max_range, const int num, + const int cin, const int cout, + T* out) { + int cout_wh_size = num / cin; + int wh_size = cout_wh_size / cout; + + T s = scale[blockIdx.x]; + const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; + T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; + + for (int i = 0; i < wh_size; i++) { + out_current[i] = in_current[i] * s / max_range; + } +} + template __global__ void DequantizeTwoScale(const T* in, const T* scale_one, const T* scale_two, T max_range, int num, @@ -74,18 +92,29 @@ template struct ChannelDequantizeFunctor { void operator()(const platform::CUDADeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, - const int scale_num, T max_range, framework::Tensor* out) { + const int scale_num, T max_range, const int quant_axis, + framework::Tensor* out) { + auto in_dims = in->dims(); const T* in_data = in->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); if (scale_num == 1) { int num = in->numel(); - int channel = in->dims()[0]; const T* scale_factor = scales[0]->data(); - int block = 1024; - int grid = channel; - DequantizeOneScale<<>>( - in_data, scale_factor, max_range, num, channel, out_data); + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + DequantizeOneScaleQuantAxis0<<>>( + in_data, scale_factor, max_range, num, in_dims[0], out_data); + } else if (quant_axis == 1) { + // Dequantize weight of Cin * Cout * W * H + int grid = in_dims[1]; + int block = in_dims[0]; + DequantizeOneScaleQuantAxis1<<>>( + in_data, scale_factor, max_range, num, in_dims[0], in_dims[1], + out_data); + } } else if (scale_num == 2) { + // Not need to consider quant_axis int num = in->numel(); int batch_size = in->dims()[0]; int channel = in->dims()[1]; diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index 500960098f5ce5e66af5690138c15cc0eaa80d83..6ddb12771fd5176dbe27642adcb2ac82e4d7bfbf 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -33,7 +33,7 @@ template struct ChannelDequantizeFunctor { void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, const int scale_num, - T max_range, framework::Tensor* out); + T max_range, const int quant_axis, framework::Tensor* out); }; template @@ -63,6 +63,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto quant_bits = ctx.Attr>("quant_bits"); + auto quant_axis = ctx.Attr("quant_axis"); int max_range = 1; auto& dev_ctx = ctx.template device_context(); @@ -70,12 +71,12 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { int scale_num = scales.size(); if (scale_num == 1) { PADDLE_ENFORCE_EQ( - scales[0]->numel(), in->dims()[0], + scales[0]->numel(), in->dims()[quant_axis], platform::errors::PreconditionNotMet( "The number of first scale values must be the same with " - "first dimension value of Input(X) when the `Scales` has only " - "one element, but %ld != %ld here.", - scales[0]->numel(), in->dims()[0])); + "quant_axis dimension value of Input(X) when the `Scales` has " + "only one element, but %ld != %ld here.", + scales[0]->numel(), in->dims()[quant_axis])); max_range *= (std::pow(2, quant_bits[0] - 1) - 1); } else if (scale_num == 2) { PADDLE_ENFORCE_EQ( @@ -94,7 +95,8 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { (std::pow(2, quant_bits[1] - 1) - 1); } ChannelDequantizeFunctor()( - dev_ctx, in, scales.data(), scale_num, static_cast(max_range), out); + dev_ctx, in, scales.data(), scale_num, static_cast(max_range), + quant_axis, out); } }; diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 358f122c8359fa60f2c99492db8851c8a5fc5293..04ac4a35208a54361a4f434e68095e9519ee12e9 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/clip_op.h" @@ -39,13 +40,41 @@ template struct FindAbsMaxFunctor; template struct FindChannelAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, const T* in, - const int num, const int channel, T* out) { - const int channel_size = num / channel; - for (int i = 0; i < channel; i++) { - auto* start = in + i * channel_size; - auto* end = in + (i + 1) * channel_size; - out[i] = std::abs(*(std::max_element(start, end, Compare()))); + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + auto* in_data = in_tensor.data(); + auto in_dims = in_tensor.dims(); + const int64_t channel = in_dims[quant_axis]; + if (quant_axis == 0) { + const int64_t channel_size = in_tensor.numel() / channel; + for (int64_t i = 0; i < channel; i++) { + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + out_abs_max[i] = + std::abs(*(std::max_element(start, end, Compare()))); + } + } else if (quant_axis == 1) { + for (int64_t i = 0; i < channel; i++) { + out_abs_max[i] = 0; + } + const int64_t step_i = in_tensor.numel() / in_dims[0]; + const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); + for (int64_t i = 0; i < in_dims[0]; i++) { + for (int64_t j = 0; j < in_dims[1]; j++) { + auto* start = in_data + i * step_i + j * step_j; + auto* end = in_data + i * step_i + (j + 1) * step_j; + T abs_max = std::abs(*(std::max_element(start, end, Compare()))); + out_abs_max[j] = std::max(out_abs_max[j], abs_max); + } + } } } }; @@ -92,26 +121,53 @@ template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int channel, + const int bin_cnt, const int quant_axis, framework::Tensor* out) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); auto* scale_data = scale.data(); auto* in_data = in.data(); auto* out_data = out->mutable_data(ctx.GetPlace()); - const int channel_size = in.numel() / channel; + auto in_dims = in.dims(); + const int64_t channel = in_dims[quant_axis]; platform::Transform trans; - for (int i = 0; i < channel; i++) { - T s = scale_data[i]; - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); - } - for (int i = 0; i < channel; i++) { - T s = scale_data[i]; - T inv_s = inverse(s); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + if (quant_axis == 0) { + const int64_t channel_size = in.numel() / channel; + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + trans(ctx, start, end, out_data + i * channel_size, + ClipFunctor(-s, s)); + } + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + T inv_s = inverse(s); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + } + } else if (quant_axis == 1) { + const int64_t step_i = in.numel() / in_dims[0]; + const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]); + for (int i = 0; i < in_dims[0]; i++) { + for (int j = 0; j < in_dims[1]; j++) { + T s = scale_data[j]; + T inv_s = inverse(s); + auto* start = in_data + i * step_i + j * step_j; + auto* end = in_data + i * step_i + (j + 1) * step_j; + auto* cur_out_data = out_data + i * step_i + j * step_j; + trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + for (int k = 0; k < step_j; k++) { + cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); + } + } + } } } }; @@ -247,8 +303,9 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { "FakeChannelWiseQuantizeAbsMax"); OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", "FakeChannelWiseQuantizeAbsMax"); + int quant_axis = ctx->Attrs().Get("quant_axis"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); + ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]}); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -269,6 +326,18 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker "(Tensor) Output of quantized low level tensor, " "but also saved as float data type."); AddOutput("OutScale", "(Tensor) Current channel wise scale"); + AddAttr("quant_axis", + "(int, default 0) The axis for quantization. " + "For conv2d, depthwise_conv2d, conv2d_transpose " + "and mul, the quant_axis is equal to the cout axis.") + .SetDefault(0) + .AddCustomChecker([](const int& quant_axis) { + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + }); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 75a55fa821f0af664ad18cc20c90cd2f3d61d5d0..6ff3c7ec632f236fe4ae6c6504537df3b8a46b7a 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -75,8 +75,8 @@ struct FindAbsMaxFunctor { template struct FindAbsMaxFunctor; template -__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, - T* out) { +__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, + const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; @@ -100,14 +100,69 @@ __global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, } } +template +__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, + const int cin, const int cout, + T* out) { + extern __shared__ T shared_max_data[]; + int cout_wh_size = n / cin; + int wh_size = n / (cin * cout); + + int tid = threadIdx.x; + int bid = blockIdx.x; + const T* in_current = in + tid * cout_wh_size + bid * wh_size; + shared_max_data[tid] = T(0); + for (int i = 0; i < wh_size; i++) { + T tmp = fabs(in_current[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + __syncthreads(); + + int len = blockDim.x; + for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { + if (tid < i && tid + i < len && + shared_max_data[tid] < shared_max_data[tid + i]) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + if (i == 1) { + i = 0; // break the loop + } + __syncthreads(); + } + if (tid == 0) { + out[bid] = shared_max_data[0]; + } +} + template struct FindChannelAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, const T* in, - const int num, const int channel, T* out) { - int block = 1024; - int grid = channel; - FindChannelAbsMaxKernel<<>>( - in, num, channel, out); + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + const int num = in_tensor.numel(); + auto in_dims = in_tensor.dims(); + int channel = in_dims[quant_axis]; + const T* in_data = in_tensor.data(); + if (quant_axis == 0) { + int grid = channel; + int block = 1024; + FindChannelAbsMaxKernelQuantAxis0< + T><<>>( + in_data, num, channel, out_abs_max); + } else if (quant_axis == 1) { + int grid = in_dims[1]; + int block = in_dims[0]; + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, in_dims[0], in_dims[1], out_abs_max); + } } }; @@ -189,10 +244,12 @@ struct ClipAndFakeQuantDequantFunctor { template struct ClipAndFakeQuantDequantFunctor; +// ChannelClipAndQuantKernel for quant_axis is 0 template -__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, - const int c, T* out) { +__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, + const int bin_cnt, + const int n, const int c, + T* out) { int tid = threadIdx.x; int channel_size = n / c; @@ -211,22 +268,57 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, } } +// ChannelClipAndQuantKernel for quant_axis is 1 +template +__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale, + const int bin_cnt, + const int n, const int cin, + const int cout, T* out) { + T s = scale[blockIdx.x % cout]; + T inv_s = inverse(s); + + int wh_size = n / (cin * cout); + const T* in_c = in + blockIdx.x * wh_size; + T* out_c = out + blockIdx.x * wh_size; + + for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v); + } +} + template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int channel, + const int bin_cnt, const int quant_axis, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = channel; + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + int num = in.numel(); + auto in_dims = in.dims(); const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); - ChannelClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, channel, out_data); + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + ChannelClipAndQuantKernelQuantAxis0<<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], out_data); + } else if (quant_axis == 1) { + int grid = in_dims[0] * in_dims[1]; + int block = 1024; + ChannelClipAndQuantKernelQuantAxis1<<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); + } } }; diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 4136217fb0c5f600971c1c04f803b65de9bbecb4..5c6e0b1f6e26d84462a18da910b412f03b93285d 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -61,15 +61,15 @@ struct FindRangeAbsMaxFunctor { template struct FindChannelAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const T* in, const int num, - const int channel, T* out); + void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, + const int quant_axis, T* out_abs_max); }; template struct ChannelClipAndFakeQuantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, - const int channel, framework::Tensor* out); + const int quant_axis, framework::Tensor* out); }; template @@ -144,12 +144,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { int bit_length = context.Attr("bit_length"); int bin_cnt = std::pow(2, bit_length - 1) - 1; + int quant_axis = context.Attr("quant_axis"); auto& dev_ctx = context.template device_context(); - FindChannelAbsMaxFunctor()( - dev_ctx, in->data(), in->numel(), in->dims()[0], out_scale_data); + FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, + out_scale_data); ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out); + dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); } }; diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 3097e1d82a9cb5e096efa3913ea6a06bff557c94..244a621611060b87805846f1ea748615bcdde19a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -29,6 +29,7 @@ from .quantization_pass import _out_scale_op_list from .quantization_pass import _get_op_input_var_names from .quantization_pass import _get_op_output_var_names from .quantization_pass import _get_output_name_index +from .quantization_pass import _channelwise_quant_axis1_ops __all__ = ['PostTrainingQuantization', 'WeightQuantization'] @@ -316,6 +317,7 @@ class PostTrainingQuantization(object): self._out_scale_op_list = _out_scale_op_list self._quantized_weight_var_name = set() self._quantized_act_var_name = set() + self.weight_op_pairs = {} self._sampling_data = {} self._quantized_var_kl_threshold = {} self._quantized_var_min = {} @@ -436,6 +438,8 @@ class PostTrainingQuantization(object): graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = _remove_ctrl_vars(graph) graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass') + graph = _apply_pass(self._scope, graph, 'depthwise_conv_bn_fuse_pass') + graph = _apply_pass(self._scope, graph, 'conv_transpose_bn_fuse_pass') self._program = graph.to_program() def _collect_target_varnames(self): @@ -446,10 +450,11 @@ class PostTrainingQuantization(object): # TODO(juncaipeng), consider the name_scope of skip_quant _logger.info("Collect quantized variable names ...") - def collect_var_name(var_name_list, persistable_var_names): + def collect_var_name(var_name_list, persistable_var_names, op_type): for var_name in var_name_list: if var_name in persistable_var_names: self._quantized_weight_var_name.add(var_name) + self.weight_op_pairs[var_name] = op_type else: self._quantized_act_var_name.add(var_name) @@ -462,13 +467,15 @@ class PostTrainingQuantization(object): # For quantized ops, sample inputs and outputs if op_type in self._quantizable_op_type: collect_var_name( - _get_op_input_var_names(op), persistable_var_names) + _get_op_input_var_names(op), persistable_var_names, op_type) collect_var_name( - _get_op_output_var_names(op), persistable_var_names) + _get_op_output_var_names(op), persistable_var_names, + op_type) # For other op, only sample output scale elif op_type in self._out_scale_op_list: collect_var_name( - _get_op_output_var_names(op), persistable_var_names) + _get_op_output_var_names(op), persistable_var_names, + op_type) def _set_activation_persistable(self): ''' @@ -492,45 +499,75 @@ class PostTrainingQuantization(object): Sample the input threshold(min, max, or abs_max) in every iterations. ''' assert self._algo in ["abs_max", "min_max"], \ - "The algo should be abs_max or min_max to sample min max value." - + "The algo should be abs_max or min_max for _sample_threshold." if self._algo == "abs_max": - # Only calculate abs_max value for weight for once - if self._quantized_var_abs_max == {}: - for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) - abs_max_per_channel = [] - for i in range(var_tensor.shape[0]): - abs_max_per_channel.append( - float(np.max(np.abs(var_tensor[i])))) - self._quantized_var_abs_max[var_name] = abs_max_per_channel - for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) - abs_max_value = float(np.max(np.abs(var_tensor))) - if (var_name not in self._quantized_var_abs_max) or \ - (abs_max_value > self._quantized_var_abs_max[var_name]): - self._quantized_var_abs_max[var_name] = abs_max_value + self._sample_threshold_abs_max() elif self._algo == "min_max": - if self._quantized_var_min == {} and self._quantized_var_max == {}: - for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) - min_per_channel = [] - max_per_channle = [] - for i in range(var_tensor.shape[0]): - min_per_channel.append(float(np.min(var_tensor[i]))) - max_per_channle.append(float(np.max(var_tensor[i]))) - self._quantized_var_min[var_name] = min_per_channel - self._quantized_var_max[var_name] = max_per_channle - for var_name in self._quantized_act_var_name: + self._sample_threshold_min_max() + + def _sample_threshold_abs_max(self): + assert self._algo == "abs_max", \ + "The algo should be abs_max for _sample_threshold_abs_max." + # Only calculate abs_max value for weight for once + if self._quantized_var_abs_max == {}: + for var_name in self._quantized_weight_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + if self._weight_quantize_type == "abs_max": + abs_max_value = float(np.max(np.abs(var_tensor))) + elif self._weight_quantize_type == "channel_wise_abs_max": + abs_max_value = [] + if self.weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[:, i])))) + else: + for i in range(var_tensor.shape[0]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[i])))) + self._quantized_var_abs_max[var_name] = abs_max_value + + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + abs_max_value = float(np.max(np.abs(var_tensor))) + if (var_name not in self._quantized_var_abs_max) or \ + (abs_max_value > self._quantized_var_abs_max[var_name]): + self._quantized_var_abs_max[var_name] = abs_max_value + + def _sample_threshold_min_max(self): + assert self._algo == "min_max", \ + "The algo should be min_max for _sample_threshold_min_max." + if self._quantized_var_min == {} and self._quantized_var_max == {}: + for var_name in self._quantized_weight_var_name: var_tensor = _load_variable_data(self._scope, var_name) - min_value = float(np.min(var_tensor)) - max_value = float(np.max(var_tensor)) - if (var_name not in self._quantized_var_min) or \ - (min_value < self._quantized_var_min[var_name]): - self._quantized_var_min[var_name] = min_value - if (var_name not in self._quantized_var_max) or \ - (max_value > self._quantized_var_max[var_name]): - self._quantized_var_max[var_name] = max_value + if self._weight_quantize_type == "abs_max": + min_value = float(np.min(var_tensor)) + max_value = float(np.max(var_tensor)) + elif self._weight_quantize_type == "channel_wise_abs_max": + min_value = [] + max_value = [] + if self.weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + min_value.append(float(np.min(var_tensor[:, i]))) + max_value.append(float(np.max(var_tensor[:, i]))) + else: + for i in range(var_tensor.shape[0]): + min_value.append(float(np.min(var_tensor[i]))) + max_value.append(float(np.max(var_tensor[i]))) + self._quantized_var_min[var_name] = min_value + self._quantized_var_max[var_name] = max_value + + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + min_value = float(np.min(var_tensor)) + max_value = float(np.max(var_tensor)) + if (var_name not in self._quantized_var_min) or \ + (min_value < self._quantized_var_min[var_name]): + self._quantized_var_min[var_name] = min_value + if (var_name not in self._quantized_var_max) or \ + (max_value > self._quantized_var_max[var_name]): + self._quantized_var_max[var_name] = max_value def _save_input_threhold(self): ''' @@ -554,11 +591,6 @@ class PostTrainingQuantization(object): applied in every iteration. ''' assert self._algo == "KL", "The algo should be KL to sample data." - for var_name in self._quantized_weight_var_name: - if var_name not in self._sampling_data: - var_tensor = _load_variable_data(self._scope, var_name) - self._sampling_data[var_name] = var_tensor - if self._is_use_cache_file: for var_name in self._quantized_act_var_name: var_tensor = _load_variable_data(self._scope, var_name) @@ -584,15 +616,20 @@ class PostTrainingQuantization(object): # Abs_max threshold for weights for var_name in self._quantized_weight_var_name: - weight_data = self._sampling_data[var_name] - weight_threshold = None + weight_data = _load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": - weight_threshold = np.max(np.abs(weight_data)) + weight_threshold = float(np.max(np.abs(weight_data))) elif self._weight_quantize_type == "channel_wise_abs_max": weight_threshold = [] - for i in range(weight_data.shape[0]): - abs_max_value = np.max(np.abs(weight_data[i])) - weight_threshold.append(abs_max_value) + if self.weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(weight_data.shape[1]): + weight_threshold.append( + float(np.max(np.abs(weight_data[:, i])))) + else: + for i in range(weight_data.shape[0]): + weight_threshold.append( + float(np.max(np.abs(weight_data[i])))) self._quantized_var_kl_threshold[var_name] = weight_threshold # KL threshold for activations diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 8851bcc6440d405f7484257b44760802feb0d8fb..0eef94896287af833c7d8d9e2a480627c61b3004 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -111,6 +111,10 @@ _op_real_in_out_name = { "scale": [["X"], ["Out"]], } +_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] + +_channelwise_quant_axis1_ops = ['conv2d_transpose', 'mul'] + def _get_op_input_var_names(op): """ """ @@ -185,10 +189,24 @@ def _is_input_all_not_persistable(graph, op_node): return is_input_all_not_persistable +def _check_grandchild_op_node(op_node, grandchild_op_name): + ''' + Check whether the fake_quant node has a grandchild op node named + grandchild_op_name. + ''' + for out1_var_node in op_node.outputs: + for out1_op_node in out1_var_node.outputs: + for out2_var_node in out1_op_node.outputs: + for out2_op_node in out2_var_node.outputs: + if out2_op_node.name() == grandchild_op_name: + return True + return False + + class QuantizationTransformPass(object): """ - Quantize the ops that have weights. Add quant and dequant ops for the quantized - ops's inputs. + Quantize the ops that have weights. Add quant and dequant ops for + the quantized ops's inputs. """ _supported_quantizable_op_type = [ 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul' @@ -311,8 +329,8 @@ class QuantizationTransformPass(object): if weight_quantize_type not in quant_type: raise ValueError( "Unknown weight_quantize_type: '%s'. It can only be " - "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'." - % (str(weight_quantize_type))) + "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' " + "or 'moving_average_abs_max'." % (str(weight_quantize_type))) self._activation_quantize_type = activation_quantize_type self._weight_quantize_type = weight_quantize_type @@ -323,7 +341,6 @@ class QuantizationTransformPass(object): for op in self._quantizable_ops: assert op in QuantizationTransformPass._supported_quantizable_op_type, \ op + " is not supported for quantization." - self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] @@ -356,10 +373,12 @@ class QuantizationTransformPass(object): user_skipped = False if isinstance(self._skip_pattern, list): user_skipped = op_node.op().has_attr("op_namescope") and \ - any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + any(pattern in op_node.op().attr("op_namescope") \ + for pattern in self._skip_pattern) elif isinstance(self._skip_pattern, str): user_skipped = op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + op_node.op().attr("op_namescope").find( + self._skip_pattern) != -1 if user_skipped: op_node.op()._set_attr("skip_quant", True) @@ -373,15 +392,11 @@ class QuantizationTransformPass(object): if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] else: - name = var_node.name() if name in processed_vars: continue - - if var_node.name() in persistable_vars: - is_weight = True - else: - is_weight = False + is_weight = True if var_node.name() in persistable_vars \ + else False # if var node is weight and weight_preprocess_func is not None, # will insert weight preprocess func @@ -415,20 +430,14 @@ class QuantizationTransformPass(object): else self._activation_bits quant_type = self._weight_quantize_type if is_weight \ else self._activation_quantize_type - if quant_type == 'channel_wise_abs_max': - assert is_weight, "'channel_wise_abs_max' can only be applied on weights." - if op.name() in self._conv_ops: - quant_var_node, scale_var_node = self._insert_channel_quant_op( - graph, var_node, name, quant_bits) - dequant_var_node = self._insert_channel_dequant_op( - graph, quant_var_node, [scale_var_node], - [quant_bits]) - else: - quant_var_node, scale_var_node = self._insert_quant_op( - graph, var_node, name, quant_bits, 'abs_max') - dequant_var_node = self._insert_dequant_op( - graph, quant_var_node, scale_var_node, - quant_bits) + if quant_type == 'channel_wise_abs_max': # Weight quantization + quant_axis = 1 if op.name() in \ + _channelwise_quant_axis1_ops else 0 + quant_var_node, scale_var_node = self._insert_channel_quant_op( + graph, var_node, name, quant_bits, quant_axis) + dequant_var_node = self._insert_channel_dequant_op( + graph, quant_var_node, [scale_var_node], + [quant_bits], quant_axis) else: quant_var_node, scale_var_node = self._insert_quant_op( graph, var_node, name, quant_bits, quant_type) @@ -529,11 +538,19 @@ class QuantizationTransformPass(object): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) - scale_var_node = graph.create_var_node( + scale_var_node = graph.create_persistable_node( name=self._quantized_scale_name(name), var_type=var_node.type(), shape=[1], var_dtype=var_node.dtype()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + scale_var_node, + np.zeros( + scale_var_node.shape(), dtype=data_type), + self._scope, + self._place) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', attrs={ @@ -706,7 +723,8 @@ class QuantizationTransformPass(object): return quant_var_node, scale_out_node - def _insert_channel_quant_op(self, graph, var_node, name, quant_bits): + def _insert_channel_quant_op(self, graph, var_node, name, quant_bits, + quant_axis): """ Insert fake_channel_wise_quantize_abs_max op in the graph. """ @@ -717,15 +735,24 @@ class QuantizationTransformPass(object): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) - scale_var_node = graph.create_var_node( + scale_var_node = graph.create_persistable_node( name=self._quantized_scale_name(name), var_type=var_node.type(), - shape=[var_node.shape()[0]], + shape=[var_node.shape()[quant_axis]], var_dtype=var_node.dtype()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + scale_var_node, + np.zeros( + scale_var_node.shape(), dtype=data_type), + self._scope, + self._place) quant_op_node = graph.create_op_node( op_type='fake_channel_wise_quantize_abs_max', attrs={ 'bit_length': quant_bits, + 'quant_axis': quant_axis, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node}, @@ -763,7 +790,7 @@ class QuantizationTransformPass(object): return dequant_var_node def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, - quant_bits): + quant_bits, quant_axis): """ Insert fake_channel_wise_dequantize_max_abs in the graph. """ @@ -778,6 +805,7 @@ class QuantizationTransformPass(object): op_type='fake_channel_wise_dequantize_max_abs', attrs={ 'quant_bits': quant_bits, + 'quant_axis': quant_axis, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node, @@ -1036,7 +1064,6 @@ class QuantizationFreezePass(object): self._weight_bits = weight_bits self._activation_bits = activation_bits self._weight_quantize_type = weight_quantize_type - self._conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] self._fake_quant_op_names = _fake_quant_op_list self._fake_dequant_op_names = _fake_dequant_op_list self._op_input_rename_map = collections.OrderedDict() @@ -1063,34 +1090,37 @@ class QuantizationFreezePass(object): if input_arg_name in graph.out_node_mapping_table.keys(): input_arg_name = graph.out_node_mapping_table[ input_arg_name] - if input_arg_name in persistable_vars: - if self._weight_quantize_type == 'abs_max': - param = self._load_var(input_arg_name) - scale_v = np.max(np.abs(param)) - elif self._weight_quantize_type == 'channel_wise_abs_max': - param = self._load_var(input_arg_name) - if len(param.shape) == 4: # conv2d or depthwise_conv2d - scale_v = [] - for i in range(param.shape[0]): - scale_v.append(np.max(np.abs(param[i]))) - else: - scale_v = np.max(np.abs(param)) + if input_arg_name not in persistable_vars: + scale_v = graph._find_node_by_name( + op_node.outputs, op_node.output('OutScale')[0]) + self._quant_var_scale_map[input_arg_name] = scale_v + else: + # Obtain scale from OutScale var node + scale_v = self._load_var(op_node.output('OutScale')[0]) + assert scale_v.ndim in [ + 1, 2 + ], "the dim of scale_v should be 1 or 2" + if scale_v.ndim == 2: + scale_v = scale_v[0] + if scale_v.size == 1: + scale_v = scale_v[0] else: - scale_v = self._load_var( - op_node.output('OutScale')[0])[0] + scale_v = scale_v.tolist() self._quant_var_scale_map[input_arg_name] = scale_v - self._remove_fake_quant_and_dequant_op(graph, op_node) - # quantize weight and restore + # Quantize weight and restore param_v = self._load_var(input_arg_name) - quantized_param_v = self._quant(param_v, scale_v, - self._weight_bits) + if isinstance(scale_v, list) and \ + any(_check_grandchild_op_node(op_node, op) + for op in _channelwise_quant_axis1_ops): + quant_axis = 1 + else: + quant_axis = 0 + quantized_param_v = self._quant( + param_v, scale_v, self._weight_bits, quant_axis) self._restore_var(input_arg_name, quantized_param_v) - else: - scale_v = graph._find_node_by_name( - op_node.outputs, op_node.output('OutScale')[0]) - self._quant_var_scale_map[input_arg_name] = scale_v + self._remove_fake_quant_and_dequant_op(graph, op_node) - # Remove all fake dequant op +# Remove all fake dequant op ops = graph.all_op_nodes() for op_node in ops: op_name = op_node.name() @@ -1103,8 +1133,7 @@ class QuantizationFreezePass(object): op_node_desc = op_node.op() if op_node_desc.has_attr("quantization_type") and \ op_node_desc.attr("quantization_type") == "qat_with_weight": - if self._weight_quantize_type == 'channel_wise_abs_max' \ - and op_node.name() in self._conv_ops: + if self._weight_quantize_type == 'channel_wise_abs_max': self._insert_post_channel_dequant_op(graph, op_node) else: self._insert_post_dequant_op(graph, op_node) @@ -1295,10 +1324,15 @@ class QuantizationFreezePass(object): return isinstance(v, float) or isinstance(v, np.float32) \ or isinstance(v, np.float64) - def _quant(self, x, scale, num_bits): + def _quant(self, x, scale, num_bits, quant_axis): + assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' if isinstance(scale, list): for i, s in enumerate(scale): - x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + if quant_axis == 0: + x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + else: + x[:, i] = np.round(x[:, i] / s * ( + (1 << (num_bits - 1)) - 1)) return x else: return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) @@ -1468,6 +1502,10 @@ class OutScaleForTrainingPass(object): for op in target_ops: for output_var_name in _get_op_output_var_names(op): in_node = graph._find_node_by_name(op.outputs, output_var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue + scale_node = graph.create_persistable_node( name=self._scale_name(in_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, @@ -1570,17 +1608,26 @@ class OutScaleForInferencePass(object): if op_node.name() in self._teller_set: var_names = _get_op_output_var_names(op_node) for var_name in var_names: - # For compatibility, we save output threshold by two methods. + in_node = graph._find_node_by_name(op_node.outputs, + var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue + scale_name = self._scale_name(var_name) - scale_v = np.array( - self._scope.find_var(scale_name).get_tensor())[0] - op_node.op()._set_attr("out_threshold", float(scale_v)) + scale_var = self._scope.find_var(scale_name) + assert scale_var is not None, \ + "Can not find {} variable in the scope".format(scale_name) + scale_value = np.array(scale_var.get_tensor())[0] + + # For compatibility, we save output threshold by two methods. + op_node.op()._set_attr("out_threshold", float(scale_value)) argname_index = _get_output_name_index(op_node, var_name) assert argname_index is not None, \ var_name + " is not the output of the op" op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ - + "_threshold", float(scale_v)) + + "_threshold", float(scale_value)) graph.resolve_hazard() return graph diff --git a/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py b/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py index c9ea15bf6cde9af16810920f53a7d5e045a852e3..32292c8a47b50bc5e7eb2d7833823e586eea8909 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py +++ b/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py @@ -33,34 +33,29 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CPU_NUM"] = "1" -def residual_block(img, label, num=1): - def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - bias_attr=False): - tmp = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - use_cudnn=False, - act=None, - bias_attr=bias_attr) - return fluid.layers.batch_norm(input=tmp, act=act) - - hidden = img - for _ in six.moves.xrange(num): - conv = conv_bn_layer(hidden, 20, 3, 1, 1, act=None, bias_attr=True) - short = conv_bn_layer(hidden, 20, 1, 1, 0, act=None) - hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - fc = fluid.layers.fc(input=hidden, size=10, act='softmax') - loss = fluid.layers.cross_entropy(input=fc, label=label) - loss = fluid.layers.mean(loss) - return loss +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + pool_type='max', + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + pool_type='avg', + act="relu") + hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return avg_loss def pact(x, name=None): @@ -102,7 +97,7 @@ class TestUserDefinedQuantization(unittest.TestCase): img.stop_gradient = False label = fluid.layers.data( name='label', shape=[1], dtype='int64') - loss = residual_block(img, label, 1) + loss = conv_net(img, label) if not is_test: opt = fluid.optimizer.SGD(learning_rate=0.0001) opt.minimize(loss) diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 0812b02b47db7fa2d43e1d3bbd0a3f7b59911326..b30e0a6775ea9901d8c2a3a56b2e80141fffd23c 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,45 +31,45 @@ def dequantize_max_abs(x, scale, max_range): return y -def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False): +def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." scales = [] - if not use_second_dim: + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + if quant_axis == 0: for i in range(x.shape[0]): - scales.append(np.max(np.abs(x[i])).astype("float32")) - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - for i, scale in enumerate(scales): - y[i] = np.round(x[i] / scale * max_range) - else: - for i in range(x.shape[0]): - s = [] - for j in range(x.shape[1]): - s.append(np.max(np.abs(x[i][j])).astype("float32")) - scales.append(s) - scales = np.amax(np.array(scales), axis=0) - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - for i in range(x.shape[0]): - for j, scale in enumerate(scales): - y[i][j] = np.round(x[i][j] / scale * max_range) + scale = np.max(np.abs(x[i])).astype("float32") + scales.append(scale) + y[i] = np.round(x[i] * max_range / scale) + elif quant_axis == 1: + for i in range(x.shape[1]): + scale = np.max(np.abs(x[:, i])).astype("float32") + scales.append(scale) + y[:, i] = np.round(x[:, i] * max_range / scale) return y, scales def channel_wise_dequantize_max_abs(x, scales, quant_bits, + quant_axis, activation_scale=None): - if activation_scale is None: - y = x.copy() - for i in range(x.shape[0]): - y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i] + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + + if isinstance(quant_bits, list): + max_range = math.pow(2, quant_bits[0] - 1) - 1 else: - y = x.copy() + max_range = math.pow(2, quant_bits - 1) - 1 + y = x.copy() + if quant_axis == 0: for i in range(x.shape[0]): - for j in range(x.shape[1]): - y[i][j] = (scales[j] / - (math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j] - y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) + y[i] = x[i] * scales[i] / max_range + elif quant_axis == 1: + for i in range(x.shape[1]): + y[:, i] = x[:, i] * scales[i] / max_range + + if activation_scale is not None: + y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) return y @@ -83,9 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scales = channel_wise_quantize_max_abs( - x, self.quant_bits[0], use_second_dim=True) - ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1, self.activation_scale) self.inputs = { @@ -105,25 +104,39 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): def set_args(self): self.quant_bits = [8] self.data_type = "float32" + self.quant_axis = 0 def setUp(self): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) - ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits) + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], + self.quant_axis) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + self.quant_axis) self.inputs = { 'X': yq, 'Scales': [("scales0", np.array(scales).astype(self.data_type))] } - self.attrs = {'quant_bits': self.quant_bits} + self.attrs = { + 'quant_bits': self.quant_bits, + 'quant_axis': self.quant_axis + } self.outputs = {'Out': ydq} def test_check_output(self): self.check_output() +class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1( + TestFakeChannelWiseDequantizeMaxAbsOpOneScale): + def set_args(self): + self.quant_bits = [8] + self.data_type = "float32" + self.quant_axis = 1 + + class TestFakeDequantizeMaxAbsOp(OpTest): def set_args(self): self.num_bits = 8 diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 1c8335e3bceab24cba9364a96f6907d2cf585fe0..7835fd3f53ddb7f9a95313c6cc5fc7b72ae6d664 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -72,28 +72,62 @@ class TestFakeQuantizeOp2(OpTest): class TestFakeChannelWiseQuantizeOp(OpTest): def setUp(self): + self.set_arg() + assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1." + self.op_type = "fake_channel_wise_quantize_abs_max" - self.attrs = {'bit_length': 8} - self.inputs = { - 'X': np.random.random((4, 3, 64, 64)).astype("float32"), - } + self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis} + scales = [] - for i in range(self.inputs['X'].shape[0]): - scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32")) outputs = self.inputs['X'].copy() - for i, scale in enumerate(scales): - outputs[i] = np.round(outputs[i] / scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + if self.quant_axis == 0: + for i in range(self.inputs['X'].shape[0]): + scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32") + scales.append(scale_v) + outputs[i] = np.round(outputs[i] / scale_v * bnt) + elif self.quant_axis == 1: + for i in range(self.inputs['X'].shape[1]): + scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype( + "float32") + scales.append(scale_v) + outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt) self.outputs = { 'Out': outputs, 'OutScale': np.array(scales).astype("float32"), } + def set_arg(self): + self.quant_axis = 0 + self.inputs = { + 'X': np.random.random((20, 15, 6, 6)).astype("float32"), + } + def test_check_output(self): self.check_output() +class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp): + def set_quant_axis(self): + self.quant_axis = 1 + self.inputs = { + 'X': np.random.random((15, 20, 5, 5)).astype("float32"), + } + + +class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp): + def set_quant_axis(self): + self.quant_axis = 0 + self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } + + +class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp): + def set_quant_axis(self): + self.quant_axis = 1 + self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } + + class TestFakeQuantizeRangeAbsMaxOp(OpTest): def setUp(self): self.op_type = "fake_quantize_range_abs_max"