From d28162b97fd2d224968c18c9da735900ef280e7c Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Fri, 18 Sep 2020 10:19:07 +0800 Subject: [PATCH] Remove save_quantized_model in ImperativeQuantAware. (#27240) --- .../slim/quantization/imperative/qat.py | 83 ++----------------- .../contrib/slim/tests/test_imperative_qat.py | 36 ++++---- 2 files changed, 27 insertions(+), 92 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 8d399c9290..7b27629363 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -67,6 +67,7 @@ class ImperativeQuantAware(object): Examples: .. code-block:: python + import paddle from paddle.fluid.contrib.slim.quantization \ import ImperativeQuantAware from paddle.vision.models \ @@ -86,13 +87,12 @@ class ImperativeQuantAware(object): # ... # Save quant model for the inference. - imperative_qat.save_quantized_model( - dirname="./resnet50_qat", - model=model, - input_shape=[(3, 224, 224)], - input_dtype=['float32'], - feed=[0], - fetch=[0]) + paddle.jit.save( + layer=model, + model_path="./resnet50_qat", + input_spec=[ + paddle.static.InputSpec( + shape=[None, 3, 224, 224], dtype='float32')]) """ super(ImperativeQuantAware, self).__init__() self._weight_bits = weight_bits @@ -148,75 +148,6 @@ class ImperativeQuantAware(object): quant_layer = self._get_quantized_counterpart(layer) setattr(obj, target, quant_layer) - def save_quantized_model(self, - dirname, - model, - input_shape, - input_dtype, - feed, - fetch, - append_batch_size=True): - """ - Save the quantized model for the inference. - - Args: - dirname (str): the directory to save the quantized model. - model(fluid.dygraph.Layer): the quantized model to be saved. - input_shape(list[tuple(int)]): The shape value for each input, - e.g. [(3, 224, 224)]. - input_dtype(list[str]): The dtype value for each input, - e.g. ['float32']. - feed(list[int]): the indices of the input variables of the - imperative functions which will be saved as input variables in - inference model. - fetch(list[int]): the indices of the returned variable of the - imperative functions which will be saved as output variables in - inference model. - append_batch_size(bool, optional): - If true, it prepends an extra axis to the input_shape, meanwhile, - the input_shape shouldn't contain the batch size dimension. - Otherwise, it just uses the input_shape. Default True. - Returns: - None - """ - assert isinstance( - input_shape, list), "The parameter `input_shape` shoubld be a list." - assert isinstance( - input_dtype, list), "The parameter `input_dtype` shoubld be a list." - assert isinstance(feed, list), "The parameter `feed` shoubld be a list." - assert isinstance(fetch, - list), "The parameter `fetch` shoubld be a list." - assert len(input_shape) == len( - input_dtype - ), "The length of input_shape should be equal to input_dtype's." - assert len(input_dtype) == len( - feed), "The length of input_shape should be equal to feed's." - - with dygraph.guard(): - model.eval() - input_vars = [] - for i, (shape, dtype) in enumerate(zip(input_shape, input_dtype)): - if append_batch_size: - shape = [None] + list(shape) - # Note(Aurelius84): need a elegant way to name this. - in_spec = paddle.static.InputSpec(shape, dtype, 'feed_%d' % i) - input_vars.append(in_spec) - # use `declarative` to convert dygraph into static program - model.forward = dygraph.jit.declarative( - model.forward, input_spec=input_vars) - outputs = model.forward.concrete_program.outputs - input_spec = [input_vars[i] for i in feed] - configs = dygraph.jit.SaveLoadConfig() - configs.separate_params = True - if not isinstance(outputs, (tuple, list)): - outputs = [outputs] - configs.output_spec = [outputs[i] for i in fetch] - dygraph.jit.save( - layer=model, - model_path=dirname, - input_spec=input_spec, - configs=configs) - def _get_quantized_counterpart(self, layer): quant_layers = tuple(self._quant_layers_map.values()) quantized_counterpart = tuple('Quantized' + k diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index 79b0bbd6a4..f076d274b6 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -221,7 +221,7 @@ class TestImperativeQat(unittest.TestCase): model_dict = lenet.state_dict() fluid.save_dygraph(model_dict, "save_temp") - # test the correctness of `save_quantized_model` + # test the correctness of `paddle.jit.save` data = next(test_reader()) test_data = np.array([x[0].reshape(1, 28, 28) for x in data]).astype('float32') @@ -231,13 +231,14 @@ class TestImperativeQat(unittest.TestCase): # save inference quantized model path = "./mnist_infer_model" - imperative_qat.save_quantized_model( - dirname=path, - model=lenet, - input_shape=[(1, 28, 28)], - input_dtype=['float32'], - feed=[0], - fetch=[0]) + paddle.jit.save( + layer=lenet, + model_path=path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) + if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) else: @@ -245,7 +246,10 @@ class TestImperativeQat(unittest.TestCase): exe = fluid.Executor(place) [inference_program, feed_target_names, fetch_targets] = ( fluid.io.load_inference_model( - dirname=path, executor=exe)) + dirname=path, + executor=exe, + model_filename="__model__", + params_filename="__variables__")) after_save, = exe.run(inference_program, feed={feed_target_names[0]: test_data}, fetch_list=fetch_targets) @@ -332,13 +336,13 @@ class TestImperativeQat(unittest.TestCase): if batch_id % 100 == 0: _logger.info('{}: {}'.format('loss', avg_loss.numpy())) - imperative_qat.save_quantized_model( - dirname="./dynamic_mnist", - model=lenet, - input_shape=[(1, 28, 28)], - input_dtype=['float32'], - feed=[0], - fetch=[0]) + paddle.jit.save( + layer=lenet, + model_path="./dynamic_mnist", + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) # static graph train _logger.info( -- GitLab