提交 133458b1 编写于 作者: D dengkaipeng

use load_params

上级 fc0e6988
data
checkpoints
output*
*.py
*.swp
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#limitations under the License. #limitations under the License.
import os import os
import logging
try: try:
from configparser import ConfigParser from configparser import ConfigParser
except: except:
...@@ -25,6 +26,8 @@ from .utils import download, AttrDict ...@@ -25,6 +26,8 @@ from .utils import download, AttrDict
WEIGHT_DIR = os.path.expanduser("~/.paddle/weights") WEIGHT_DIR = os.path.expanduser("~/.paddle/weights")
logger = logging.getLogger(__name__)
class NotImplementError(Exception): class NotImplementError(Exception):
"Error: model function not implement" "Error: model function not implement"
...@@ -163,14 +166,13 @@ class ModelBase(object): ...@@ -163,14 +166,13 @@ class ModelBase(object):
"get model weight default path and download url" "get model weight default path and download url"
raise NotImplementError(self, self.weights_info) raise NotImplementError(self, self.weights_info)
def get_weights(self, logger=None): def get_weights(self):
"get model weight file path, download weight from Paddle if not exist" "get model weight file path, download weight from Paddle if not exist"
path, url = self.weights_info() path, url = self.weights_info()
path = os.path.join(WEIGHT_DIR, path) path = os.path.join(WEIGHT_DIR, path)
if os.path.exists(path): if os.path.exists(path):
return path return path
if logger:
logger.info("Download weights of {} from {}".format(self.name, url)) logger.info("Download weights of {} from {}".format(self.name, url))
download(url, path) download(url, path)
return path return path
...@@ -186,7 +188,7 @@ class ModelBase(object): ...@@ -186,7 +188,7 @@ class ModelBase(object):
"get pretrain base model directory" "get pretrain base model directory"
return (None, None) return (None, None)
def get_pretrain_weights(self, logger=None): def get_pretrain_weights(self):
"get model weight file path, download weight from Paddle if not exist" "get model weight file path, download weight from Paddle if not exist"
path, url = self.pretrain_info() path, url = self.pretrain_info()
if not path: if not path:
...@@ -196,22 +198,15 @@ class ModelBase(object): ...@@ -196,22 +198,15 @@ class ModelBase(object):
if os.path.exists(path): if os.path.exists(path):
return path return path
if logger:
logger.info("Download pretrain weights of {} from {}".format( logger.info("Download pretrain weights of {} from {}".format(
self.name, url)) self.name, url))
utils.download(url, path) utils.download(url, path)
return path return path
def load_pretrained_params(self, exe, pretrained_base, prog, place): def load_pretrain_params(self, exe, pretrain, prog):
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name)) return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog)
inference_program = prog.clone(for_test=True)
fluid.io.load_vars(
exe,
pretrained_base,
predicate=if_exist,
main_program=inference_program)
def get_config_from_sec(self, sec, item, default=None): def get_config_from_sec(self, sec, item, default=None):
cfg_item = self._config.get_config_from_sec(sec.upper(), cfg_item = self._config.get_config_from_sec(sec.upper(),
......
...@@ -153,15 +153,8 @@ class STNET(ModelBase): ...@@ -153,15 +153,8 @@ class STNET(ModelBase):
def create_metrics_args(self): def create_metrics_args(self):
return {} return {}
def load_pretrained_params(self, exe, pretrain_base, prog, place): def load_pretrain_params(self, exe, pretrain, prog):
def is_parameter(var): fluid.io.load_params(exe, pretrain, main_program=prog)
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
inference_program = prog.clone(for_test=True)
vars = filter(is_parameter, inference_program.list_vars())
fluid.io.load_vars(exe, pretrain_base, vars=vars)
param_tensor = fluid.global_scope().find_var( param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor() "conv1_weights").get_tensor()
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import models import models
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -152,7 +152,7 @@ def train(train_model, valid_model, args): ...@@ -152,7 +152,7 @@ def train(train_model, valid_model, args):
"Given pretrain weight dir {} not exist.".format(args.pretrain) "Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights() pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain: if pretrain:
train_model.load_pretrained_params(exe, pretrain, train_prog, place) train_model.load_pretrain_params(exe, pretrain, train_prog)
if args.no_parallel: if args.no_parallel:
train_exe = exe train_exe = exe
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册