From f6fffd8ce9394c7ee9cab97fa08778953ab906dc Mon Sep 17 00:00:00 2001 From: wuzewu Date: Tue, 10 Nov 2020 13:54:34 +0800 Subject: [PATCH] Fix model compatibility issues --- paddlehub/compat/module/module_v1.py | 7 ++++++- paddlehub/compat/task/base_task.py | 20 +++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/paddlehub/compat/module/module_v1.py b/paddlehub/compat/module/module_v1.py index d5f6ff4e..694db191 100644 --- a/paddlehub/compat/module/module_v1.py +++ b/paddlehub/compat/module/module_v1.py @@ -211,7 +211,12 @@ class ModuleV1(object): '''Load the infomation of Module object defined in the specified directory.''' desc_file = os.path.join(directory, 'module_desc.pb') desc = module_v1_utils.convert_module_desc(desc_file) - return desc.module_info + + # The naming of some old versions of Module is not standardized, which format of uppercase + # letters. This will cause the path of these modules to be incorrect after installation. + module_info = desc.module_info + module_info.name = module_info.name.lower() + return module_info def assets_path(self): return os.path.join(self.directory, 'assets') diff --git a/paddlehub/compat/task/base_task.py b/paddlehub/compat/task/base_task.py index 38a837b0..4258f4de 100644 --- a/paddlehub/compat/task/base_task.py +++ b/paddlehub/compat/task/base_task.py @@ -117,6 +117,8 @@ class BaseTask(object): self._base_data_reader = data_reader self._base_feed_list = feed_list + self._compatible_mode = True if data_reader else False + @contextlib.contextmanager def phase_guard(self, phase: str): self.enter_phase(phase) @@ -308,15 +310,19 @@ class BaseTask(object): return wrapper - if self.is_predict_phase: - records = self._predict_data + if self._compatible_mode: + self.env.generator = self._base_data_reader.data_generator( + batch_size=self.config.batch_size, phase=self.phase, data=self._predict_data, return_list=True) else: - if self.is_train_phase: - shuffle = True + if self.is_predict_phase: + records = self._predict_data else: - shuffle = False - records = self.dataset.get_records(phase=self.phase, shuffle=shuffle) - self.env.generator = data_generator(records) + if self.is_train_phase: + shuffle = True + else: + shuffle = False + records = self.dataset.get_records(phase=self.phase, shuffle=shuffle) + self.env.generator = data_generator(records) return self.env.generator -- GitLab