diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 68c7227e5a7123e1e751dd55e243ee481bf36540..4a8937ba1c7ef9827ecc9bf575d9893c95a3b22b 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -33,8 +33,51 @@ struct DequantizeFunctor { } }; +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor** scales, + const int scale_num, T max_range, framework::Tensor* out) { + if (scale_num == 1) { + const int channel = in->dims()[0]; + const T* scale_factor = scales[0]->data(); + for (int i = 0; i < channel; i++) { + T s = scale_factor[i]; + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = (s / max_range) * in_e; + } + } else if (scale_num == 2) { + int batch_size = in->dims()[0]; + int channel = in->dims()[1]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + for (int i = 0; i < batch_size; i++) { + framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( + framework::slice_ddim(in->dims(), 1, in->dims().size())); + framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( + framework::slice_ddim(out->dims(), 1, out->dims().size())); + for (int j = 0; j < channel; j++) { + T s = scale_one[j]; + framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); + framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = (s * scale_two[0] / max_range) * in_e; + } + } + } + } +}; + template struct DequantizeFunctor; template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { public: diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 35dcc69279d0119e75c4c5072e7817c839b9e819..02f9dc827d68cbb58447ed1557ff4bf310b2c017 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -44,8 +44,66 @@ struct DequantizeFunctor { } }; +template +__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, + int num, int channel, T* out) { + int tid = threadIdx.x; + int channel_size = num / channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale[blockIdx.x] / max_range; + } +} + +template +__global__ void DequantizeTwoScale(const T* in, const T* scale_one, + const T* scale_two, T max_range, int num, + int batch_size, int channel, T* out) { + int tid = threadIdx.x; + int channel_size = num / (batch_size * channel); + int scale_index = blockIdx.x % channel; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + for (int i = tid; i < channel_size; i += blockDim.x) { + out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range; + } +} + +template +struct ChannelDequantizeFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor* in, const framework::Tensor** scales, + const int scale_num, T max_range, framework::Tensor* out) { + const T* in_data = in->data(); + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + if (scale_num == 1) { + int num = in->numel(); + int channel = in->dims()[0]; + const T* scale_factor = scales[0]->data(); + int block = 1024; + int grid = channel; + DequantizeOneScale<<>>( + in_data, scale_factor, max_range, num, channel, out_data); + } else if (scale_num == 2) { + int num = in->numel(); + int batch_size = in->dims()[0]; + int channel = in->dims()[1]; + const T* scale_one = scales[0]->data(); + const T* scale_two = scales[1]->data(); + int block = 1024; + int grid = batch_size * channel; + DequantizeTwoScale<<>>( + in_data, scale_one, scale_two, max_range, num, batch_size, channel, + out_data); + } + } +}; + template struct DequantizeFunctor; template struct DequantizeFunctor; +template struct ChannelDequantizeFunctor; +template struct ChannelDequantizeFunctor; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index 1a504bf035a5bcc91dbe2296a2a1ead1ea44e08a..ed9a0a4d65fab5ce1ef48835c332fade978d2bae 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -29,6 +29,13 @@ struct DequantizeFunctor { framework::Tensor* out); }; +template +struct ChannelDequantizeFunctor { + void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, + const framework::Tensor** scales, const int scale_num, + T max_range, framework::Tensor* out); +}; + template class FakeDequantizeMaxAbsKernel : public framework::OpKernel { public: @@ -56,50 +63,32 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto quant_bits = ctx.Attr>("quant_bits"); - int max_range = std::pow(2, quant_bits[0] - 1) - 1; + int max_range = 1; auto& dev_ctx = ctx.template device_context(); out->mutable_data(dev_ctx.GetPlace()); - - auto dequant = DequantizeFunctor(); - if (scales.size() == 1) { + int scale_num = scales.size(); + if (scale_num == 1) { PADDLE_ENFORCE_EQ( scales[0]->numel(), in->dims()[0], "The number of first scale values must be the same with " "first dimension value of Input(X) when the `Scales` has only one " "element."); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); - dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(max_range), &one_channel_out); - } - } else if (scales.size() == 2) { + max_range *= (std::pow(2, quant_bits[0] - 1) - 1); + } else if (scale_num == 2) { PADDLE_ENFORCE_EQ( scales[0]->numel(), in->dims()[1], "The number of first scale values must be the same with " "second dimension value of Input(X) when the `Scales` has two " "elements."); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( - framework::slice_ddim(in->dims(), 1, in->dims().size())); - framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( - framework::slice_ddim(out->dims(), 1, out->dims().size())); - for (int64_t j = 0; j < in->dims()[1]; j++) { - framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); - framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); - framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1); - dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(max_range), &one_channel_out); - } - } PADDLE_ENFORCE_EQ( scales[1]->numel(), 1, "The second scale tensor should only have one value at now."); - max_range = std::pow(2, quant_bits[1] - 1) - 1; - dequant(dev_ctx, out, scales[1], static_cast(max_range), out); + max_range *= (std::pow(2, quant_bits[0] - 1) - 1) * + (std::pow(2, quant_bits[1] - 1) - 1); } + ChannelDequantizeFunctor()( + dev_ctx, in, scales.data(), scale_num, static_cast(max_range), out); } }; diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index f9f28f60c0fb6e86d90361fbc97a273125392c01..054ef4658cc0c4448d49870849017d3191d57db9 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -37,6 +37,21 @@ struct FindAbsMaxFunctor { template struct FindAbsMaxFunctor; +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* in, + const int num, const int channel, T* out) { + const int channel_size = num / channel; + for (int i = 0; i < channel; i++) { + auto* start = in + i * channel_size; + auto* end = in + (i + 1) * channel_size; + out[i] = std::abs(*(std::max_element(start, end, Compare()))); + } + } +}; + +template struct FindChannelAbsMaxFunctor; + template struct ClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, @@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor { template struct ClipAndFakeQuantFunctor; +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int channel, + framework::Tensor* out) { + auto* scale_data = scale.data(); + auto* in_data = in.data(); + auto* out_data = out->mutable_data(ctx.GetPlace()); + const int channel_size = in.numel() / channel; + platform::Transform trans; + for (int i = 0; i < channel; i++) { + T s = scale_data[i]; + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + trans(ctx, start, end, out_data + i * channel_size, + ClipFunctor(-s, s)); + } + for (int i = 0; i < channel; i++) { + T s = scale_data[i]; + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round(); + } + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + template struct FindRangeAbsMaxFunctor { void operator()(const platform::CPUDeviceContext& ctx, diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 3707f6772eac0d568c170d60c17d431e254d0b6b..33bd275e5cc507ec700b3694cd8b1df9672ec512 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -74,6 +74,45 @@ struct FindAbsMaxFunctor { template struct FindAbsMaxFunctor; +template +__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, + T* out) { + int tid = threadIdx.x; + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + extern __shared__ T shared_max_data[]; + shared_max_data[tid] = T(0); + for (int i = tid; i < channel_size; i += blockDim.x) { + T tmp = fabs(in_c[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + __syncthreads(); + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +template +struct FindChannelAbsMaxFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* in, + const int num, const int channel, T* out) { + int block = 1024; + int grid = channel; + FindChannelAbsMaxKernel<<>>( + in, num, channel, out); + } +}; + +template struct FindChannelAbsMaxFunctor; + template __global__ void ClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, T* out) { @@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, T s = scale[0]; for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[bid]; + T x = in[i]; T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt / s * v; - out[bid] = round(v); + out[i] = round(v); } } +template +struct ClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); + } +}; + +template struct ClipAndFakeQuantFunctor; + +template +__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, + const int c, T* out) { + int tid = threadIdx.x; + + int channel_size = n / c; + const T* in_c = in + blockIdx.x * channel_size; + T* out_c = out + blockIdx.x * channel_size; + + T s = scale[blockIdx.x]; + for (int i = tid; i < channel_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt / s * v; + out_c[i] = round(v); + } +} + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, const int channel, + framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = channel; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ChannelClipAndQuantKernel<<>>( + in_data, scale_data, bin_cnt, num, channel, out_data); + } +}; + +template struct ChannelClipAndFakeQuantFunctor; + template __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, const T* last_scale, @@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor { template struct FindMovingAverageAbsMaxFunctor; -template -struct ClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = (block - 1 + num) / block; - - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); - - ClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); - } -}; - -template struct ClipAndFakeQuantFunctor; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 2616cd996f05bc1d229c95d88b7bebf2c1c7a34a..5ab38b086df7f9df33996ec83b5ec07047c204ba 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor { framework::Tensor* scales_arr, framework::Tensor* out_scale); }; +template +struct FindChannelAbsMaxFunctor { + void operator()(const DeviceContext& ctx, const T* in, const int num, + const int channel, T* out); +}; + +template +struct ChannelClipAndFakeQuantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + const int channel, framework::Tensor* out); +}; + template struct FindMovingAverageAbsMaxFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, @@ -86,21 +99,10 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { int bin_cnt = std::pow(2, bit_length - 1) - 1; auto& dev_ctx = context.template device_context(); - auto find_abs_max = FindAbsMaxFunctor(); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel = in->Slice(i, i + 1); - const T* one_channel_data = one_channel.data(); - find_abs_max(dev_ctx, one_channel_data, one_channel.numel(), - &out_scale_data[i]); - } - auto clip_quant = ClipAndFakeQuantFunctor(); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = out_scale->Slice(i, i + 1); - clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt, - &one_channel_out); - } + FindChannelAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), in->dims()[0], out_scale_data); + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out); } }; diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 03ffd2795d089435dd26c7e8e35c37df0dc13934..5dcef506711b78c2aef30d16719f8766359ae8f3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -576,8 +576,6 @@ class QuantizationFreezePass(object): elif self._weight_quantize_type == 'channel_wise_abs_max': param = self._load_var(input_arg_name) if len(param.shape) == 4: # conv2d or depthwise_conv2d - print('DEBUG**************************: %s' % - input_arg_name) scale_v = [] for i in range(param.shape[0]): scale_v.append(np.max(np.abs(param[i]))) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index cda7ecbd8251d4a5818adeae6cef6654fbb84c7a..c7feca0b82606cdba9a05fb6de821aa6d347d4e6 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase): arg_name.endswith('.quantized.dequantized')) self.assertTrue(arg_name in quantized_ops) - def linear_fc_quant(self, quant_type, for_ci=False): + def linear_fc_quant(self, activation_quant_type, for_ci=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase): transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), place=place, - activation_quantize_type=quant_type) + activation_quantize_type=activation_quant_type) transform_pass.apply(graph) if not for_ci: marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) + graph.draw('.', 'quantize_fc_' + activation_quant_type, + marked_nodes) program = graph.to_program() self.check_program(transform_pass, program) val_graph = IrGraph(core.Graph(program.desc), for_test=False) @@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase): for op in val_graph.all_op_nodes(): if op.name().find('quantize') > -1: val_marked_nodes.add(op) - val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) + val_graph.draw('.', 'val_fc_' + activation_quant_type, + val_marked_nodes) def test_linear_fc_quant_abs_max(self): self.linear_fc_quant('abs_max', for_ci=True) @@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase): def test_linear_fc_quant_moving_average_abs_max(self): self.linear_fc_quant('moving_average_abs_max', for_ci=True) - def residual_block_quant(self, quant_type, for_ci=False): + def residual_block_quant(self, activation_quant_type, for_ci=False): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase): transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), place=place, - activation_quantize_type=quant_type) + activation_quantize_type=activation_quant_type) transform_pass.apply(graph) if not for_ci: marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) + graph.draw('.', 'quantize_residual_' + activation_quant_type, + marked_nodes) program = graph.to_program() self.check_program(transform_pass, program) val_graph = IrGraph(core.Graph(program.desc), for_test=False) @@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase): for op in val_graph.all_op_nodes(): if op.name().find('quantize') > -1: val_marked_nodes.add(op) - val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) + val_graph.draw('.', 'val_residual_' + activation_quant_type, + val_marked_nodes) def test_residual_block_abs_max(self): self.residual_block_quant('abs_max', for_ci=True) @@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase): - def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False): + def freeze_graph(self, + use_cuda, + seed, + activation_quant_type, + weight_quant_type='abs_max', + for_ci=False): def build_program(main, startup, is_test): main.random_seed = seed startup.random_seed = seed @@ -245,10 +254,10 @@ class TestQuantizationFreezePass(unittest.TestCase): transform_pass = QuantizationTransformPass( scope=scope, place=place, - activation_quantize_type=quant_type, - weight_quantize_type='channel_wise_abs_max') + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type) #transform_pass = QuantizationTransformPass( - # scope=scope, place=place, activation_quantize_type=quant_type) + # scope=scope, place=place, activation_quantize_type=activation_quant_type) transform_pass.apply(main_graph) transform_pass.apply(test_graph) dev_name = '_gpu_' if use_cuda else '_cpu_' @@ -257,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in main_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) + main_graph.draw('.', 'main' + dev_name + activation_quant_type + '_' + + weight_quant_type, marked_nodes) marked_nodes = set() for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) + test_graph.draw('.', 'test' + dev_name + activation_quant_type + '_' + + weight_quant_type, marked_nodes) build_strategy = fluid.BuildStrategy() build_strategy.memory_optimize = False @@ -287,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase): feed=feeder.feed(data), fetch_list=[loss]) if not for_ci: - print('{}: {}'.format('loss' + dev_name + quant_type, - loss_v)) + print('{}: {}'.format('loss' + dev_name + + activation_quant_type + '_' + + weight_quant_type, loss_v)) test_data = next(test_reader()) with fluid.program_guard(quantized_test_program): @@ -302,9 +314,7 @@ class TestQuantizationFreezePass(unittest.TestCase): # Freeze graph for inference, but the weight of fc/conv is still float type. freeze_pass = QuantizationFreezePass( - scope=scope, - place=place, - weight_quantize_type='channel_wise_abs_max') + scope=scope, place=place, weight_quantize_type=weight_quant_type) #freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass.apply(test_graph) if not for_ci: @@ -312,7 +322,8 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test_freeze' + dev_name + quant_type, + test_graph.draw('.', 'test_freeze' + dev_name + + activation_quant_type + '_' + weight_quant_type, marked_nodes) server_program = test_graph.to_program() @@ -322,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase): fetch_list=[loss]) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) if not for_ci: - print('{}: {}'.format('test_loss1' + dev_name + quant_type, - test_loss1)) - print('{}: {}'.format('test_loss2' + dev_name + quant_type, - test_loss2)) + print( + '{}: {}'.format('test_loss1' + dev_name + activation_quant_type + + '_' + weight_quant_type, test_loss1)) + print( + '{}: {}'.format('test_loss2' + dev_name + activation_quant_type + + '_' + weight_quant_type, test_loss2)) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) # Maybe failed, this is due to the calculation precision # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) if not for_ci: - print('{}: {}'.format('w_freeze' + dev_name + quant_type, - np.sum(w_freeze))) - print('{}: {}'.format('w_quant' + dev_name + quant_type, - np.sum(w_quant))) + print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_freeze))) + print('{}: {}'.format('w_quant' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_quant))) # Convert parameter to 8-bit. convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) @@ -343,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test_int8' + dev_name + quant_type, - marked_nodes) + test_graph.draw('.', 'test_int8' + dev_name + activation_quant_type + + '_' + weight_quant_type, marked_nodes) server_program_int8 = test_graph.to_program() # Save the 8-bit parameter and model file. with fluid.scope_guard(scope): - fluid.io.save_inference_model('server_int8' + dev_name + quant_type, - ['image', 'label'], [loss], exe, - server_program_int8) + fluid.io.save_inference_model( + 'server_int8' + dev_name + activation_quant_type + '_' + + weight_quant_type, ['image', 'label'], [loss], exe, + server_program_int8) # Test whether the 8-bit parameter and model file can be loaded successfully. [infer, feed, fetch] = fluid.io.load_inference_model( - 'server_int8' + dev_name + quant_type, exe) + 'server_int8' + dev_name + activation_quant_type + '_' + + weight_quant_type, exe) # Check the loaded 8-bit weight. w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) if not for_ci: - print('{}: {}'.format('w_8bit' + dev_name + quant_type, - np.sum(w_8bit))) - print('{}: {}'.format('w_freeze' + dev_name + quant_type, - np.sum(w_freeze))) + print('{}: {}'.format('w_8bit' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_8bit))) + print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_freeze))) mobile_pass = TransformForMobilePass() mobile_pass.apply(test_graph) @@ -371,45 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase): for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) - test_graph.draw('.', 'test_mobile' + dev_name + quant_type, + test_graph.draw('.', 'test_mobile' + dev_name + + activation_quant_type + '_' + weight_quant_type, marked_nodes) mobile_program = test_graph.to_program() with fluid.scope_guard(scope): - fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type, - ['image', 'label'], [loss], exe, - mobile_program) + fluid.io.save_inference_model( + 'mobile_int8' + dev_name + activation_quant_type + '_' + + weight_quant_type, ['image', 'label'], [loss], exe, + mobile_program) def test_freeze_graph_cuda_dynamic(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): self.freeze_graph( - True, seed=1, quant_type='abs_max', for_ci=False) + True, + seed=1, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=True) + with fluid.unique_name.guard(): + self.freeze_graph( + True, + seed=1, + activation_quant_type='abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) def test_freeze_graph_cpu_dynamic(self): with fluid.unique_name.guard(): - self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=False) + self.freeze_graph( + False, + seed=2, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) def test_freeze_graph_cuda_static(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): self.freeze_graph( - True, seed=1, quant_type='range_abs_max', for_ci=False) + True, + seed=1, + activation_quant_type='range_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + True, + seed=1, + activation_quant_type='moving_average_abs_max', + weight_quant_type='abs_max', + for_ci=True) self.freeze_graph( True, seed=1, - quant_type='moving_average_abs_max', - for_ci=False) + activation_quant_type='range_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) + self.freeze_graph( + True, + seed=1, + activation_quant_type='moving_average_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) def test_freeze_graph_cpu_static(self): with fluid.unique_name.guard(): self.freeze_graph( - False, seed=2, quant_type='range_abs_max', for_ci=False) + False, + seed=2, + activation_quant_type='range_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='moving_average_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.freeze_graph( + False, + seed=2, + activation_quant_type='range_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) self.freeze_graph( False, seed=2, - quant_type='moving_average_abs_max', - for_ci=False) + activation_quant_type='moving_average_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) if __name__ == '__main__':