diff --git a/PaddleSlim/utility.py b/PaddleSlim/utility.py index b084813590c45030ea5a0acfa6512fa0d7c4bb70..90a5ffe75fa1bbb752c0bcfc25ea08e9f966f3be 100644 --- a/PaddleSlim/utility.py +++ b/PaddleSlim/utility.py @@ -20,11 +20,15 @@ import distutils.util import os import numpy as np import six +import logging import paddle.fluid as fluid import paddle.compat as cpt from paddle.fluid import core from paddle.fluid.framework import Program +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) def print_arguments(args): """Print argparse's arguments. @@ -125,6 +129,13 @@ def load_persistable_nodes(executor, dirname, graph): def _exist(var): return os.path.exists(os.path.join(dirname, var.name)) + def _load_var(name, scope): + return np.array(scope.find_var(name).get_tensor()) + + def _store_var(name, array, scope, place): + tensor = scope.find_var(name).get_tensor() + tensor.set(array, place) + for node in persistable_nodes: var_desc = node.var() if var_desc.type() == core.VarDesc.VarType.RAW or \ @@ -139,4 +150,6 @@ def load_persistable_nodes(executor, dirname, graph): persistable=var_desc.persistable()) if _exist(var): var_list.append(var) + else: + _logger.info("Cannot find the var %s!!!" %(node.name())) fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)