prune_io.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
import os
import paddle.fluid as fluid
from paddle.fluid import Program
from ..core import GraphWrapper
from ..common import get_logger
import json
import logging

__all__ = ["save_model", "load_model"]

_logger = get_logger(__name__, level=logging.INFO)

L
LielinJiang 已提交
13 14
_PARAMS_FILE = "__params__"
_SHAPES_FILE = "__shapes__"
15 16


L
LielinJiang 已提交
17
def save_model(exe, graph, dirname):
18 19 20 21 22 23 24 25 26
    """
    Save weights of model and information of shapes into filesystem.

    Args:
      - graph(Program|Graph): The graph to be saved.
      - dirname(str): The directory that the model saved into.
    """
    assert graph is not None and dirname is not None
    graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
L
LielinJiang 已提交
27

28 29 30 31
    fluid.io.save_params(
        executor=exe,
        dirname=dirname,
        main_program=graph.program,
L
LielinJiang 已提交
32 33
        filename=_PARAMS_FILE)
    weights_file = os.path.join(dirname, _PARAMS_FILE)
34 35 36 37
    _logger.info("Save model weights into {}".format(weights_file))
    shapes = {}
    for var in graph.all_parameters():
        shapes[var.name()] = var.shape()
L
LielinJiang 已提交
38
    SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
39 40 41 42 43
    with open(SHAPES_FILE, "w") as f:
        json.dump(shapes, f)
        _logger.info("Save shapes of weights into {}".format(SHAPES_FILE))


L
LielinJiang 已提交
44
def load_model(exe, graph, dirname):
45 46 47 48 49 50 51 52 53 54
    """
    Load weights of model and information of shapes from filesystem.

    Args:
      - graph(Program|Graph): The graph to be saved.
      - dirname(str): The directory that the model saved into.
    """
    assert graph is not None and dirname is not None
    graph = GraphWrapper(graph) if isinstance(graph, Program) else graph

L
LielinJiang 已提交
55
    SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
56 57 58 59 60 61 62 63 64 65 66 67
    _logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
    with open(SHAPES_FILE, "r") as f:
        shapes = json.load(f)
        for param, shape in shapes.items():
            graph.var(param).set_shape(shape)

    _logger.info("Load shapes of weights from {}".format(SHAPES_FILE))

    fluid.io.load_params(
        executor=exe,
        dirname=dirname,
        main_program=graph.program,
L
LielinJiang 已提交
68
        filename=_PARAMS_FILE)
69 70 71
    graph.update_groups_of_conv()
    graph.infer_shape()
    _logger.info("Load weights from {}".format(
L
LielinJiang 已提交
72
        os.path.join(dirname, _PARAMS_FILE)))