diff --git a/paddleslim/prune/io.py b/paddleslim/prune/io.py index a791cb20b36d5d99e8b9032854afaea911b96b24..5dcd781c1a658ee757441207a1801174784fdfcc 100644 --- a/paddleslim/prune/io.py +++ b/paddleslim/prune/io.py @@ -10,11 +10,11 @@ __all__ = ["save_model", "load_model"] _logger = get_logger(__name__, level=logging.INFO) -PARAMS_FILE = "__params__" -SHAPES_FILE = "__shapes__" +_PARAMS_FILE = "__params__" +_SHAPES_FILE = "__shapes__" -def save_model(graph, dirname): +def save_model(exe, graph, dirname): """ Save weights of model and information of shapes into filesystem. @@ -24,24 +24,24 @@ def save_model(graph, dirname): """ assert graph is not None and dirname is not None graph = GraphWrapper(graph) if isinstance(graph, Program) else graph - exe = fluid.Executor(fluid.CPUPlace()) + fluid.io.save_params( executor=exe, dirname=dirname, main_program=graph.program, - filename=PARAMS_FILE) - weights_file = os.path.join(dirname, PARAMS_FILE) + filename=_PARAMS_FILE) + weights_file = os.path.join(dirname, _PARAMS_FILE) _logger.info("Save model weights into {}".format(weights_file)) shapes = {} for var in graph.all_parameters(): shapes[var.name()] = var.shape() - SHAPES_FILE = os.path.join(dirname, SHAPES_FILE) + SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE) with open(SHAPES_FILE, "w") as f: json.dump(shapes, f) _logger.info("Save shapes of weights into {}".format(SHAPES_FILE)) -def load_model(graph, dirname): +def load_model(exe, graph, dirname): """ Load weights of model and information of shapes from filesystem. @@ -51,9 +51,8 @@ def load_model(graph, dirname): """ assert graph is not None and dirname is not None graph = GraphWrapper(graph) if isinstance(graph, Program) else graph - exe = fluid.Executor(fluid.CPUPlace()) - SHAPES_FILE = os.path.join(dirname, SHAPES_FILE) + SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE) _logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) with open(SHAPES_FILE, "r") as f: shapes = json.load(f) @@ -62,13 +61,12 @@ def load_model(graph, dirname): _logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) - exe = fluid.Executor(fluid.CPUPlace()) fluid.io.load_params( executor=exe, dirname=dirname, main_program=graph.program, - filename=PARAMS_FILE) + filename=_PARAMS_FILE) graph.update_groups_of_conv() graph.infer_shape() _logger.info("Load weights from {}".format( - os.path.join(dirname, PARAMS_FILE))) + os.path.join(dirname, _PARAMS_FILE)))