提交 76ab2323 编写于 作者: S SunGaofeng

remove redundant code in models/nolocal_model

上级 464bdc82
......@@ -19,6 +19,7 @@ import paddle.fluid as fluid
from ..model import ModelBase
import resnet_video
from .nonlocal_utils import load_params_from_file
import logging
logger = logging.getLogger(__name__)
......@@ -153,147 +154,3 @@ def get_learning_rate_decay_list(base_learning_rate, lr_decay, step_lists):
lr_values.append(base_learning_rate * decay_rate)
return lr_bounds, lr_values
def load_params_from_pkl_file(prog, pretrained_file, place):
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
if os.path.exists(pretrained_file):
params_from_file = cPickle.load(open(pretrained_file))
if len(params_from_file.keys()) == 1:
params_from_file = params_from_file['blobs']
param_name_from_file = params_from_file.keys()
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
common_names = get_common_names(param_name_list, param_name_from_file)
logger.info('-------- loading params -----------')
for name in common_names:
t = fluid.global_scope().find_var(name).get_tensor()
t_array = np.array(t)
f_array = params_from_file[name]
if 'pred' in name:
assert np.prod(t_array.shape) == np.prod(
f_array.shape), "number of params should be the same"
if t_array.shape == f_array.shape:
logger.info("pred param is the same {}".format(name))
else:
re_f_array = np.reshape(f_array, t_array.shape)
t.set(re_f_array.astype('float32'), place)
logger.info("load pred param {}".format(name))
continue
if t_array.shape == f_array.shape:
t.set(f_array.astype('float32'), place)
logger.info("load param {}".format(name))
elif (t_array.shape[:2] == f_array.shape[:2]) and (
t_array.shape[-2:] == f_array.shape[-2:]):
num_inflate = t_array.shape[2]
stack_f_array = np.stack(
[f_array] * num_inflate, axis=2) / float(num_inflate)
assert t_array.shape == stack_f_array.shape, "inflated shape should be the same with tensor {}".format(
name)
t.set(stack_f_array.astype('float32'), place)
logger.info("load inflated({}) param {}".format(num_inflate,
name))
else:
logger.info("Invalid case for name: {}".format(name))
raise
logger.info("finished loading params from resnet pretrained model")
def load_params_from_paddle_file(exe, prog, pretrained_file, place):
if os.path.isdir(pretrained_file):
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
param_shape = {}
for name in param_name_list:
param_tensor = fluid.global_scope().find_var(name).get_tensor()
param_shape[name] = np.array(param_tensor).shape
param_name_from_file = os.listdir(pretrained_file)
common_names = get_common_names(param_name_list, param_name_from_file)
logger.info('-------- loading params -----------')
# load params from file
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and \
os.path.exists(os.path.join(pretrained_file, var.name))
logger.info("Load pretrain weights from file {}".format(
pretrained_file))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrained_file, vars=vars, main_program=prog)
# reset params if necessary
for name in common_names:
t = fluid.global_scope().find_var(name).get_tensor()
t_array = np.array(t)
origin_shape = param_shape[name]
if 'pred' in name:
assert np.prod(t_array.shape) == np.prod(
origin_shape), "number of params should be the same"
if t_array.shape == origin_shape:
logger.info("pred param is the same {}".format(name))
else:
reshaped_t_array = np.reshape(t_array, origin_shape)
t.set(reshaped_t_array.astype('float32'), place)
logger.info("load pred param {}".format(name))
continue
if t_array.shape == origin_shape:
logger.info("load param {}".format(name))
elif (t_array.shape[:2] == origin_shape[:2]) and (
t_array.shape[-2:] == origin_shape[-2:]):
num_inflate = origin_shape[2]
stack_t_array = np.stack(
[t_array] * num_inflate, axis=2) / float(num_inflate)
assert origin_shape == stack_t_array.shape, "inflated shape should be the same with tensor {}".format(
name)
t.set(stack_t_array.astype('float32'), place)
logger.info("load inflated({}) param {}".format(num_inflate,
name))
else:
logger.info("Invalid case for name: {}".format(name))
raise
logger.info("finished loading params from resnet pretrained model")
else:
logger.info(
"pretrained file is not in a directory, not suitable to load params".
format(pretrained_file))
pass
def get_common_names(param_name_list, param_name_from_file):
# name check and return common names both in param_name_list and file
common_names = []
paddle_only_names = []
file_only_names = []
logger.info('-------- comon params -----------')
for name in param_name_list:
if name in param_name_from_file:
common_names.append(name)
logger.info(name)
else:
paddle_only_names.append(name)
logger.info('-------- paddle only params ----------')
for name in paddle_only_names:
logger.info(name)
logger.info('-------- file only params -----------')
for name in param_name_from_file:
if name in param_name_list:
assert name in common_names
else:
file_only_names.append(name)
logger.info(name)
return common_names
def load_params_from_file(exe, prog, pretrained_file, place):
logger.info('load params from {}'.format(pretrained_file))
if '.pkl' in pretrained_file:
load_params_from_pkl_file(prog, pretrained_file, place)
else:
load_params_from_paddle_file(exe, prog, pretrained_file, place)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import numpy as np
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
def load_params_from_file(exe, prog, pretrained_file, place):
logger.info('load params from {}'.format(pretrained_file))
if os.path.isdir(pretrained_file):
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
param_shape = {}
for name in param_name_list:
param_tensor = fluid.global_scope().find_var(name).get_tensor()
param_shape[name] = np.array(param_tensor).shape
param_name_from_file = os.listdir(pretrained_file)
common_names = get_common_names(param_name_list, param_name_from_file)
logger.info('-------- loading params -----------')
# load params from file
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and \
os.path.exists(os.path.join(pretrained_file, var.name))
logger.info("Load pretrain weights from file {}".format(
pretrained_file))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrained_file, vars=vars, main_program=prog)
# reset params if necessary
for name in common_names:
t = fluid.global_scope().find_var(name).get_tensor()
t_array = np.array(t)
origin_shape = param_shape[name]
if t_array.shape == origin_shape:
logger.info("load param {}".format(name))
elif (t_array.shape[:2] == origin_shape[:2]) and (
t_array.shape[-2:] == origin_shape[-2:]):
num_inflate = origin_shape[2]
stack_t_array = np.stack(
[t_array] * num_inflate, axis=2) / float(num_inflate)
assert origin_shape == stack_t_array.shape, "inflated shape should be the same with tensor {}".format(
name)
t.set(stack_t_array.astype('float32'), place)
logger.info("load inflated({}) param {}".format(num_inflate,
name))
else:
logger.info("Invalid case for name: {}".format(name))
raise
logger.info("finished loading params from resnet pretrained model")
else:
logger.info(
"pretrained file is not in a directory, not suitable to load params".
format(pretrained_file))
pass
def get_common_names(param_name_list, param_name_from_file):
# name check and return common names both in param_name_list and file
common_names = []
paddle_only_names = []
file_only_names = []
logger.info('-------- comon params -----------')
for name in param_name_list:
if name in param_name_from_file:
common_names.append(name)
logger.info(name)
else:
paddle_only_names.append(name)
logger.info('-------- paddle only params ----------')
for name in paddle_only_names:
logger.info(name)
logger.info('-------- file only params -----------')
for name in param_name_from_file:
if name in param_name_list:
assert name in common_names
else:
file_only_names.append(name)
logger.info(name)
return common_names
......@@ -61,6 +61,11 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede
save_model_name = 'model', test_exe = None, test_reader = None, \
test_feeder = None, test_fetch_list = None, test_metrics = None):
for epoch in range(epochs):
lr = fluid.global_scope().find_var("learning_rate").get_tensor()
lr_count = fluid.global_scope().find_var(
"@LR_DECAY_COUNTER@").get_tensor()
logger.info("------- learning rate {}, learning rate counter {} -----"
.format(np.array(lr), np.array(lr_count)))
epoch_periods = []
for train_iter, data in enumerate(train_reader()):
cur_time = time.time()
......@@ -80,7 +85,8 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede
format(epoch, np.mean(epoch_periods)))
save_model(exe, train_prog, save_dir, save_model_name,
"_epoch{}".format(epoch))
if test_exe and valid_interval > 0 and (epoch + 1) % valid_interval == 0:
if test_exe and valid_interval > 0 and (epoch + 1
) % valid_interval == 0:
test_without_pyreader(test_exe, test_reader, test_feeder,
test_fetch_list, test_metrics, log_interval)
......@@ -95,6 +101,11 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
if not train_pyreader:
logger.error("[TRAIN] get pyreader failed.")
for epoch in range(epochs):
lr = fluid.global_scope().find_var("learning_rate").get_tensor()
lr_count = fluid.global_scope().find_var(
"@LR_DECAY_COUNTER@").get_tensor()
logger.info("------- learning rate {}, learning rate counter {} -----"
.format(np.array(lr), np.array(lr_count)))
train_pyreader.start()
train_metrics.reset()
try:
......@@ -119,7 +130,8 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
format(epoch, np.mean(epoch_periods)))
save_model(exe, train_prog, save_dir, save_model_name,
"_epoch{}".format(epoch))
if test_exe and valid_interval > 0 and (epoch + 1) % valid_interval == 0:
if test_exe and valid_interval > 0 and (epoch + 1
) % valid_interval == 0:
test_with_pyreader(test_exe, test_pyreader, test_fetch_list,
test_metrics, log_interval)
finally:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册