未验证 提交 6cfe8bcf 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #2092 from wzzju/checkpoint_utility

Add some infomation when loading.
......@@ -20,11 +20,15 @@ import distutils.util
import os
import numpy as np
import six
import logging
import paddle.fluid as fluid
import paddle.compat as cpt
from paddle.fluid import core
from paddle.fluid.framework import Program
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def print_arguments(args):
"""Print argparse's arguments.
......@@ -125,6 +129,13 @@ def load_persistable_nodes(executor, dirname, graph):
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
def _load_var(name, scope):
return np.array(scope.find_var(name).get_tensor())
def _store_var(name, array, scope, place):
tensor = scope.find_var(name).get_tensor()
tensor.set(array, place)
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
......@@ -139,4 +150,6 @@ def load_persistable_nodes(executor, dirname, graph):
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
else:
_logger.info("Cannot find the var %s!!!" %(node.name()))
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册