提交 092ab674 编写于 作者: L LielinJiang 提交者: whs

fix save load bug (#47)

上级 d95da67b
...@@ -10,11 +10,11 @@ __all__ = ["save_model", "load_model"] ...@@ -10,11 +10,11 @@ __all__ = ["save_model", "load_model"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
PARAMS_FILE = "__params__" _PARAMS_FILE = "__params__"
SHAPES_FILE = "__shapes__" _SHAPES_FILE = "__shapes__"
def save_model(graph, dirname): def save_model(exe, graph, dirname):
""" """
Save weights of model and information of shapes into filesystem. Save weights of model and information of shapes into filesystem.
...@@ -24,24 +24,24 @@ def save_model(graph, dirname): ...@@ -24,24 +24,24 @@ def save_model(graph, dirname):
""" """
assert graph is not None and dirname is not None assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
exe = fluid.Executor(fluid.CPUPlace())
fluid.io.save_params( fluid.io.save_params(
executor=exe, executor=exe,
dirname=dirname, dirname=dirname,
main_program=graph.program, main_program=graph.program,
filename=PARAMS_FILE) filename=_PARAMS_FILE)
weights_file = os.path.join(dirname, PARAMS_FILE) weights_file = os.path.join(dirname, _PARAMS_FILE)
_logger.info("Save model weights into {}".format(weights_file)) _logger.info("Save model weights into {}".format(weights_file))
shapes = {} shapes = {}
for var in graph.all_parameters(): for var in graph.all_parameters():
shapes[var.name()] = var.shape() shapes[var.name()] = var.shape()
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE) SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
with open(SHAPES_FILE, "w") as f: with open(SHAPES_FILE, "w") as f:
json.dump(shapes, f) json.dump(shapes, f)
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE)) _logger.info("Save shapes of weights into {}".format(SHAPES_FILE))
def load_model(graph, dirname): def load_model(exe, graph, dirname):
""" """
Load weights of model and information of shapes from filesystem. Load weights of model and information of shapes from filesystem.
...@@ -51,9 +51,8 @@ def load_model(graph, dirname): ...@@ -51,9 +51,8 @@ def load_model(graph, dirname):
""" """
assert graph is not None and dirname is not None assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
exe = fluid.Executor(fluid.CPUPlace())
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE) SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) _logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f: with open(SHAPES_FILE, "r") as f:
shapes = json.load(f) shapes = json.load(f)
...@@ -62,13 +61,12 @@ def load_model(graph, dirname): ...@@ -62,13 +61,12 @@ def load_model(graph, dirname):
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE)) _logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
exe = fluid.Executor(fluid.CPUPlace())
fluid.io.load_params( fluid.io.load_params(
executor=exe, executor=exe,
dirname=dirname, dirname=dirname,
main_program=graph.program, main_program=graph.program,
filename=PARAMS_FILE) filename=_PARAMS_FILE)
graph.update_groups_of_conv() graph.update_groups_of_conv()
graph.infer_shape() graph.infer_shape()
_logger.info("Load weights from {}".format( _logger.info("Load weights from {}".format(
os.path.join(dirname, PARAMS_FILE))) os.path.join(dirname, _PARAMS_FILE)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册