提交 3b779edd 编写于 作者: S SunGaofeng

add nonlocal model

上级 df1d8f80
...@@ -19,12 +19,15 @@ except: ...@@ -19,12 +19,15 @@ except:
from utils import AttrDict from utils import AttrDict
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [ CONFIG_SECS = [
'train', 'train',
'valid', 'valid',
'test', 'test',
'infer', 'infer',
] ]
def parse_config(cfg_file): def parse_config(cfg_file):
...@@ -43,6 +46,7 @@ def parse_config(cfg_file): ...@@ -43,6 +46,7 @@ def parse_config(cfg_file):
return cfg return cfg
def merge_configs(cfg, sec, args_dict): def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec) assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper()) sec_dict = getattr(cfg, sec.upper())
...@@ -56,3 +60,8 @@ def merge_configs(cfg, sec, args_dict): ...@@ -56,3 +60,8 @@ def merge_configs(cfg, sec, args_dict):
pass pass
return cfg return cfg
def print_configs(cfg):
import pprint
logger.info('Training with config:')
logger.info(pprint.pformat(cfg))
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
[MODEL]
name = "NONLOCAL"
num_classes = 400
image_mean = 114.75
image_std = 57.375
depth = 50
dataset = 'kinetics400'
video_arc_choice = 1
use_affine = False
fc_init_std = 0.01
bn_momentum = 0.9
bn_epsilon = 1.0e-5
bn_init_gamma = 0.
[RESNETS]
num_groups = 1
width_per_group = 64
trans_func = bottleneck_transformation_3d
[NONLOCAL]
bn_momentum = 0.9
bn_epsilon = 1.0e-5
bn_init_gamma = 0.0
layer_mod = 2
conv3_nonlocal = True
conv4_nonlocal = True
conv_init_std = 0.01
no_bias = 0
use_maxpool = True
use_softmax = True
use_scale = True
use_zero_init_conv = False
use_bn = True
use_affine = False
[TRAIN]
num_reader_threads = 8
batch_size = 64
num_gpus = 8
filelist = './dataset/nonlocal/trainlist.txt'
crop_size = 224
sample_rate = 8
video_length = 8
jitter_scales = [256, 320]
dropout_rate = 0.5
learning_rate = 0.01
learning_rate_decay = 0.1
step_sizes = [150000, 150000, 100000]
max_iter = 400000
weight_decay = 0.0001
weight_decay_bn = 0.0
momentum = 0.9
nesterov = True
scale_momentum = True
[VALID]
num_reader_threads = 8
batch_size = 64
filelist = './dataset/nonlocal/vallist.txt'
crop_size = 224
sample_rate = 8
video_length = 8
jitter_scales = [256, 320]
[TEST]
num_reader_threads = 8
batch_size = 4
filelist = 'dataset/nonlocal/testlist.txt'
filename_gt = 'dataset/nonlocal/vallist.txt'
checkpoint_dir = './output'
crop_size = 256
sample_rate = 8
video_length = 8
jitter_scales = [256, 256]
num_test_clips = 30
dataset_size = 19761
use_multi_crop = 1
[INFER]
num_reader_threads = 8
batch_size = 1
filelist = 'dataset/nonlocal/inferencelist.txt'
crop_size = 256
sample_rate = 8
video_length = 8
jitter_scales = [256, 256]
num_test_clips = 30
use_multi_crop = 1
文件模式从 100755 更改为 100644
文件模式从 100755 更改为 100644
...@@ -34,7 +34,7 @@ class NonlocalReader(DataReader): ...@@ -34,7 +34,7 @@ class NonlocalReader(DataReader):
image_mean image_mean
image_std image_std
batch_size batch_size
list filelist
crop_size crop_size
sample_rate sample_rate
video_length video_length
...@@ -68,7 +68,7 @@ class NonlocalReader(DataReader): ...@@ -68,7 +68,7 @@ class NonlocalReader(DataReader):
dataset_args['min_size'] = cfg[mode.upper()]['jitter_scales'][0] dataset_args['min_size'] = cfg[mode.upper()]['jitter_scales'][0]
dataset_args['max_size'] = cfg[mode.upper()]['jitter_scales'][1] dataset_args['max_size'] = cfg[mode.upper()]['jitter_scales'][1]
dataset_args['num_reader_threads'] = num_reader_threads dataset_args['num_reader_threads'] = num_reader_threads
filelist = cfg[mode.upper()]['list'] filelist = cfg[mode.upper()]['filelist']
batch_size = cfg[mode.upper()]['batch_size'] batch_size = cfg[mode.upper()]['batch_size']
if self.mode == 'train': if self.mode == 'train':
...@@ -146,8 +146,8 @@ def apply_resize(rgbdata, min_size, max_size): ...@@ -146,8 +146,8 @@ def apply_resize(rgbdata, min_size, max_size):
ratio = float(side_length) / float(width) ratio = float(side_length) / float(width)
else: else:
ratio = float(side_length) / float(height) ratio = float(side_length) / float(height)
out_height = int(height * ratio) out_height = int(round(height * ratio))
out_width = int(width * ratio) out_width = int(round(width * ratio))
outdata = np.zeros( outdata = np.zeros(
(length, out_height, out_width, channel), dtype=rgbdata.dtype) (length, out_height, out_width, channel), dtype=rgbdata.dtype)
for i in range(length): for i in range(length):
...@@ -197,14 +197,13 @@ def crop_mirror_transform(rgbdata, ...@@ -197,14 +197,13 @@ def crop_mirror_transform(rgbdata,
def make_reader(filelist, batch_size, sample_times, is_training, shuffle, def make_reader(filelist, batch_size, sample_times, is_training, shuffle,
**dataset_args): **dataset_args):
# should add smaple_times param def reader():
fl = open(filelist).readlines() fl = open(filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != ''] fl = [line.strip() for line in fl if line.strip() != '']
if shuffle: if shuffle:
random.shuffle(fl) random.shuffle(fl)
def reader():
batch_out = [] batch_out = []
for line in fl: for line in fl:
# start_time = time.time() # start_time = time.time()
...@@ -253,23 +252,6 @@ def make_reader(filelist, batch_size, sample_times, is_training, shuffle, ...@@ -253,23 +252,6 @@ def make_reader(filelist, batch_size, sample_times, is_training, shuffle,
def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle, def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle,
**dataset_args): **dataset_args):
fl = open(filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != '']
if shuffle:
random.shuffle(fl)
n = dataset_args['num_reader_threads']
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
def read_into_queue(flq, queue): def read_into_queue(flq, queue):
batch_out = [] batch_out = []
for line in flq: for line in flq:
...@@ -315,6 +297,24 @@ def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle, ...@@ -315,6 +297,24 @@ def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle,
queue.put(None) queue.put(None)
def queue_reader(): def queue_reader():
# split file list and shuffle
fl = open(filelist).readlines()
fl = [line.strip() for line in fl if line.strip() != '']
if shuffle:
random.shuffle(fl)
n = dataset_args['num_reader_threads']
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size) queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists) p_list = [None] * len(reader_lists)
# for reader_list in reader_lists: # for reader_list in reader_lists:
......
...@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) ...@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--model-name', '--model_name',
type=str, type=str,
default='AttentionCluster', default='AttentionCluster',
help='name of model to train.') help='name of model to train.')
...@@ -47,14 +47,14 @@ def parse_args(): ...@@ -47,14 +47,14 @@ def parse_args():
default='configs/attention_cluster.txt', default='configs/attention_cluster.txt',
help='path to config file of model') help='path to config file of model')
parser.add_argument( parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.') '--use_gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument( parser.add_argument(
'--weights', '--weights',
type=str, type=str,
default=None, default=None,
help='weight path, None to use weights from Paddle.') help='weight path, None to use weights from Paddle.')
parser.add_argument( parser.add_argument(
'--batch-size', '--batch_size',
type=int, type=int,
default=1, default=1,
help='sample number in a batch for inference.') help='sample number in a batch for inference.')
...@@ -64,17 +64,17 @@ def parse_args(): ...@@ -64,17 +64,17 @@ def parse_args():
default=None, default=None,
help='path to inferenece data file lists file.') help='path to inferenece data file lists file.')
parser.add_argument( parser.add_argument(
'--log-interval', '--log_interval',
type=int, type=int,
default=1, default=1,
help='mini-batch interval to log.') help='mini-batch interval to log.')
parser.add_argument( parser.add_argument(
'--infer-topk', '--infer_topk',
type=int, type=int,
default=20, default=20,
help='topk predictions to restore.') help='topk predictions to restore.')
parser.add_argument( parser.add_argument(
'--save-dir', type=str, default='./', help='directory to store results') '--save_dir', type=str, default='./', help='directory to store results')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -126,8 +126,7 @@ def infer(args): ...@@ -126,8 +126,7 @@ def infer(args):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:] topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1] topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds] preds = predictions[i][topk_inds]
results.append( results.append((video_id[i], preds.tolist(), topk_inds.tolist()))
(video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time prev_time = cur_time
cur_time = time.time() cur_time = time.time()
period = cur_time - prev_time period = cur_time - prev_time
...@@ -145,6 +144,7 @@ def infer(args): ...@@ -145,6 +144,7 @@ def infer(args):
"{}_infer_result".format(args.model_name)) "{}_infer_result".format(args.model_name))
pickle.dump(results, open(result_file_name, 'wb')) pickle.dump(results, open(result_file_name, 'wb'))
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger.info(args) logger.info(args)
......
...@@ -63,6 +63,7 @@ class MetricsCalculator(): ...@@ -63,6 +63,7 @@ class MetricsCalculator():
def accumulate(self, loss, pred, labels): def accumulate(self, loss, pred, labels):
labels = labels.astype(int) labels = labels.astype(int)
labels = labels[:, 0]
for i in range(pred.shape[0]): for i in range(pred.shape[0]):
probs = pred[i, :].tolist() probs = pred[i, :].tolist()
vid = labels[i] vid = labels[i]
...@@ -81,6 +82,8 @@ class MetricsCalculator(): ...@@ -81,6 +82,8 @@ class MetricsCalculator():
evaluate_results(self.results, self.filename_gt, self.dataset_size, \ evaluate_results(self.results, self.filename_gt, self.dataset_size, \
self.num_classes, self.num_test_clips) self.num_classes, self.num_test_clips)
# save temporary file # save temporary file
if not os.path.isdir(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
pkl_path = os.path.join(self.checkpoint_dir, "results_probs.pkl") pkl_path = os.path.join(self.checkpoint_dir, "results_probs.pkl")
with open(pkl_path, 'w') as f: with open(pkl_path, 'w') as f:
...@@ -188,26 +191,4 @@ def evaluate_results(results, filename_gt, test_dataset_size, num_classes, ...@@ -188,26 +191,4 @@ def evaluate_results(results, filename_gt, test_dataset_size, num_classes,
logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100)) logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100))
logger.info('-' * 80) logger.info('-' * 80)
for i in range(sample_num):
prob = probs[i]
# top-1
idx = prob.argmax()
if idx == gt_labels[i] and counts[i] > 0:
accuracy = accuracy + 1
ids = np.argsort(prob)[::-1]
for j in range(5):
if ids[j] == gt_labels[i] and counts[i] > 0:
accuracy_top5 = accuracy_top5 + 1
break
accuracy = float(accuracy) / float(sample_num)
accuracy_top5 = float(accuracy_top5) / float(sample_num)
logger.info('-' * 80)
logger.info('top-1 accuracy: {:.2f} percent'.format(accuracy * 100))
logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100))
logger.info('-' * 80)
return return
...@@ -4,6 +4,7 @@ from .nextvlad import NEXTVLAD ...@@ -4,6 +4,7 @@ from .nextvlad import NEXTVLAD
from .tsn import TSN from .tsn import TSN
from .stnet import STNET from .stnet import STNET
from .attention_lstm import AttentionLSTM from .attention_lstm import AttentionLSTM
from .nonlocal_model import NonLocal
# regist models # regist models
regist_model("AttentionCluster", AttentionCluster) regist_model("AttentionCluster", AttentionCluster)
...@@ -11,3 +12,4 @@ regist_model("NEXTVLAD", NEXTVLAD) ...@@ -11,3 +12,4 @@ regist_model("NEXTVLAD", NEXTVLAD)
regist_model("TSN", TSN) regist_model("TSN", TSN)
regist_model("STNET", STNET) regist_model("STNET", STNET)
regist_model("AttentionLSTM", AttentionLSTM) regist_model("AttentionLSTM", AttentionLSTM)
regist_model('NONLOCAL', NonLocal)
...@@ -137,8 +137,8 @@ class ModelBase(object): ...@@ -137,8 +137,8 @@ class ModelBase(object):
if os.path.exists(path): if os.path.exists(path):
return path return path
logger.info("Download pretrain weights of {} from {}".format( logger.info("Download pretrain weights of {} from {}".format(self.name,
self.name, url)) url))
download(url, path) download(url, path)
return path return path
...@@ -146,6 +146,12 @@ class ModelBase(object): ...@@ -146,6 +146,12 @@ class ModelBase(object):
logger.info("Load pretrain weights from {}".format(pretrain)) logger.info("Load pretrain weights from {}".format(pretrain))
fluid.io.load_params(exe, pretrain, main_program=prog) fluid.io.load_params(exe, pretrain, main_program=prog)
def load_test_weights(self, exe, weights, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(weights, var.name))
fluid.io.load_vars(exe, weights, predicate=if_exist)
def get_config_from_sec(self, sec, item, default=None): def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg: if sec.upper() not in self.cfg:
return default return default
...@@ -178,4 +184,3 @@ def regist_model(name, model): ...@@ -178,4 +184,3 @@ def regist_model(name, model):
def get_model(name, cfg, mode='train'): def get_model(name, cfg, mode='train'):
return model_zoo.get(name, cfg, mode) return model_zoo.get(name, cfg, mode)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
# 3d spacetime nonlocal (v1, spatial downsample)
def spacetime_nonlocal(blob_in, dim_in, dim_out, batch_size, prefix, dim_inner, cfg, \
test_mode = False, max_pool_stride = 2):
#------------
cur = blob_in
# we do projection to convert each spacetime location to a feature
# theta original size
# e.g., (8, 1024, 4, 14, 14) => (8, 1024, 4, 14, 14)
theta = fluid.layers.conv3d(
input=cur,
num_filters=dim_inner,
filter_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
param_attr=ParamAttr(
name=prefix + '_theta' + "_w",
initializer=fluid.initializer.Normal(
loc=0.0, scale=cfg.NONLOCAL.conv_init_std)),
bias_attr=ParamAttr(
name=prefix + '_theta' + "_b",
initializer=fluid.initializer.Constant(value=0.))
if (cfg.NONLOCAL.no_bias == 0) else False,
name=prefix + '_theta')
theta_shape = theta.shape
# phi and g: half spatial size
# e.g., (8, 1024, 4, 14, 14) => (8, 1024, 4, 7, 7)
if cfg.NONLOCAL.use_maxpool:
max_pool = fluid.layers.pool3d(
input=cur,
pool_size=[1, max_pool_stride, max_pool_stride],
pool_type='max',
pool_stride=[1, max_pool_stride, max_pool_stride],
pool_padding=[0, 0, 0],
name=prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv3d(
input=max_pool,
num_filters=dim_inner,
filter_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
param_attr=ParamAttr(
name=prefix + '_phi' + "_w",
initializer=fluid.initializer.Normal(
loc=0.0, scale=cfg.NONLOCAL.conv_init_std)),
bias_attr=ParamAttr(
name=prefix + '_phi' + "_b",
initializer=fluid.initializer.Constant(value=0.))
if (cfg.NONLOCAL.no_bias == 0) else False,
name=prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv3d(
input=max_pool,
num_filters=dim_inner,
filter_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
param_attr=ParamAttr(
name=prefix + '_g' + "_w",
initializer=fluid.initializer.Normal(
loc=0.0, scale=cfg.NONLOCAL.conv_init_std)),
bias_attr=ParamAttr(
name=prefix + '_g' + "_b",
initializer=fluid.initializer.Constant(value=0.))
if (cfg.NONLOCAL.no_bias == 0) else False,
name=prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(
theta, [-1, 0, theta_shape[2] * theta_shape[3] * theta_shape[4]])
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.reshape(
phi, [-1, 0, phi_shape[2] * phi_shape[3] * phi_shape[4]])
theta_phi = fluid.layers.matmul(theta, phi, name=prefix + '_affinity')
g = fluid.layers.reshape(g, [-1, 0, g_shape[2] * g_shape[3] * g_shape[4]])
if cfg.NONLOCAL.use_softmax:
if cfg.NONLOCAL.use_scale is True:
theta_phi_sc = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(
theta_phi_sc, name=prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name=prefix + '_y')
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
# print(t_shape)
# print(theta_shape)
t_re = fluid.layers.reshape(t, shape=list(theta_shape))
blob_out = t_re
blob_out = fluid.layers.conv3d(
input=blob_out,
num_filters=dim_out,
filter_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
param_attr=ParamAttr(
name=prefix + '_out' + "_w",
initializer=fluid.initializer.Constant(value=0.)
if cfg.NONLOCAL.use_zero_init_conv else fluid.initializer.Normal(
loc=0.0, scale=cfg.NONLOCAL.conv_init_std)),
bias_attr=ParamAttr(
name=prefix + '_out' + "_b",
initializer=fluid.initializer.Constant(value=0.))
if (cfg.NONLOCAL.no_bias == 0) else False,
name=prefix + '_out')
blob_out_shape = blob_out.shape
if cfg.NONLOCAL.use_bn is True:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(
blob_out,
is_test=test_mode,
momentum=cfg.NONLOCAL.bn_momentum,
epsilon=cfg.NONLOCAL.bn_epsilon,
name=bn_name,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(
value=cfg.NONLOCAL.bn_init_gamma),
regularizer=fluid.regularizer.L2Decay(
cfg.TRAIN.weight_decay_bn)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2Decay(
cfg.TRAIN.weight_decay_bn)),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance") # add bn
if cfg.NONLOCAL.use_affine is True:
affine_scale = fluid.layers.create_parameter(
shape=[blob_out_shape[1]],
dtype=blob_out.dtype,
attr=ParamAttr(name=prefix + '_affine' + '_s'),
default_initializer=fluid.initializer.Constant(value=1.))
affine_bias = fluid.layers.create_parameter(
shape=[blob_out_shape[1]],
dtype=blob_out.dtype,
attr=ParamAttr(name=prefix + '_affine' + '_b'),
default_initializer=fluid.initializer.Constant(value=0.))
blob_out = fluid.layers.affine_channel(
blob_out,
scale=affine_scale,
bias=affine_bias,
name=prefix + '_affine') # add affine
return blob_out
def add_nonlocal(blob_in,
dim_in,
dim_out,
batch_size,
prefix,
dim_inner,
cfg,
test_mode=False):
blob_out = spacetime_nonlocal(blob_in, \
dim_in, dim_out, batch_size, prefix, dim_inner, cfg, test_mode = test_mode)
blob_out = fluid.layers.elementwise_add(
blob_out, blob_in, name=prefix + '_sum')
return blob_out
# this is to reduce memory usage if the feature maps are big
# devide the feature maps into groups in the temporal dimension,
# and perform non-local operations inside each group.
def add_nonlocal_group(blob_in,
dim_in,
dim_out,
batch_size,
pool_stride,
height,
width,
group_size,
prefix,
dim_inner,
cfg,
test_mode=False):
group_num = int(pool_stride / group_size)
assert (pool_stride % group_size == 0), \
'nonlocal block {}: pool_stride({}) should be divided by group size({})'.format(prefix, pool_stride, group_size)
if group_num > 1:
blob_in = fluid.layers.transpose(
blob_in, [0, 2, 1, 3, 4], name=prefix + '_pre_trans1')
blob_in = fluid.layers.reshape(
blob_in,
[batch_size * group_num, group_size, dim_in, height, width],
name=prefix + '_pre_reshape1')
blob_in = fluid.layers.transpose(
blob_in, [0, 2, 1, 3, 4], name=prefix + '_pre_trans2')
blob_out = spacetime_nonlocal(
blob_in,
dim_in,
dim_out,
batch_size,
prefix,
dim_inner,
cfg,
test_mode=test_mode)
blob_out = fluid.layers.elementwise_add(
blob_out, blob_in, name=prefix + '_sum')
if group_num > 1:
blob_out = fluid.layers.transpose(
blob_out, [0, 2, 1, 3, 4], name=prefix + '_post_trans1')
blob_out = fluid.layers.reshape(
blob_out,
[batch_size, group_num * group_size, dim_out, height, width],
name=prefix + '_post_reshape1')
blob_out = fluid.layers.transpose(
blob_out, [0, 2, 1, 3, 4], name=prefix + '_post_trans2')
return blob_out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import numpy as np
import cPickle
import paddle.fluid as fluid
from ..model import ModelBase
import resnet_video
import logging
logger = logging.getLogger(__name__)
__all__ = ["NonLocal"]
# To add new models, import them, add them to this map and models/TARGETS
class NonLocal(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(NonLocal, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
# video_length
self.video_length = self.get_config_from_sec(self.mode, 'video_length')
# crop size
self.crop_size = self.get_config_from_sec(self.mode, 'crop_size')
def build_input(self, use_pyreader=True):
input_shape = [3, self.video_length, self.crop_size, self.crop_size]
label_shape = [1]
py_reader = None
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
py_reader = fluid.layers.py_reader(
capacity=20,
shapes=[[-1] + input_shape, [-1] + label_shape],
dtypes=['float32', 'int64'],
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
data, label = fluid.layers.read_file(py_reader)
self.py_reader = py_reader
else:
data = fluid.layers.data(
name='train_data' if self.is_training else 'test_data',
shape=input_shape,
dtype='float32')
if self.mode != 'infer':
label = fluid.layers.data(
name='train_label' if self.is_training else 'test_label',
shape=label_shape,
dtype='int64')
else:
label = None
self.feature_input = [data]
self.label_input = label
def create_model_args(self):
return None
def build_model(self):
pred, loss = resnet_video.create_model(
data=self.feature_input[0],
label=self.label_input,
cfg=self.cfg,
is_training=self.is_training,
mode=self.mode)
if loss is not None:
loss = fluid.layers.mean(loss)
self.network_outputs = [pred]
self.loss_ = loss
def optimizer(self):
base_lr = self.get_config_from_sec('TRAIN', 'learning_rate')
lr_decay = self.get_config_from_sec('TRAIN', 'learning_rate_decay')
step_sizes = self.get_config_from_sec('TRAIN', 'step_sizes')
lr_bounds, lr_values = get_learning_rate_decay_list(base_lr, lr_decay,
step_sizes)
learning_rate = fluid.layers.piecewise_decay(
boundaries=lr_bounds, values=lr_values)
momentum = self.get_config_from_sec('TRAIN', 'momentum')
use_nesterov = self.get_config_from_sec('TRAIN', 'nesterov')
l2_weight_decay = self.get_config_from_sec('TRAIN', 'weight_decay')
logger.info(
'Build up optimizer, \ntype: {}, \nmomentum: {}, \nnesterov: {}, \
\nregularization: L2 {}, \nlr_values: {}, lr_bounds: {}'
.format('Momentum', momentum, use_nesterov, l2_weight_decay,
lr_values, lr_bounds))
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
use_nesterov=use_nesterov,
regularization=fluid.regularizer.L2Decay(l2_weight_decay))
return optimizer
def loss(self):
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
return self.feature_input if self.mode == 'infer' else \
self.feature_input + [self.label_input]
def pretrain_info(self):
return None, None
def weights_info(self):
pass
def load_pretrain_params(self, exe, pretrain, prog, place):
load_params_from_file(exe, prog, pretrain, place)
def load_test_weights(self, exe, weights, prog, place):
super(NonLocal, self).load_test_weights(exe, weights, prog, place)
pred_w = fluid.global_scope().find_var('pred_w').get_tensor()
pred_array = np.array(pred_w)
pred_w_shape = pred_array.shape
if len(pred_w_shape) == 2:
logger.info('reshape for pred_w when test')
pred_array = np.transpose(pred_array, (1, 0))
pred_w_shape = pred_array.shape
pred_array = np.reshape(
pred_array, [pred_w_shape[0], pred_w_shape[1], 1, 1, 1])
pred_w.set(pred_array.astype('float32'), place)
def get_learning_rate_decay_list(base_learning_rate, lr_decay, step_lists):
lr_bounds = []
lr_values = [base_learning_rate * 1]
cur_step = 0
for i in range(len(step_lists)):
cur_step += step_lists[i]
lr_bounds.append(cur_step)
decay_rate = lr_decay**(i + 1)
lr_values.append(base_learning_rate * decay_rate)
return lr_bounds, lr_values
def load_params_from_pkl_file(prog, pretrained_file, place):
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
if os.path.exists(pretrained_file):
params_from_file = cPickle.load(open(pretrained_file))
if len(params_from_file.keys()) == 1:
params_from_file = params_from_file['blobs']
param_name_from_file = params_from_file.keys()
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
common_names = get_common_names(param_name_list, param_name_from_file)
logger.info('-------- loading params -----------')
for name in common_names:
t = fluid.global_scope().find_var(name).get_tensor()
t_array = np.array(t)
f_array = params_from_file[name]
if 'pred' in name:
assert np.prod(t_array.shape) == np.prod(
f_array.shape), "number of params should be the same"
if t_array.shape == f_array.shape:
logger.info("pred param is the same {}".format(name))
else:
re_f_array = np.reshape(f_array, t_array.shape)
t.set(re_f_array.astype('float32'), place)
logger.info("load pred param {}".format(name))
continue
if t_array.shape == f_array.shape:
t.set(f_array.astype('float32'), place)
logger.info("load param {}".format(name))
elif (t_array.shape[:2] == f_array.shape[:2]) and (
t_array.shape[-2:] == f_array.shape[-2:]):
num_inflate = t_array.shape[2]
stack_f_array = np.stack(
[f_array] * num_inflate, axis=2) / float(num_inflate)
assert t_array.shape == stack_f_array.shape, "inflated shape should be the same with tensor {}".format(
name)
t.set(stack_f_array.astype('float32'), place)
logger.info("load inflated({}) param {}".format(num_inflate,
name))
else:
logger.info("Invalid case for name: {}".format(name))
raise
logger.info("finished loading params from resnet pretrained model")
def load_params_from_paddle_file(exe, prog, pretrained_file, place):
if os.path.isdir(pretrained_file):
param_list = prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list]
param_shape = {}
for name in param_name_list:
param_tensor = fluid.global_scope().find_var(name).get_tensor()
param_shape[name] = np.array(param_tensor).shape
param_name_from_file = os.listdir(pretrained_file)
common_names = get_common_names(param_name_list, param_name_from_file)
logger.info('-------- loading params -----------')
# load params from file
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and \
os.path.exists(os.path.join(pretrained_file, var.name))
logger.info("Load pretrain weights from file {}".format(
pretrained_file))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrained_file, vars=vars, main_program=prog)
# reset params if necessary
for name in common_names:
t = fluid.global_scope().find_var(name).get_tensor()
t_array = np.array(t)
origin_shape = param_shape[name]
if 'pred' in name:
assert np.prod(t_array.shape) == np.prod(
origin_shape), "number of params should be the same"
if t_array.shape == origin_shape:
logger.info("pred param is the same {}".format(name))
else:
reshaped_t_array = np.reshape(t_array, origin_shape)
t.set(reshaped_t_array.astype('float32'), place)
logger.info("load pred param {}".format(name))
continue
if t_array.shape == origin_shape:
logger.info("load param {}".format(name))
elif (t_array.shape[:2] == origin_shape[:2]) and (
t_array.shape[-2:] == origin_shape[-2:]):
num_inflate = origin_shape[2]
stack_t_array = np.stack(
[t_array] * num_inflate, axis=2) / float(num_inflate)
assert origin_shape == stack_t_array.shape, "inflated shape should be the same with tensor {}".format(
name)
t.set(stack_t_array.astype('float32'), place)
logger.info("load inflated({}) param {}".format(num_inflate,
name))
else:
logger.info("Invalid case for name: {}".format(name))
raise
logger.info("finished loading params from resnet pretrained model")
else:
logger.info(
"pretrained file is not in a directory, not suitable to load params".
format(pretrained_file))
pass
def get_common_names(param_name_list, param_name_from_file):
# name check and return common names both in param_name_list and file
common_names = []
paddle_only_names = []
file_only_names = []
logger.info('-------- comon params -----------')
for name in param_name_list:
if name in param_name_from_file:
common_names.append(name)
logger.info(name)
else:
paddle_only_names.append(name)
logger.info('-------- paddle only params ----------')
for name in paddle_only_names:
logger.info(name)
logger.info('-------- file only params -----------')
for name in param_name_from_file:
if name in param_name_list:
assert name in common_names
else:
file_only_names.append(name)
logger.info(name)
return common_names
def load_params_from_file(exe, prog, pretrained_file, place):
logger.info('load params from {}'.format(pretrained_file))
if '.pkl' in pretrained_file:
load_params_from_pkl_file(prog, pretrained_file, place)
else:
load_params_from_paddle_file(exe, prog, pretrained_file, place)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
import nonlocal_helper
def Conv3dAffine(blob_in,
prefix,
dim_in,
dim_out,
filter_size,
stride,
padding,
cfg,
group=1,
test_mode=False,
bn_init=None):
blob_out = fluid.layers.conv3d(
input=blob_in,
num_filters=dim_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=group,
param_attr=ParamAttr(
name=prefix + "_weights", initializer=fluid.initializer.MSRA()),
bias_attr=False,
name=prefix + "_conv")
blob_out_shape = blob_out.shape
affine_name = "bn" + prefix[3:]
affine_scale = fluid.layers.create_parameter(
shape=[blob_out_shape[1]],
dtype=blob_out.dtype,
attr=ParamAttr(name=affine_name + '_scale'),
default_initializer=fluid.initializer.Constant(value=1.))
affine_bias = fluid.layers.create_parameter(
shape=[blob_out_shape[1]],
dtype=blob_out.dtype,
attr=ParamAttr(name=affine_name + '_offset'),
default_initializer=fluid.initializer.Constant(value=0.))
blob_out = fluid.layers.affine_channel(
blob_out, scale=affine_scale, bias=affine_bias, name=affine_name)
return blob_out
def Conv3dBN(blob_in,
prefix,
dim_in,
dim_out,
filter_size,
stride,
padding,
cfg,
group=1,
test_mode=False,
bn_init=None):
blob_out = fluid.layers.conv3d(
input=blob_in,
num_filters=dim_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=group,
param_attr=ParamAttr(
name=prefix + "_weights", initializer=fluid.initializer.MSRA()),
bias_attr=False,
name=prefix + "_conv")
bn_name = "bn" + prefix[3:]
blob_out = fluid.layers.batch_norm(
blob_out,
is_test=test_mode,
momentum=cfg.MODEL.bn_momentum,
epsilon=cfg.MODEL.bn_epsilon,
name=bn_name,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(value=bn_init if
(bn_init != None) else 1.),
regularizer=fluid.regularizer.L2Decay(cfg.TRAIN.weight_decay_bn)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2Decay(cfg.TRAIN.weight_decay_bn)),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
return blob_out
# 3d bottleneck
def bottleneck_transformation_3d(blob_in,
dim_in,
dim_out,
stride,
prefix,
dim_inner,
cfg,
group=1,
use_temp_conv=1,
temp_stride=1,
test_mode=False):
conv_op = Conv3dAffine if cfg.MODEL.use_affine else Conv3dBN
# 1x1 layer
blob_out = conv_op(
blob_in,
prefix + "_branch2a",
dim_in,
dim_inner, [1 + use_temp_conv * 2, 1, 1], [temp_stride, 1, 1],
[use_temp_conv, 0, 0],
cfg,
test_mode=test_mode)
blob_out = fluid.layers.relu(blob_out, name=prefix + "_branch2a" + "_relu")
# 3x3 layer
blob_out = conv_op(
blob_out,
prefix + '_branch2b',
dim_inner,
dim_inner, [1, 3, 3], [1, stride, stride], [0, 1, 1],
cfg,
group=group,
test_mode=test_mode)
blob_out = fluid.layers.relu(blob_out, name=prefix + "_branch2b" + "_relu")
# 1x1 layer, no relu
blob_out = conv_op(
blob_out,
prefix + '_branch2c',
dim_inner,
dim_out, [1, 1, 1], [1, 1, 1], [0, 0, 0],
cfg,
test_mode=test_mode,
bn_init=cfg.MODEL.bn_init_gamma)
return blob_out
def _add_shortcut_3d(blob_in,
prefix,
dim_in,
dim_out,
stride,
cfg,
temp_stride=1,
test_mode=False):
if ((dim_in == dim_out) and (temp_stride == 1) and (stride == 1)):
# identity mapping (do nothing)
return blob_in
else:
# when dim changes
conv_op = Conv3dAffine if cfg.MODEL.use_affine else Conv3dBN
blob_out = conv_op(
blob_in,
prefix,
dim_in,
dim_out, [1, 1, 1], [temp_stride, stride, stride], [0, 0, 0],
cfg,
test_mode=test_mode)
return blob_out
# residual block abstraction
def _generic_residual_block_3d(blob_in,
dim_in,
dim_out,
stride,
prefix,
dim_inner,
cfg,
group=1,
use_temp_conv=0,
temp_stride=1,
trans_func=None,
test_mode=False):
# transformation branch (e.g. 1x1-3x3-1x1, or 3x3-3x3), namely "F(x)"
if trans_func is None:
trans_func = globals()[cfg.RESNETS.trans_func]
tr_blob = trans_func(
blob_in,
dim_in,
dim_out,
stride,
prefix,
dim_inner,
cfg,
group=group,
use_temp_conv=use_temp_conv,
temp_stride=temp_stride,
test_mode=test_mode)
# create short cut, namely, "x"
sc_blob = _add_shortcut_3d(
blob_in,
prefix + "_branch1",
dim_in,
dim_out,
stride,
cfg,
temp_stride=temp_stride,
test_mode=test_mode)
# addition, namely, "x + F(x)", and relu
sum_blob = fluid.layers.elementwise_add(
tr_blob, sc_blob, act='relu', name=prefix + '_sum')
return sum_blob
def res_stage_nonlocal(block_fn,
blob_in,
dim_in,
dim_out,
stride,
num_blocks,
prefix,
cfg,
dim_inner=None,
group=None,
use_temp_convs=None,
temp_strides=None,
batch_size=None,
nonlocal_name=None,
nonlocal_mod=1000,
test_mode=False):
# prefix is something like: res2, res3, etc.
# each res layer has num_blocks stacked.
# check dtype and format of use_temp_convs and temp_strides
if use_temp_convs is None:
use_temp_convs = np.zeros(num_blocks).astype(int)
if temp_strides is None:
temp_strides = np.ones(num_blocks).astype(int)
if len(use_temp_convs) < num_blocks:
for _ in range(num_blocks - len(use_temp_convs)):
use_temp_convs.append(0)
temp_strides.append(1)
for idx in range(num_blocks):
block_prefix = '{}{}'.format(prefix, chr(idx + 97))
block_stride = 2 if ((idx == 0) and (stride == 2)) else 1
blob_in = _generic_residual_block_3d(
blob_in,
dim_in,
dim_out,
block_stride,
block_prefix,
dim_inner,
cfg,
group=group,
use_temp_conv=use_temp_convs[idx],
temp_stride=temp_strides[idx],
test_mode=test_mode)
dim_in = dim_out
if idx % nonlocal_mod == nonlocal_mod - 1:
blob_in = nonlocal_helper.add_nonlocal(
blob_in,
dim_in,
dim_in,
batch_size,
nonlocal_name + '_{}'.format(idx),
int(dim_in / 2),
cfg,
test_mode=test_mode)
return blob_in, dim_in
def res_stage_nonlocal_group(block_fn,
blob_in,
dim_in,
dim_out,
stride,
num_blocks,
prefix,
cfg,
dim_inner=None,
group=None,
use_temp_convs=None,
temp_strides=None,
batch_size=None,
pool_stride=None,
spatial_dim=None,
group_size=None,
nonlocal_name=None,
nonlocal_mod=1000,
test_mode=False):
# prefix is something like res2, res3, etc.
# each res layer has num_blocks stacked
# check dtype and format of use_temp_convs and temp_strides
if use_temp_convs is None:
use_temp_convs = np.zeros(num_blocks).astype(int)
if temp_strides is None:
temp_strides = np.ones(num_blocks).astype(int)
for idx in range(num_blocks):
block_prefix = "{}{}".format(prefix, chr(idx + 97))
block_stride = 2 if (idx == 0 and stride == 2) else 1
blob_in = _generic_residual_block_3d(
blob_in,
dim_in,
dim_out,
block_stride,
block_prefix,
dim_inner,
cfg,
group=group,
use_temp_conv=use_temp_convs[idx],
temp_stride=temp_strides[idx],
test_mode=test_mode)
dim_in = dim_out
if idx % nonlocal_mod == nonlocal_mod - 1:
blob_in = nonlocal_helper.add_nonlocal_group(
blob_in,
dim_in,
dim_in,
batch_size,
pool_stride,
spatial_dim,
spatial_dim,
group_size,
nonlocal_name + "_{}".format(idx),
int(dim_in / 2),
cfg,
test_mode=test_mode)
return blob_in, dim_in
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import resnet_helper
import logging
logger = logging.getLogger(__name__)
# For more depths, add the block config here
BLOCK_CONFIG = {
50: (3, 4, 6, 3),
101: (3, 4, 23, 3),
}
# ------------------------------------------------------------------------
# obtain_arc defines the temporal kernel radius and temporal strides for
# each layers residual blocks in a resnet.
# e.g. use_temp_convs = 1 means a temporal kernel of 3 is used.
# In ResNet50, it has (3, 4, 6, 3) blocks in conv2, 3, 4, 5,
# so the lengths of the corresponding lists are (3, 4, 6, 3).
# ------------------------------------------------------------------------
def obtain_arc(arc_type, video_length):
pool_stride = 1
# c2d, ResNet50
if arc_type == 1:
use_temp_convs_1 = [0]
temp_strides_1 = [1]
use_temp_convs_2 = [0, 0, 0]
temp_strides_2 = [1, 1, 1]
use_temp_convs_3 = [0, 0, 0, 0]
temp_strides_3 = [1, 1, 1, 1]
use_temp_convs_4 = [0, ] * 6
temp_strides_4 = [1, ] * 6
use_temp_convs_5 = [0, 0, 0]
temp_strides_5 = [1, 1, 1]
pool_stride = int(video_length / 2)
# i3d, ResNet50
if arc_type == 2:
use_temp_convs_1 = [2]
temp_strides_1 = [1]
use_temp_convs_2 = [1, 1, 1]
temp_strides_2 = [1, 1, 1]
use_temp_convs_3 = [1, 0, 1, 0]
temp_strides_3 = [1, 1, 1, 1]
use_temp_convs_4 = [1, 0, 1, 0, 1, 0]
temp_strides_4 = [1, 1, 1, 1, 1, 1]
use_temp_convs_5 = [0, 1, 0]
temp_strides_5 = [1, 1, 1]
pool_stride = int(video_length / 2)
# c2d, ResNet101
if arc_type == 3:
use_temp_convs_1 = [0]
temp_strides_1 = [1]
use_temp_convs_2 = [0, 0, 0]
temp_strides_2 = [1, 1, 1]
use_temp_convs_3 = [0, 0, 0, 0]
temp_strides_3 = [1, 1, 1, 1]
use_temp_convs_4 = [0, ] * 23
temp_strides_4 = [1, ] * 23
use_temp_convs_5 = [0, 0, 0]
temp_strides_5 = [1, 1, 1]
pool_stride = int(video_length / 2)
# i3d, ResNet101
if arc_type == 4:
use_temp_convs_1 = [2]
temp_strides_1 = [1]
use_temp_convs_2 = [1, 1, 1]
temp_strides_2 = [1, 1, 1]
use_temp_convs_3 = [1, 0, 1, 0]
temp_strides_3 = [1, 1, 1, 1]
use_temp_convs_4 = []
for i in range(23):
if i % 2 == 0:
use_temp_convs_4.append(1)
else:
use_temp_convs_4.append(0)
temp_strides_4 = [1] * 23
use_temp_convs_5 = [0, 1, 0]
temp_strides_5 = [1, 1, 1]
pool_stride = int(video_length / 2)
use_temp_convs_set = [
use_temp_convs_1, use_temp_convs_2, use_temp_convs_3, use_temp_convs_4,
use_temp_convs_5
]
temp_strides_set = [
temp_strides_1, temp_strides_2, temp_strides_3, temp_strides_4,
temp_strides_5
]
return use_temp_convs_set, temp_strides_set, pool_stride
def create_model(data, label, cfg, is_training=True, mode='train'):
group = cfg.RESNETS.num_groups
width_per_group = cfg.RESNETS.width_per_group
batch_size = int(cfg.TRAIN.batch_size / cfg.NUM_GPUS)
logger.info('--------------- ResNet-{} {}x{}d-{}, {} ---------------'.
format(cfg.MODEL.depth, group, width_per_group,
cfg.RESNETS.trans_func, cfg.MODEL.dataset))
assert cfg.MODEL.depth in BLOCK_CONFIG.keys(), \
"Block config is not defined for specified model depth."
(n1, n2, n3, n4) = BLOCK_CONFIG[cfg.MODEL.depth]
res_block = resnet_helper._generic_residual_block_3d
dim_inner = group * width_per_group
use_temp_convs_set, temp_strides_set, pool_stride = obtain_arc(
cfg.MODEL.video_arc_choice, cfg[mode.upper()]['video_length'])
logger.info(use_temp_convs_set)
logger.info(temp_strides_set)
conv_blob = fluid.layers.conv3d(
input=data,
num_filters=64,
filter_size=[1 + use_temp_convs_set[0][0] * 2, 7, 7],
stride=[temp_strides_set[0][0], 2, 2],
padding=[use_temp_convs_set[0][0], 3, 3],
param_attr=ParamAttr(
name='conv1' + "_weights", initializer=fluid.initializer.MSRA()),
bias_attr=False,
name='conv1')
test_mode = False if (mode == 'train') else True
if cfg.MODEL.use_affine is False:
# use bn
bn_name = 'bn_conv1'
bn_blob = fluid.layers.batch_norm(
conv_blob,
is_test=test_mode,
momentum=cfg.MODEL.bn_momentum,
epsilon=cfg.MODEL.bn_epsilon,
name=bn_name,
param_attr=ParamAttr(
name=bn_name + "_scale",
regularizer=fluid.regularizer.L2Decay(
cfg.TRAIN.weight_decay_bn)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2Decay(
cfg.TRAIN.weight_decay_bn)),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
else:
# use affine
affine_name = 'bn_conv1'
conv_blob_shape = conv_blob.shape
affine_scale = fluid.layers.create_parameter(
shape=[conv_blob_shape[1]],
dtype=conv_blob.dtype,
attr=ParamAttr(name=affine_name + '_scale'),
default_initializer=fluid.initializer.Constant(value=1.))
affine_bias = fluid.layers.create_parameter(
shape=[conv_blob_shape[1]],
dtype=conv_blob.dtype,
attr=ParamAttr(name=affine_name + '_offset'),
default_initializer=fluid.initializer.Constant(value=0.))
bn_blob = fluid.layers.affine_channel(
conv_blob, scale=affine_scale, bias=affine_bias, name=affine_name)
# relu
relu_blob = fluid.layers.relu(bn_blob, name='res_conv1_bn_relu')
# max pool
max_pool = fluid.layers.pool3d(
input=relu_blob,
pool_size=[1, 3, 3],
pool_type='max',
pool_stride=[1, 2, 2],
pool_padding=[0, 0, 0],
name='pool1')
# building res block
if cfg.MODEL.depth in [50, 101]:
blob_in, dim_in = resnet_helper.res_stage_nonlocal(
res_block,
max_pool,
64,
256,
stride=1,
num_blocks=n1,
prefix='res2',
cfg=cfg,
dim_inner=dim_inner,
group=group,
use_temp_convs=use_temp_convs_set[1],
temp_strides=temp_strides_set[1],
test_mode=test_mode)
layer_mod = cfg.NONLOCAL.layer_mod
if cfg.MODEL.depth == 101:
layer_mod = 2
if cfg.NONLOCAL.conv3_nonlocal is False:
layer_mod = 1000
blob_in = fluid.layers.pool3d(
blob_in,
pool_size=[2, 1, 1],
pool_type='max',
pool_stride=[2, 1, 1],
pool_padding=[0, 0, 0],
name='pool2')
if cfg.MODEL.use_affine is False:
blob_in, dim_in = resnet_helper.res_stage_nonlocal(
res_block,
blob_in,
dim_in,
512,
stride=2,
num_blocks=n2,
prefix='res3',
cfg=cfg,
dim_inner=dim_inner * 2,
group=group,
use_temp_convs=use_temp_convs_set[2],
temp_strides=temp_strides_set[2],
batch_size=batch_size,
nonlocal_name="nonlocal_conv3",
nonlocal_mod=layer_mod,
test_mode=test_mode)
else:
crop_size = cfg[mode.upper()]['crop_size']
blob_in, dim_in = resnet_helper.res_stage_nonlocal_group(
res_block,
blob_in,
dim_in,
512,
stride=2,
num_blocks=n2,
prefix='res3',
cfg=cfg,
dim_inner=dim_inner * 2,
group=group,
use_temp_convs=use_temp_convs_set[2],
temp_strides=temp_strides_set[2],
batch_size=batch_size,
pool_stride=pool_stride,
spatial_dim=int(crop_size / 8),
group_size=4,
nonlocal_name="nonlocal_conv3_group",
nonlocal_mod=layer_mod,
test_mode=test_mode)
layer_mod = cfg.NONLOCAL.layer_mod
if cfg.MODEL.depth == 101:
layer_mod = layer_mod * 4 - 1
if cfg.NONLOCAL.conv4_nonlocal is False:
layer_mod = 1000
blob_in, dim_in = resnet_helper.res_stage_nonlocal(
res_block,
blob_in,
dim_in,
1024,
stride=2,
num_blocks=n3,
prefix='res4',
cfg=cfg,
dim_inner=dim_inner * 4,
group=group,
use_temp_convs=use_temp_convs_set[3],
temp_strides=temp_strides_set[3],
batch_size=batch_size,
nonlocal_name="nonlocal_conv4",
nonlocal_mod=layer_mod,
test_mode=test_mode)
blob_in, dim_in = resnet_helper.res_stage_nonlocal(
res_block,
blob_in,
dim_in,
2048,
stride=2,
num_blocks=n4,
prefix='res5',
cfg=cfg,
dim_inner=dim_inner * 8,
group=group,
use_temp_convs=use_temp_convs_set[4],
temp_strides=temp_strides_set[4],
test_mode=test_mode)
else:
raise Exception("Unsupported network settings.")
blob_out = fluid.layers.pool3d(
blob_in,
pool_size=[pool_stride, 7, 7],
pool_type='avg',
pool_stride=[1, 1, 1],
pool_padding=[0, 0, 0],
name='pool5')
if (cfg.TRAIN.dropout_rate > 0) and (test_mode is False):
blob_out = fluid.layers.dropout(
blob_out, cfg.TRAIN.dropout_rate, is_test=test_mode)
if mode in ['train', 'valid']:
blob_out = fluid.layers.fc(
blob_out,
cfg.MODEL.num_classes,
param_attr=ParamAttr(
name='pred' + "_w",
initializer=fluid.initializer.Normal(
loc=0.0, scale=cfg.MODEL.fc_init_std)),
bias_attr=ParamAttr(
name='pred' + "_b",
initializer=fluid.initializer.Constant(value=0.)),
name='pred')
elif mode in ['test', 'infer']:
blob_out = fluid.layers.conv3d(
input=blob_out,
num_filters=cfg.MODEL.num_classes,
filter_size=[1, 1, 1],
stride=[1, 1, 1],
padding=[0, 0, 0],
param_attr=ParamAttr(
name='pred' + "_w", initializer=fluid.initializer.MSRA()),
bias_attr=ParamAttr(
name='pred' + "_b",
initializer=fluid.initializer.Constant(value=0.)),
name='pred')
if (mode == 'train') or (mode == 'valid'):
softmax = fluid.layers.softmax(blob_out)
loss = fluid.layers.cross_entropy(
softmax, label, soft_label=False, ignore_index=-100)
elif (mode == 'test') or (mode == 'infer'):
# fully convolutional testing, when loading test model,
# params should be copied from train_prog fc layer named pred
blob_out = fluid.layers.transpose(
blob_out, [0, 2, 3, 4, 1], name='pred_tr')
blob_out = fluid.layers.softmax(blob_out, name='softmax_conv')
softmax = fluid.layers.reduce_mean(
blob_out, dim=[1, 2, 3], keep_dim=False, name='softmax')
loss = None
else:
raise 'Not implemented Error'
return softmax, loss
文件模式从 100755 更改为 100644
python infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \ python infer.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \
--filelist=./data/youtube8m/infer.list \ --filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/AttentionCluster_epoch0 \ --weights=./checkpoints/AttentionCluster_epoch0 \
--save-dir="./save" --save_dir="./save"
python infer.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt \ python infer.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \
--filelist=./data/youtube8m/infer.list \ --filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/AttentionLSTM_epoch0 \ --weights=./checkpoints/AttentionLSTM_epoch0 \
--save-dir="./save" --save_dir="./save"
python infer.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \ python infer.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \
--weights=./checkpoints/NEXTVLAD_epoch0 \ --weights=./checkpoints/NEXTVLAD_epoch0 \
--save-dir="./save" --save_dir="./save"
python infer.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --filelist=./dataset/nonlocal/infer.list \
--log_interval=10 --weights=./checkpoints/NONLOCAL_epoch0 --save_dir=./save
python infer.py --model-name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \ python infer.py --model_name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \
--log-interval=10 --weights=./checkpoints/STNET_epoch0 --save-dir=./save --log_interval=10 --weights=./checkpoints/STNET_epoch0 --save_dir=./save
python infer.py --model-name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \ python infer.py --model_name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \
--log-interval=10 --weights=./checkpoints/TSN_epoch0 --save-dir=./save --log_interval=10 --weights=./checkpoints/TSN_epoch0 --save_dir=./save
python test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \ python test.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \
--log-interval=5 --weights=./checkpoints/AttentionCluster_epoch0 --log_interval=5 --weights=./checkpoints/AttentionCluster_epoch0
python test.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt \ python test.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \
--log-interval=5 --weights=./checkpoints/AttentionLSTM_epoch0 --log_interval=5 --weights=./checkpoints/AttentionLSTM_epoch0
python test.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt \ python test.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt \
--log-interval=10 --weights=./checkpoints/NEXTVLAD_epoch0 --log_interval=10 --weights=./checkpoints/NEXTVLAD_epoch0
python -i test.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt \
--log_interval=1 --weights=./checkpoints/NONLOCAL_epoch0
python test.py --model-name="STNET" --config=./configs/stnet.txt \ python test.py --model_name="STNET" --config=./configs/stnet.txt \
--log-interval=10 --weights=./checkpoints/STNET_epoch0 --log_interval=10 --weights=./checkpoints/STNET_epoch0
python test.py --model-name="TSN" --config=./configs/tsn.txt \ python test.py --model_name="TSN" --config=./configs/tsn.txt \
--log-interval=10 --weights=./checkpoints/TSN_epoch0 --log_interval=10 --weights=./checkpoints/TSN_epoch0
python train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \ python train.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch_num=5 \
--valid-interval=1 --log-interval=10 --valid_interval=1 --log_interval=10
python train.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt --epoch-num=10 \ python train.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt --epoch_num=10 \
--valid-interval=1 --log-interval=10 --valid_interval=1 --log_interval=10
export CUDA_VISIBLE_DEVICES=0,1,2,3 export CUDA_VISIBLE_DEVICES=0,1,2,3
python train.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt --epoch-num=6 \ python train.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --epoch_num=6 \
--valid-interval=1 --log-interval=10 --valid_interval=1 --log_interval=10
python train.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --epoch_num=120 \
--valid_interval=1 --log_interval=1 \
--pretrain=./pretrained/ResNet50_pretrained
python train.py --model-name="STNET" --config=./configs/stnet.txt --epoch-num=60 \ python train.py --model_name="STNET" --config=./configs/stnet.txt --epoch_num=60 \
--valid-interval=1 --log-interval=10 --valid_interval=1 --log_interval=10
python train.py --model-name="TSN" --config=./configs/tsn.txt --epoch-num=45 \ python train.py --model_name="TSN" --config=./configs/tsn.txt --epoch_num=45 \
--valid-interval=1 --log-interval=10 --valid_interval=1 --log_interval=10
...@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) ...@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--model-name', '--model_name',
type=str, type=str,
default='AttentionCluster', default='AttentionCluster',
help='name of model to train.') help='name of model to train.')
...@@ -44,19 +44,19 @@ def parse_args(): ...@@ -44,19 +44,19 @@ def parse_args():
default='configs/attention_cluster.txt', default='configs/attention_cluster.txt',
help='path to config file of model') help='path to config file of model')
parser.add_argument( parser.add_argument(
'--batch-size', '--batch_size',
type=int, type=int,
default=None, default=None,
help='traing batch size per GPU. None to use config file setting.') help='traing batch size per GPU. None to use config file setting.')
parser.add_argument( parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.') '--use_gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument( parser.add_argument(
'--weights', '--weights',
type=str, type=str,
default=None, default=None,
help='weight path, None to use weights from Paddle.') help='weight path, None to use weights from Paddle.')
parser.add_argument( parser.add_argument(
'--log-interval', '--log_interval',
type=int, type=int,
default=1, default=1,
help='mini-batch interval to log.') help='mini-batch interval to log.')
...@@ -75,7 +75,7 @@ def test(args): ...@@ -75,7 +75,7 @@ def test(args):
test_model.build_model() test_model.build_model()
test_feeds = test_model.feeds() test_feeds = test_model.feeds()
test_outputs = test_model.outputs() test_outputs = test_model.outputs()
loss = test_model.loss() test_loss = test_model.loss()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -85,29 +85,34 @@ def test(args): ...@@ -85,29 +85,34 @@ def test(args):
args.weights), "Given weight dir {} not exist.".format(args.weights) args.weights), "Given weight dir {} not exist.".format(args.weights)
weights = args.weights or test_model.get_weights() weights = args.weights or test_model.get_weights()
def if_exist(var): test_model.load_test_weights(exe, weights,
return os.path.exists(os.path.join(weights, var.name)) fluid.default_main_program(), place)
fluid.io.load_vars(exe, weights, predicate=if_exist)
# get reader and metrics # get reader and metrics
test_reader = get_reader(args.model_name.upper(), 'test', test_config) test_reader = get_reader(args.model_name.upper(), 'test', test_config)
test_metrics = get_metrics(args.model_name.upper(), 'test', test_config) test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds) test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
fetch_list = [loss.name] + [x.name if test_loss is None:
for x in test_outputs] + [test_feeds[-1].name] fetch_list = [x.name for x in test_outputs] + [test_feeds[-1].name]
else:
fetch_list = [test_loss.name] + [x.name for x in test_outputs
] + [test_feeds[-1].name]
epoch_period = [] epoch_period = []
for test_iter, data in enumerate(test_reader()): for test_iter, data in enumerate(test_reader()):
cur_time = time.time() cur_time = time.time()
test_outs = exe.run(fetch_list=fetch_list, test_outs = exe.run(fetch_list=fetch_list, feed=test_feeder.feed(data))
feed=test_feeder.feed(data))
period = time.time() - cur_time period = time.time() - cur_time
epoch_period.append(period) epoch_period.append(period)
loss = np.array(test_outs[0]) if test_loss is None:
pred = np.array(test_outs[1]) loss = np.zeros(1, ).astype('float32')
label = np.array(test_outs[-1]) pred = np.array(test_outs[0])
label = np.array(test_outs[-1])
else:
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
test_metrics.accumulate(loss, pred, label) test_metrics.accumulate(loss, pred, label)
# metric here # metric here
......
...@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) ...@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Paddle Video train script") parser = argparse.ArgumentParser("Paddle Video train script")
parser.add_argument( parser.add_argument(
'--model-name', '--model_name',
type=str, type=str,
default='AttentionCluster', default='AttentionCluster',
help='name of model to train.') help='name of model to train.')
...@@ -45,12 +45,12 @@ def parse_args(): ...@@ -45,12 +45,12 @@ def parse_args():
default='configs/attention_cluster.txt', default='configs/attention_cluster.txt',
help='path to config file of model') help='path to config file of model')
parser.add_argument( parser.add_argument(
'--batch-size', '--batch_size',
type=int, type=int,
default=None, default=None,
help='training batch size. None to use config file setting.') help='training batch size. None to use config file setting.')
parser.add_argument( parser.add_argument(
'--learning-rate', '--learning_rate',
type=float, type=float,
default=None, default=None,
help='learning rate use for training. None to use config file setting.') help='learning rate use for training. None to use config file setting.')
...@@ -65,37 +65,36 @@ def parse_args(): ...@@ -65,37 +65,36 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help='path to resume training based on previous checkpoints. ' help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.' 'None for not resuming any checkpoints.')
)
parser.add_argument( parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.') '--use_gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument( parser.add_argument(
'--no-use-pyreader', '--no_use_pyreader',
action='store_true', action='store_true',
default=False, default=False,
help='whether to use pyreader') help='whether to use pyreader')
parser.add_argument( parser.add_argument(
'--no-memory-optimize', '--no_memory_optimize',
action='store_true', action='store_true',
default=False, default=False,
help='whether to use memory optimize in train') help='whether to use memory optimize in train')
parser.add_argument( parser.add_argument(
'--epoch-num', '--epoch_num',
type=int, type=int,
default=0, default=0,
help='epoch number, 0 for read from config file') help='epoch number, 0 for read from config file')
parser.add_argument( parser.add_argument(
'--valid-interval', '--valid_interval',
type=int, type=int,
default=1, default=1,
help='validation epoch interval, 0 for no validation.') help='validation epoch interval, 0 for no validation.')
parser.add_argument( parser.add_argument(
'--save-dir', '--save_dir',
type=str, type=str,
default='checkpoints', default='checkpoints',
help='directory name to save train snapshoot') help='directory name to save train snapshoot')
parser.add_argument( parser.add_argument(
'--log-interval', '--log_interval',
type=int, type=int,
default=10, default=10,
help='mini-batch interval to log.') help='mini-batch interval to log.')
...@@ -108,6 +107,8 @@ def train(args): ...@@ -108,6 +107,8 @@ def train(args):
config = parse_config(args.config) config = parse_config(args.config)
train_config = merge_configs(config, 'train', vars(args)) train_config = merge_configs(config, 'train', vars(args))
valid_config = merge_configs(config, 'valid', vars(args)) valid_config = merge_configs(config, 'valid', vars(args))
logger.info("############### train config ###############")
print_configs(train_config)
train_model = models.get_model(args.model_name, train_config, mode='train') train_model = models.get_model(args.model_name, train_config, mode='train')
valid_model = models.get_model(args.model_name, valid_config, mode='valid') valid_model = models.get_model(args.model_name, valid_config, mode='valid')
...@@ -153,9 +154,12 @@ def train(args): ...@@ -153,9 +154,12 @@ def train(args):
# if resume weights is given, load resume weights directly # if resume weights is given, load resume weights directly
assert os.path.exists(args.resume), \ assert os.path.exists(args.resume), \
"Given resume weight dir {} not exist.".format(args.resume) "Given resume weight dir {} not exist.".format(args.resume)
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(args.resume, var.name)) return os.path.exists(os.path.join(args.resume, var.name))
fluid.io.load_vars(exe, args.resume, predicate=if_exist, main_program=train_prog)
fluid.io.load_vars(
exe, args.resume, predicate=if_exist, main_program=train_prog)
else: else:
# if not in resume mode, load pretrain weights # if not in resume mode, load pretrain weights
if args.pretrain: if args.pretrain:
...@@ -199,21 +203,43 @@ def train(args): ...@@ -199,21 +203,43 @@ def train(args):
if args.no_use_pyreader: if args.no_use_pyreader:
train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds) train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds)
valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds) valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds)
train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder, train_without_pyreader(
train_fetch_list, train_metrics, epochs = epochs, exe,
log_interval = args.log_interval, valid_interval = args.valid_interval, train_prog,
save_dir = args.save_dir, save_model_name = args.model_name, train_exe,
test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder, train_reader,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) train_feeder,
train_fetch_list,
train_metrics,
epochs=epochs,
log_interval=args.log_interval,
valid_interval=args.valid_interval,
save_dir=args.save_dir,
save_model_name=args.model_name,
test_exe=valid_exe,
test_reader=valid_reader,
test_feeder=valid_feeder,
test_fetch_list=valid_fetch_list,
test_metrics=valid_metrics)
else: else:
train_pyreader.decorate_paddle_reader(train_reader) train_pyreader.decorate_paddle_reader(train_reader)
valid_pyreader.decorate_paddle_reader(valid_reader) valid_pyreader.decorate_paddle_reader(valid_reader)
train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics, train_with_pyreader(
epochs = epochs, log_interval = args.log_interval, exe,
valid_interval = args.valid_interval, train_prog,
save_dir = args.save_dir, save_model_name = args.model_name, train_exe,
test_exe = valid_exe, test_pyreader = valid_pyreader, train_pyreader,
test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) train_fetch_list,
train_metrics,
epochs=epochs,
log_interval=args.log_interval,
valid_interval=args.valid_interval,
save_dir=args.save_dir,
save_model_name=args.model_name,
test_exe=valid_exe,
test_pyreader=valid_pyreader,
test_fetch_list=valid_fetch_list,
test_metrics=valid_metrics)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
__all__ = ['AttrDict'] __all__ = ['AttrDict']
class AttrDict(dict): class AttrDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
return self[key] return self[key]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册