未验证 提交 a10fc884 编写于 作者: G Guanghua Yu 提交者: GitHub

fix act model_dir rstrip (#1201)

上级 23b5a731
......@@ -116,7 +116,7 @@ class AutoCompression:
as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None.
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
"""
self.model_dir = model_dir
self.model_dir = model_dir.rstrip('/')
if model_filename == 'None':
model_filename = None
......@@ -144,7 +144,7 @@ class AutoCompression:
self.train_config = extract_train_config(config)
# prepare dataloader
self.feed_vars = get_feed_vars(model_dir, model_filename,
self.feed_vars = get_feed_vars(self.model_dir, model_filename,
params_filename)
self.train_dataloader = wrap_dataloader(train_dataloader,
self.feed_vars)
......@@ -158,7 +158,7 @@ class AutoCompression:
paddle.enable_static()
self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type(self._exe, model_dir,
self.model_type = self._get_model_type(self._exe, self.model_dir,
model_filename, params_filename)
if self.train_config is not None and self.train_config.use_fleet:
......@@ -171,7 +171,7 @@ class AutoCompression:
infer_shape_model = self.create_tmp_dir(
self.final_dir, prefix="infer_shape_model_")
self._infer_shape(model_dir, self.model_filename,
self._infer_shape(self.model_dir, self.model_filename,
self.params_filename, input_shapes,
infer_shape_model)
self.model_dir = infer_shape_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册