未验证 提交 11a059c9 编写于 作者: G Guanghua Yu 提交者: GitHub

fix dygraph to static model error, test=dygraph (#2037)

上级 141234c8
......@@ -99,7 +99,8 @@ def _dump_infer_config(config, path, image_shape, model):
'Architecture: {} is not supported for exporting model now'.format(
infer_arch))
os._exit(0)
if getattr(model.__dict__, 'mask_post_process', None):
if 'mask_post_process' in model.__dict__ and model.__dict__[
'mask_post_process']:
infer_cfg['mask_resolution'] = model.mask_post_process.mask_resolution
infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
'label_list'], image_shape = _parse_reader(
......
......@@ -297,6 +297,7 @@ class Trainer(object):
return os.path.join(output_dir, "{}".format(name)) + ext
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
......
......@@ -30,7 +30,7 @@ import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.engine import Detector
from ppdet.engine import Trainer
from ppdet.utils.logger import setup_logger
logger = setup_logger('export_model')
......@@ -49,13 +49,13 @@ def parse_args():
def run(FLAGS, cfg):
# build detector
detector = Detector(cfg, mode='test')
trainer = Trainer(cfg, mode='test')
# load weights
detector.load_weights(cfg.weights, 'resume')
trainer.load_weights(cfg.weights, 'resume')
# export model
detector.export(FLAGS.output_dir)
trainer.export(FLAGS.output_dir)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册