You need to sign in or sign up before continuing.
未验证 提交 a10fc884 编写于 作者: G Guanghua Yu 提交者: GitHub

fix act model_dir rstrip (#1201)

上级 23b5a731
...@@ -116,7 +116,7 @@ class AutoCompression: ...@@ -116,7 +116,7 @@ class AutoCompression:
as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None. as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None.
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'. deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
""" """
self.model_dir = model_dir self.model_dir = model_dir.rstrip('/')
if model_filename == 'None': if model_filename == 'None':
model_filename = None model_filename = None
...@@ -144,7 +144,7 @@ class AutoCompression: ...@@ -144,7 +144,7 @@ class AutoCompression:
self.train_config = extract_train_config(config) self.train_config = extract_train_config(config)
# prepare dataloader # prepare dataloader
self.feed_vars = get_feed_vars(model_dir, model_filename, self.feed_vars = get_feed_vars(self.model_dir, model_filename,
params_filename) params_filename)
self.train_dataloader = wrap_dataloader(train_dataloader, self.train_dataloader = wrap_dataloader(train_dataloader,
self.feed_vars) self.feed_vars)
...@@ -158,7 +158,7 @@ class AutoCompression: ...@@ -158,7 +158,7 @@ class AutoCompression:
paddle.enable_static() paddle.enable_static()
self._exe, self._places = self._prepare_envs() self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type(self._exe, model_dir, self.model_type = self._get_model_type(self._exe, self.model_dir,
model_filename, params_filename) model_filename, params_filename)
if self.train_config is not None and self.train_config.use_fleet: if self.train_config is not None and self.train_config.use_fleet:
...@@ -171,7 +171,7 @@ class AutoCompression: ...@@ -171,7 +171,7 @@ class AutoCompression:
infer_shape_model = self.create_tmp_dir( infer_shape_model = self.create_tmp_dir(
self.final_dir, prefix="infer_shape_model_") self.final_dir, prefix="infer_shape_model_")
self._infer_shape(model_dir, self.model_filename, self._infer_shape(self.model_dir, self.model_filename,
self.params_filename, input_shapes, self.params_filename, input_shapes,
infer_shape_model) infer_shape_model)
self.model_dir = infer_shape_model self.model_dir = infer_shape_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册