未验证 提交 995b5067 编写于 作者: G Guanghua Yu 提交者: GitHub

fix fuse conv-bn after reparameterize (#7189)

上级 4e2546e3
......@@ -1002,6 +1002,10 @@ class Trainer(object):
if hasattr(layer, 'convert_to_deploy'):
layer.convert_to_deploy()
if hasattr(self.cfg, 'export') and 'fuse_conv_bn' in self.cfg[
'export'] and self.cfg['export']['fuse_conv_bn']:
self.model = fuse_conv_bn(self.model)
export_post_process = self.cfg['export'].get(
'post_process', False) if hasattr(self.cfg, 'export') else True
export_nms = self.cfg['export'].get('nms', False) if hasattr(
......@@ -1073,10 +1077,6 @@ class Trainer(object):
def export(self, output_dir='output_inference'):
self.model.eval()
if hasattr(self.cfg, 'export') and 'fuse_conv_bn' in self.cfg[
'export'] and self.cfg['export']['fuse_conv_bn']:
self.model = fuse_conv_bn(self.model)
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册