未验证 提交 5a70c22a 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix config dependence & fix python 3.8/3.9 install (#2987)

* fix config dependence & fix python 3.8/3.9 install
上级 aacb294d
......@@ -217,7 +217,7 @@ class VisualDLWriter(Callback):
logger.error('visualdl not found, plaese install visualdl. '
'for example: `pip install visualdl`.')
raise e
self.vdl_writer = LogWriter(model.cfg.vdl_log_dir)
self.vdl_writer = LogWriter(model.cfg.get('vdl_log_dir', 'vdl_log_dir/scalar'))
self.vdl_loss_step = 0
self.vdl_mAP_step = 0
self.vdl_image_step = 0
......
......@@ -116,7 +116,7 @@ class Trainer(object):
def _init_callbacks(self):
if self.mode == 'train':
self._callbacks = [LogPrinter(self), Checkpointer(self)]
if 'use_vdl' in self.cfg and self.cfg.use_vdl:
if self.cfg.get('use_vdl', False):
self._callbacks.append(VisualDLWriter(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval':
......@@ -124,7 +124,7 @@ class Trainer(object):
if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and 'use_vdl' in self.cfg and self.cfg.use_vdl:
elif self.mode == 'test' and self.cfg.get('use_vdl', False):
self._callbacks = [VisualDLWriter(self)]
self._compose_callback = ComposeCallback(self._callbacks)
else:
......@@ -141,8 +141,7 @@ class Trainer(object):
bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg['save_prediction_only'] \
if 'save_prediction_only' in self.cfg else False
save_prediction_only = self.cfg.get('save_prediction_only', False)
# pass clsid2catid info to metric instance to avoid multiple loading
# annotation file
......@@ -253,7 +252,7 @@ class Trainer(object):
self._reset_metrics()
model = self.model
if self.cfg.fleet:
if self.cfg.get('fleet', False):
model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer(
self.optimizer).user_defined_optimizer
......@@ -264,7 +263,7 @@ class Trainer(object):
self.model, find_unused_parameters=find_unused_parameters)
# initial fp16
if self.cfg.fp16:
if self.cfg.get('fp16', False):
scaler = amp.GradScaler(
enable=self.cfg.use_gpu, init_loss_scaling=1024)
......@@ -292,7 +291,7 @@ class Trainer(object):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
if self.cfg.fp16:
if self.cfg.get('fp16', False):
with amp.auto_cast(enable=self.cfg.use_gpu):
# model forward
outputs = model(data)
......
tqdm
typeguard ; python_version >= '3.4'
visualdl>=2.1.0
visualdl>=2.1.0 ; python_version <= '3.7'
opencv-python
PyYAML
shapely
......@@ -14,4 +14,4 @@ lap
sklearn
cython_bbox
motmetrics
openpyxl
\ No newline at end of file
openpyxl
......@@ -87,6 +87,8 @@ if __name__ == "__main__":
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
'Programming Language :: Python :: 3.8', 'Topic :: Utilities'
'Programming Language :: Python :: 3.9', 'Topic :: Utilities'
],
license='Apache License 2.0',
ext_modules=[])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册