提交 f6fffd8c 编写于 作者: W wuzewu

Fix model compatibility issues

上级 601c81d2
...@@ -211,7 +211,12 @@ class ModuleV1(object): ...@@ -211,7 +211,12 @@ class ModuleV1(object):
'''Load the infomation of Module object defined in the specified directory.''' '''Load the infomation of Module object defined in the specified directory.'''
desc_file = os.path.join(directory, 'module_desc.pb') desc_file = os.path.join(directory, 'module_desc.pb')
desc = module_v1_utils.convert_module_desc(desc_file) desc = module_v1_utils.convert_module_desc(desc_file)
return desc.module_info
# The naming of some old versions of Module is not standardized, which format of uppercase
# letters. This will cause the path of these modules to be incorrect after installation.
module_info = desc.module_info
module_info.name = module_info.name.lower()
return module_info
def assets_path(self): def assets_path(self):
return os.path.join(self.directory, 'assets') return os.path.join(self.directory, 'assets')
......
...@@ -117,6 +117,8 @@ class BaseTask(object): ...@@ -117,6 +117,8 @@ class BaseTask(object):
self._base_data_reader = data_reader self._base_data_reader = data_reader
self._base_feed_list = feed_list self._base_feed_list = feed_list
self._compatible_mode = True if data_reader else False
@contextlib.contextmanager @contextlib.contextmanager
def phase_guard(self, phase: str): def phase_guard(self, phase: str):
self.enter_phase(phase) self.enter_phase(phase)
...@@ -308,6 +310,10 @@ class BaseTask(object): ...@@ -308,6 +310,10 @@ class BaseTask(object):
return wrapper return wrapper
if self._compatible_mode:
self.env.generator = self._base_data_reader.data_generator(
batch_size=self.config.batch_size, phase=self.phase, data=self._predict_data, return_list=True)
else:
if self.is_predict_phase: if self.is_predict_phase:
records = self._predict_data records = self._predict_data
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册