未验证 提交 769c28f5 编写于 作者: C ceci3 提交者: GitHub

replace eval dataloader with train dataloader if eval_dataloader is None (#1163)

* replace eval dataloader with train dataloader if eval_dataloader is None

* update
上级 6a161828
......@@ -96,9 +96,9 @@ class AutoCompression:
If set to None, will choose a strategy automatically. Default: None.
target_speedup(float, optional): target speedup ratio by the way of auto compress. Default: None.
eval_callback(function, optional): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of compressed model. The documents of how to write eval function is `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst`_ . ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None.
eval_dataloader(paddle.io.Dataloader, optional): The
Generator or Dataloader provides eval data, and it could
return a batch every time. ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None.
eval_dataloader(paddle.io.Dataloader, optional): The Generator or Dataloader provides eval data, and it could
return a batch every time. If eval_dataloader is None, will take first 5000 sample from train_dataloader
as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None.
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
"""
self.model_dir = model_dir
......@@ -116,7 +116,10 @@ class AutoCompression:
self.train_dataloader = train_dataloader
self.target_speedup = target_speedup
self.eval_function = eval_callback
self.eval_dataloader = eval_dataloader if eval_dataloader is not None else train_dataloader
if eval_dataloader is None:
eval_dataloader = self._get_eval_dataloader(train_dataloader)
self.eval_dataloader = eval_dataloader
paddle.enable_static()
......@@ -152,6 +155,17 @@ class AutoCompression:
self.train_config = create_train_config(self.strategy_config,
self.model_type)
def _get_eval_dataloader(self, train_dataloader):
def _gen():
len_loader = len(list(train_dataloader()))
### max eval_dataloader is 5000 if use train_dataloader as eval_dataloader
slice_len = min(5000, len_loader)
ret = list(itertools.islice(train_dataloader(), slice_len))
for i in ret:
yield i
return _gen
def _prepare_envs(self):
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册