From 769c28f501f1f028eeafca1b22503e61b1dbdd0f Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 10 Jun 2022 14:39:21 +0800 Subject: [PATCH] replace eval dataloader with train dataloader if eval_dataloader is None (#1163) * replace eval dataloader with train dataloader if eval_dataloader is None * update --- paddleslim/auto_compression/compressor.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index bbe75188..aee42274 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -96,9 +96,9 @@ class AutoCompression: If set to None, will choose a strategy automatically. Default: None. target_speedup(float, optional): target speedup ratio by the way of auto compress. Default: None. eval_callback(function, optional): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of compressed model. The documents of how to write eval function is `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst`_ . ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None. - eval_dataloader(paddle.io.Dataloader, optional): The - Generator or Dataloader provides eval data, and it could - return a batch every time. ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None. + eval_dataloader(paddle.io.Dataloader, optional): The Generator or Dataloader provides eval data, and it could + return a batch every time. If eval_dataloader is None, will take first 5000 sample from train_dataloader + as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None. deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'. """ self.model_dir = model_dir @@ -116,7 +116,10 @@ class AutoCompression: self.train_dataloader = train_dataloader self.target_speedup = target_speedup self.eval_function = eval_callback - self.eval_dataloader = eval_dataloader if eval_dataloader is not None else train_dataloader + + if eval_dataloader is None: + eval_dataloader = self._get_eval_dataloader(train_dataloader) + self.eval_dataloader = eval_dataloader paddle.enable_static() @@ -152,6 +155,17 @@ class AutoCompression: self.train_config = create_train_config(self.strategy_config, self.model_type) + def _get_eval_dataloader(self, train_dataloader): + def _gen(): + len_loader = len(list(train_dataloader())) + ### max eval_dataloader is 5000 if use train_dataloader as eval_dataloader + slice_len = min(5000, len_loader) + ret = list(itertools.islice(train_dataloader(), slice_len)) + for i in ret: + yield i + + return _gen + def _prepare_envs(self): devices = paddle.device.get_device().split(':')[0] places = paddle.device._convert_to_place(devices) -- GitLab