未验证 提交 5a70c22a 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix config dependence & fix python 3.8/3.9 install (#2987)

* fix config dependence & fix python 3.8/3.9 install
上级 aacb294d
...@@ -217,7 +217,7 @@ class VisualDLWriter(Callback): ...@@ -217,7 +217,7 @@ class VisualDLWriter(Callback):
logger.error('visualdl not found, plaese install visualdl. ' logger.error('visualdl not found, plaese install visualdl. '
'for example: `pip install visualdl`.') 'for example: `pip install visualdl`.')
raise e raise e
self.vdl_writer = LogWriter(model.cfg.vdl_log_dir) self.vdl_writer = LogWriter(model.cfg.get('vdl_log_dir', 'vdl_log_dir/scalar'))
self.vdl_loss_step = 0 self.vdl_loss_step = 0
self.vdl_mAP_step = 0 self.vdl_mAP_step = 0
self.vdl_image_step = 0 self.vdl_image_step = 0
......
...@@ -116,7 +116,7 @@ class Trainer(object): ...@@ -116,7 +116,7 @@ class Trainer(object):
def _init_callbacks(self): def _init_callbacks(self):
if self.mode == 'train': if self.mode == 'train':
self._callbacks = [LogPrinter(self), Checkpointer(self)] self._callbacks = [LogPrinter(self), Checkpointer(self)]
if 'use_vdl' in self.cfg and self.cfg.use_vdl: if self.cfg.get('use_vdl', False):
self._callbacks.append(VisualDLWriter(self)) self._callbacks.append(VisualDLWriter(self))
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval': elif self.mode == 'eval':
...@@ -124,7 +124,7 @@ class Trainer(object): ...@@ -124,7 +124,7 @@ class Trainer(object):
if self.cfg.metric == 'WiderFace': if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self)) self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and 'use_vdl' in self.cfg and self.cfg.use_vdl: elif self.mode == 'test' and self.cfg.get('use_vdl', False):
self._callbacks = [VisualDLWriter(self)] self._callbacks = [VisualDLWriter(self)]
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
else: else:
...@@ -141,8 +141,7 @@ class Trainer(object): ...@@ -141,8 +141,7 @@ class Trainer(object):
bias = self.cfg['bias'] if 'bias' in self.cfg else 0 bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \ output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg['save_prediction_only'] \ save_prediction_only = self.cfg.get('save_prediction_only', False)
if 'save_prediction_only' in self.cfg else False
# pass clsid2catid info to metric instance to avoid multiple loading # pass clsid2catid info to metric instance to avoid multiple loading
# annotation file # annotation file
...@@ -253,7 +252,7 @@ class Trainer(object): ...@@ -253,7 +252,7 @@ class Trainer(object):
self._reset_metrics() self._reset_metrics()
model = self.model model = self.model
if self.cfg.fleet: if self.cfg.get('fleet', False):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer( self.optimizer = fleet.distributed_optimizer(
self.optimizer).user_defined_optimizer self.optimizer).user_defined_optimizer
...@@ -264,7 +263,7 @@ class Trainer(object): ...@@ -264,7 +263,7 @@ class Trainer(object):
self.model, find_unused_parameters=find_unused_parameters) self.model, find_unused_parameters=find_unused_parameters)
# initial fp16 # initial fp16
if self.cfg.fp16: if self.cfg.get('fp16', False):
scaler = amp.GradScaler( scaler = amp.GradScaler(
enable=self.cfg.use_gpu, init_loss_scaling=1024) enable=self.cfg.use_gpu, init_loss_scaling=1024)
...@@ -292,7 +291,7 @@ class Trainer(object): ...@@ -292,7 +291,7 @@ class Trainer(object):
self.status['step_id'] = step_id self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
if self.cfg.fp16: if self.cfg.get('fp16', False):
with amp.auto_cast(enable=self.cfg.use_gpu): with amp.auto_cast(enable=self.cfg.use_gpu):
# model forward # model forward
outputs = model(data) outputs = model(data)
......
tqdm tqdm
typeguard ; python_version >= '3.4' typeguard ; python_version >= '3.4'
visualdl>=2.1.0 visualdl>=2.1.0 ; python_version <= '3.7'
opencv-python opencv-python
PyYAML PyYAML
shapely shapely
...@@ -14,4 +14,4 @@ lap ...@@ -14,4 +14,4 @@ lap
sklearn sklearn
cython_bbox cython_bbox
motmetrics motmetrics
openpyxl openpyxl
\ No newline at end of file
...@@ -87,6 +87,8 @@ if __name__ == "__main__": ...@@ -87,6 +87,8 @@ if __name__ == "__main__":
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities' 'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
'Programming Language :: Python :: 3.8', 'Topic :: Utilities'
'Programming Language :: Python :: 3.9', 'Topic :: Utilities'
], ],
license='Apache License 2.0', license='Apache License 2.0',
ext_modules=[]) ext_modules=[])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册