提交 f6fffd8c 编写于 作者: W wuzewu

Fix model compatibility issues

上级 601c81d2
......@@ -211,7 +211,12 @@ class ModuleV1(object):
'''Load the infomation of Module object defined in the specified directory.'''
desc_file = os.path.join(directory, 'module_desc.pb')
desc = module_v1_utils.convert_module_desc(desc_file)
return desc.module_info
# The naming of some old versions of Module is not standardized, which format of uppercase
# letters. This will cause the path of these modules to be incorrect after installation.
module_info = desc.module_info
module_info.name = module_info.name.lower()
return module_info
def assets_path(self):
return os.path.join(self.directory, 'assets')
......
......@@ -117,6 +117,8 @@ class BaseTask(object):
self._base_data_reader = data_reader
self._base_feed_list = feed_list
self._compatible_mode = True if data_reader else False
@contextlib.contextmanager
def phase_guard(self, phase: str):
self.enter_phase(phase)
......@@ -308,15 +310,19 @@ class BaseTask(object):
return wrapper
if self.is_predict_phase:
records = self._predict_data
if self._compatible_mode:
self.env.generator = self._base_data_reader.data_generator(
batch_size=self.config.batch_size, phase=self.phase, data=self._predict_data, return_list=True)
else:
if self.is_train_phase:
shuffle = True
if self.is_predict_phase:
records = self._predict_data
else:
shuffle = False
records = self.dataset.get_records(phase=self.phase, shuffle=shuffle)
self.env.generator = data_generator(records)
if self.is_train_phase:
shuffle = True
else:
shuffle = False
records = self.dataset.get_records(phase=self.phase, shuffle=shuffle)
self.env.generator = data_generator(records)
return self.env.generator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册