未验证 提交 11a059c9 编写于 作者: G Guanghua Yu 提交者: GitHub

fix dygraph to static model error, test=dygraph (#2037)

上级 141234c8
...@@ -99,7 +99,8 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -99,7 +99,8 @@ def _dump_infer_config(config, path, image_shape, model):
'Architecture: {} is not supported for exporting model now'.format( 'Architecture: {} is not supported for exporting model now'.format(
infer_arch)) infer_arch))
os._exit(0) os._exit(0)
if getattr(model.__dict__, 'mask_post_process', None): if 'mask_post_process' in model.__dict__ and model.__dict__[
'mask_post_process']:
infer_cfg['mask_resolution'] = model.mask_post_process.mask_resolution infer_cfg['mask_resolution'] = model.mask_post_process.mask_resolution
infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[ infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
'label_list'], image_shape = _parse_reader( 'label_list'], image_shape = _parse_reader(
......
...@@ -297,6 +297,7 @@ class Trainer(object): ...@@ -297,6 +297,7 @@ class Trainer(object):
return os.path.join(output_dir, "{}".format(name)) + ext return os.path.join(output_dir, "{}".format(name)) + ext
def export(self, output_dir='output_inference'): def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0] model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name) save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
......
...@@ -30,7 +30,7 @@ import paddle ...@@ -30,7 +30,7 @@ import paddle
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.engine import Detector from ppdet.engine import Trainer
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('export_model') logger = setup_logger('export_model')
...@@ -49,13 +49,13 @@ def parse_args(): ...@@ -49,13 +49,13 @@ def parse_args():
def run(FLAGS, cfg): def run(FLAGS, cfg):
# build detector # build detector
detector = Detector(cfg, mode='test') trainer = Trainer(cfg, mode='test')
# load weights # load weights
detector.load_weights(cfg.weights, 'resume') trainer.load_weights(cfg.weights, 'resume')
# export model # export model
detector.export(FLAGS.output_dir) trainer.export(FLAGS.output_dir)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册