export_model.py 1.4 KB
Newer Older
Z
zhoujun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

import argparse

import paddle
from paddle.jit import to_static

from models import build_model
from utils import Config, ArgsParser


def init_args():
    parser = ArgsParser()
    args = parser.parse_args()
    return args


def load_checkpoint(model, checkpoint_path):
    """
    load checkpoints
    :param checkpoint_path: Checkpoint path to be loaded
    """
    checkpoint = paddle.load(checkpoint_path)
    model.set_state_dict(checkpoint['state_dict'])
    print('load checkpoint from {}'.format(checkpoint_path))


def main(config):
    model = build_model(config['arch'])
    load_checkpoint(model, config['trainer']['resume_checkpoint'])
    model.eval()

    save_path = config["trainer"]["output_dir"]
    save_path = os.path.join(save_path, "inference")
    infer_shape = [3, -1, -1]
    model = to_static(
        model,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None] + infer_shape, dtype="float32")
        ])

    paddle.jit.save(model, save_path)
    print("inference model is saved to {}".format(save_path))


if __name__ == "__main__":
    args = init_args()
    assert os.path.exists(args.config_file)
    config = Config(args.config_file)
    config.merge_dict(args.opt)
    main(config.cfg)