提交 133458b1 编写于 作者: D dengkaipeng

use load_params

上级 fc0e6988
data
checkpoints
output*
*.py
*.swp
......@@ -13,6 +13,7 @@
#limitations under the License.
import os
import logging
try:
from configparser import ConfigParser
except:
......@@ -25,6 +26,8 @@ from .utils import download, AttrDict
WEIGHT_DIR = os.path.expanduser("~/.paddle/weights")
logger = logging.getLogger(__name__)
class NotImplementError(Exception):
"Error: model function not implement"
......@@ -163,15 +166,14 @@ class ModelBase(object):
"get model weight default path and download url"
raise NotImplementError(self, self.weights_info)
def get_weights(self, logger=None):
def get_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.weights_info()
path = os.path.join(WEIGHT_DIR, path)
if os.path.exists(path):
return path
if logger:
logger.info("Download weights of {} from {}".format(self.name, url))
logger.info("Download weights of {} from {}".format(self.name, url))
download(url, path)
return path
......@@ -186,7 +188,7 @@ class ModelBase(object):
"get pretrain base model directory"
return (None, None)
def get_pretrain_weights(self, logger=None):
def get_pretrain_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.pretrain_info()
if not path:
......@@ -196,22 +198,15 @@ class ModelBase(object):
if os.path.exists(path):
return path
if logger:
logger.info("Download pretrain weights of {} from {}".format(
logger.info("Download pretrain weights of {} from {}".format(
self.name, url))
utils.download(url, path)
return path
def load_pretrained_params(self, exe, pretrained_base, prog, place):
def load_pretrain_params(self, exe, pretrain, prog):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
inference_program = prog.clone(for_test=True)
fluid.io.load_vars(
exe,
pretrained_base,
predicate=if_exist,
main_program=inference_program)
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
cfg_item = self._config.get_config_from_sec(sec.upper(),
......
......@@ -153,15 +153,8 @@ class STNET(ModelBase):
def create_metrics_args(self):
return {}
def load_pretrained_params(self, exe, pretrain_base, prog, place):
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
inference_program = prog.clone(for_test=True)
vars = filter(is_parameter, inference_program.list_vars())
fluid.io.load_vars(exe, pretrain_base, vars=vars)
def load_pretrain_params(self, exe, pretrain, prog):
fluid.io.load_params(exe, pretrain, main_program=prog)
param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor()
......
......@@ -21,6 +21,7 @@ import numpy as np
import paddle.fluid as fluid
import models
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
......@@ -152,7 +152,7 @@ def train(train_model, valid_model, args):
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrained_params(exe, pretrain, train_prog, place)
train_model.load_pretrain_params(exe, pretrain, train_prog)
if args.no_parallel:
train_exe = exe
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册