提交 092ab674 编写于 作者: L LielinJiang 提交者: whs

fix save load bug (#47)

上级 d95da67b
......@@ -10,11 +10,11 @@ __all__ = ["save_model", "load_model"]
_logger = get_logger(__name__, level=logging.INFO)
PARAMS_FILE = "__params__"
SHAPES_FILE = "__shapes__"
_PARAMS_FILE = "__params__"
_SHAPES_FILE = "__shapes__"
def save_model(graph, dirname):
def save_model(exe, graph, dirname):
"""
Save weights of model and information of shapes into filesystem.
......@@ -24,24 +24,24 @@ def save_model(graph, dirname):
"""
assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
exe = fluid.Executor(fluid.CPUPlace())
fluid.io.save_params(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=PARAMS_FILE)
weights_file = os.path.join(dirname, PARAMS_FILE)
filename=_PARAMS_FILE)
weights_file = os.path.join(dirname, _PARAMS_FILE)
_logger.info("Save model weights into {}".format(weights_file))
shapes = {}
for var in graph.all_parameters():
shapes[var.name()] = var.shape()
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE)
SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
with open(SHAPES_FILE, "w") as f:
json.dump(shapes, f)
_logger.info("Save shapes of weights into {}".format(SHAPES_FILE))
def load_model(graph, dirname):
def load_model(exe, graph, dirname):
"""
Load weights of model and information of shapes from filesystem.
......@@ -51,9 +51,8 @@ def load_model(graph, dirname):
"""
assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
exe = fluid.Executor(fluid.CPUPlace())
SHAPES_FILE = os.path.join(dirname, SHAPES_FILE)
SHAPES_FILE = os.path.join(dirname, _SHAPES_FILE)
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f:
shapes = json.load(f)
......@@ -62,13 +61,12 @@ def load_model(graph, dirname):
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
exe = fluid.Executor(fluid.CPUPlace())
fluid.io.load_params(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=PARAMS_FILE)
filename=_PARAMS_FILE)
graph.update_groups_of_conv()
graph.infer_shape()
_logger.info("Load weights from {}".format(
os.path.join(dirname, PARAMS_FILE)))
os.path.join(dirname, _PARAMS_FILE)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册