提交 89696c7b 编写于 作者: G gaotingquan 提交者: Wei Shengyu

fix error that sync bn should not be used on cpu

上级 2e41b5de
...@@ -39,7 +39,11 @@ def build_model(config, mode="train"): ...@@ -39,7 +39,11 @@ def build_model(config, mode="train"):
mod = importlib.import_module(__name__) mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**arch_config) arch = getattr(mod, model_type)(**arch_config)
if use_sync_bn: if use_sync_bn:
if config["Global"]["device"] == "gpu":
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch) arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
else:
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
logger.warning(msg)
if isinstance(arch, TheseusLayer): if isinstance(arch, TheseusLayer):
prune_model(config, arch) prune_model(config, arch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册