From 3b779eddc2581920726cdf99fbecfb269b51514d Mon Sep 17 00:00:00 2001 From: SunGaofeng Date: Thu, 4 Apr 2019 10:11:47 +0000 Subject: [PATCH] add nonlocal model --- PaddleCV/video/config.py | 19 +- PaddleCV/video/configs/attention_cluster.txt | 0 PaddleCV/video/configs/attention_lstm.txt | 0 PaddleCV/video/configs/nextvlad.txt | 0 PaddleCV/video/configs/nonlocal.txt | 92 +++++ PaddleCV/video/configs/stnet.txt | 0 PaddleCV/video/configs/tsn.txt | 0 PaddleCV/video/datareader/nonlocal_reader.py | 54 +-- PaddleCV/video/infer.py | 16 +- .../multicrop_test/multicrop_test_metrics.py | 25 +- PaddleCV/video/models/__init__.py | 2 + PaddleCV/video/models/model.py | 11 +- .../video/models/nonlocal_model/__init__.py | 1 + .../models/nonlocal_model/nonlocal_helper.py | 254 ++++++++++++ .../models/nonlocal_model/nonlocal_model.py | 299 ++++++++++++++ .../models/nonlocal_model/resnet_helper.py | 356 +++++++++++++++++ .../models/nonlocal_model/resnet_video.py | 371 ++++++++++++++++++ PaddleCV/video/models/utils.py | 0 .../scripts/infer/infer_attention_cluster.sh | 4 +- .../scripts/infer/infer_attention_lstm.sh | 4 +- .../video/scripts/infer/infer_nextvlad.sh | 4 +- .../video/scripts/infer/infer_nonlocal.sh | 2 + PaddleCV/video/scripts/infer/infer_stnet.sh | 4 +- PaddleCV/video/scripts/infer/infer_tsn.sh | 4 +- .../scripts/test/test_attention_cluster.sh | 4 +- .../video/scripts/test/test_attention_lstm.sh | 4 +- PaddleCV/video/scripts/test/test_nextvlad.sh | 4 +- PaddleCV/video/scripts/test/test_nonlocal.sh | 2 + PaddleCV/video/scripts/test/test_stnet.sh | 4 +- PaddleCV/video/scripts/test/test_tsn.sh | 4 +- .../scripts/train/train_attention_cluster.sh | 4 +- .../scripts/train/train_attention_lstm.sh | 4 +- .../video/scripts/train/train_nextvlad.sh | 4 +- .../video/scripts/train/train_nonlocal.sh | 3 + PaddleCV/video/scripts/train/train_stnet.sh | 4 +- PaddleCV/video/scripts/train/train_tsn.sh | 4 +- PaddleCV/video/test.py | 37 +- PaddleCV/video/train.py | 76 ++-- PaddleCV/video/utils.py | 1 + 39 files changed, 1545 insertions(+), 136 deletions(-) mode change 100755 => 100644 PaddleCV/video/config.py mode change 100755 => 100644 PaddleCV/video/configs/attention_cluster.txt mode change 100755 => 100644 PaddleCV/video/configs/attention_lstm.txt mode change 100755 => 100644 PaddleCV/video/configs/nextvlad.txt create mode 100644 PaddleCV/video/configs/nonlocal.txt mode change 100755 => 100644 PaddleCV/video/configs/stnet.txt mode change 100755 => 100644 PaddleCV/video/configs/tsn.txt mode change 100755 => 100644 PaddleCV/video/infer.py mode change 100755 => 100644 PaddleCV/video/models/model.py create mode 100644 PaddleCV/video/models/nonlocal_model/__init__.py create mode 100644 PaddleCV/video/models/nonlocal_model/nonlocal_helper.py create mode 100644 PaddleCV/video/models/nonlocal_model/nonlocal_model.py create mode 100644 PaddleCV/video/models/nonlocal_model/resnet_helper.py create mode 100644 PaddleCV/video/models/nonlocal_model/resnet_video.py mode change 100755 => 100644 PaddleCV/video/models/utils.py create mode 100644 PaddleCV/video/scripts/infer/infer_nonlocal.sh create mode 100644 PaddleCV/video/scripts/test/test_nonlocal.sh create mode 100644 PaddleCV/video/scripts/train/train_nonlocal.sh mode change 100755 => 100644 PaddleCV/video/test.py mode change 100755 => 100644 PaddleCV/video/train.py mode change 100755 => 100644 PaddleCV/video/utils.py diff --git a/PaddleCV/video/config.py b/PaddleCV/video/config.py old mode 100755 new mode 100644 index a534536c..bf8d55b9 --- a/PaddleCV/video/config.py +++ b/PaddleCV/video/config.py @@ -19,12 +19,15 @@ except: from utils import AttrDict +import logging +logger = logging.getLogger(__name__) + CONFIG_SECS = [ - 'train', - 'valid', - 'test', - 'infer', - ] + 'train', + 'valid', + 'test', + 'infer', +] def parse_config(cfg_file): @@ -43,6 +46,7 @@ def parse_config(cfg_file): return cfg + def merge_configs(cfg, sec, args_dict): assert sec in CONFIG_SECS, "invalid config section {}".format(sec) sec_dict = getattr(cfg, sec.upper()) @@ -56,3 +60,8 @@ def merge_configs(cfg, sec, args_dict): pass return cfg + +def print_configs(cfg): + import pprint + logger.info('Training with config:') + logger.info(pprint.pformat(cfg)) diff --git a/PaddleCV/video/configs/attention_cluster.txt b/PaddleCV/video/configs/attention_cluster.txt old mode 100755 new mode 100644 diff --git a/PaddleCV/video/configs/attention_lstm.txt b/PaddleCV/video/configs/attention_lstm.txt old mode 100755 new mode 100644 diff --git a/PaddleCV/video/configs/nextvlad.txt b/PaddleCV/video/configs/nextvlad.txt old mode 100755 new mode 100644 diff --git a/PaddleCV/video/configs/nonlocal.txt b/PaddleCV/video/configs/nonlocal.txt new file mode 100644 index 00000000..d2368682 --- /dev/null +++ b/PaddleCV/video/configs/nonlocal.txt @@ -0,0 +1,92 @@ +[MODEL] +name = "NONLOCAL" +num_classes = 400 +image_mean = 114.75 +image_std = 57.375 +depth = 50 +dataset = 'kinetics400' +video_arc_choice = 1 +use_affine = False +fc_init_std = 0.01 +bn_momentum = 0.9 +bn_epsilon = 1.0e-5 +bn_init_gamma = 0. + +[RESNETS] +num_groups = 1 +width_per_group = 64 +trans_func = bottleneck_transformation_3d + +[NONLOCAL] +bn_momentum = 0.9 +bn_epsilon = 1.0e-5 +bn_init_gamma = 0.0 +layer_mod = 2 +conv3_nonlocal = True +conv4_nonlocal = True +conv_init_std = 0.01 +no_bias = 0 +use_maxpool = True +use_softmax = True +use_scale = True +use_zero_init_conv = False +use_bn = True +use_affine = False + +[TRAIN] +num_reader_threads = 8 +batch_size = 64 +num_gpus = 8 +filelist = './dataset/nonlocal/trainlist.txt' +crop_size = 224 +sample_rate = 8 +video_length = 8 +jitter_scales = [256, 320] + +dropout_rate = 0.5 + +learning_rate = 0.01 +learning_rate_decay = 0.1 +step_sizes = [150000, 150000, 100000] +max_iter = 400000 + +weight_decay = 0.0001 +weight_decay_bn = 0.0 +momentum = 0.9 +nesterov = True +scale_momentum = True + +[VALID] +num_reader_threads = 8 +batch_size = 64 +filelist = './dataset/nonlocal/vallist.txt' +crop_size = 224 +sample_rate = 8 +video_length = 8 +jitter_scales = [256, 320] + +[TEST] +num_reader_threads = 8 +batch_size = 4 +filelist = 'dataset/nonlocal/testlist.txt' +filename_gt = 'dataset/nonlocal/vallist.txt' +checkpoint_dir = './output' +crop_size = 256 +sample_rate = 8 +video_length = 8 +jitter_scales = [256, 256] +num_test_clips = 30 +dataset_size = 19761 +use_multi_crop = 1 + +[INFER] +num_reader_threads = 8 +batch_size = 1 +filelist = 'dataset/nonlocal/inferencelist.txt' +crop_size = 256 +sample_rate = 8 +video_length = 8 +jitter_scales = [256, 256] +num_test_clips = 30 +use_multi_crop = 1 + diff --git a/PaddleCV/video/configs/stnet.txt b/PaddleCV/video/configs/stnet.txt old mode 100755 new mode 100644 diff --git a/PaddleCV/video/configs/tsn.txt b/PaddleCV/video/configs/tsn.txt old mode 100755 new mode 100644 diff --git a/PaddleCV/video/datareader/nonlocal_reader.py b/PaddleCV/video/datareader/nonlocal_reader.py index 8edb5ac2..7df266ea 100644 --- a/PaddleCV/video/datareader/nonlocal_reader.py +++ b/PaddleCV/video/datareader/nonlocal_reader.py @@ -34,7 +34,7 @@ class NonlocalReader(DataReader): image_mean image_std batch_size - list + filelist crop_size sample_rate video_length @@ -68,7 +68,7 @@ class NonlocalReader(DataReader): dataset_args['min_size'] = cfg[mode.upper()]['jitter_scales'][0] dataset_args['max_size'] = cfg[mode.upper()]['jitter_scales'][1] dataset_args['num_reader_threads'] = num_reader_threads - filelist = cfg[mode.upper()]['list'] + filelist = cfg[mode.upper()]['filelist'] batch_size = cfg[mode.upper()]['batch_size'] if self.mode == 'train': @@ -146,8 +146,8 @@ def apply_resize(rgbdata, min_size, max_size): ratio = float(side_length) / float(width) else: ratio = float(side_length) / float(height) - out_height = int(height * ratio) - out_width = int(width * ratio) + out_height = int(round(height * ratio)) + out_width = int(round(width * ratio)) outdata = np.zeros( (length, out_height, out_width, channel), dtype=rgbdata.dtype) for i in range(length): @@ -197,14 +197,13 @@ def crop_mirror_transform(rgbdata, def make_reader(filelist, batch_size, sample_times, is_training, shuffle, **dataset_args): - # should add smaple_times param - fl = open(filelist).readlines() - fl = [line.strip() for line in fl if line.strip() != ''] + def reader(): + fl = open(filelist).readlines() + fl = [line.strip() for line in fl if line.strip() != ''] - if shuffle: - random.shuffle(fl) + if shuffle: + random.shuffle(fl) - def reader(): batch_out = [] for line in fl: # start_time = time.time() @@ -253,23 +252,6 @@ def make_reader(filelist, batch_size, sample_times, is_training, shuffle, def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle, **dataset_args): - fl = open(filelist).readlines() - fl = [line.strip() for line in fl if line.strip() != ''] - - if shuffle: - random.shuffle(fl) - - n = dataset_args['num_reader_threads'] - queue_size = 20 - reader_lists = [None] * n - file_num = int(len(fl) // n) - for i in range(n): - if i < len(reader_lists) - 1: - tmp_list = fl[i * file_num:(i + 1) * file_num] - else: - tmp_list = fl[i * file_num:] - reader_lists[i] = tmp_list - def read_into_queue(flq, queue): batch_out = [] for line in flq: @@ -315,6 +297,24 @@ def make_multi_reader(filelist, batch_size, sample_times, is_training, shuffle, queue.put(None) def queue_reader(): + # split file list and shuffle + fl = open(filelist).readlines() + fl = [line.strip() for line in fl if line.strip() != ''] + + if shuffle: + random.shuffle(fl) + + n = dataset_args['num_reader_threads'] + queue_size = 20 + reader_lists = [None] * n + file_num = int(len(fl) // n) + for i in range(n): + if i < len(reader_lists) - 1: + tmp_list = fl[i * file_num:(i + 1) * file_num] + else: + tmp_list = fl[i * file_num:] + reader_lists[i] = tmp_list + queue = multiprocessing.Queue(queue_size) p_list = [None] * len(reader_lists) # for reader_list in reader_lists: diff --git a/PaddleCV/video/infer.py b/PaddleCV/video/infer.py old mode 100755 new mode 100644 index 43470ced..f44e596a --- a/PaddleCV/video/infer.py +++ b/PaddleCV/video/infer.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - '--model-name', + '--model_name', type=str, default='AttentionCluster', help='name of model to train.') @@ -47,14 +47,14 @@ def parse_args(): default='configs/attention_cluster.txt', help='path to config file of model') parser.add_argument( - '--use-gpu', type=bool, default=True, help='default use gpu.') + '--use_gpu', type=bool, default=True, help='default use gpu.') parser.add_argument( '--weights', type=str, default=None, help='weight path, None to use weights from Paddle.') parser.add_argument( - '--batch-size', + '--batch_size', type=int, default=1, help='sample number in a batch for inference.') @@ -64,17 +64,17 @@ def parse_args(): default=None, help='path to inferenece data file lists file.') parser.add_argument( - '--log-interval', + '--log_interval', type=int, default=1, help='mini-batch interval to log.') parser.add_argument( - '--infer-topk', + '--infer_topk', type=int, default=20, help='topk predictions to restore.') parser.add_argument( - '--save-dir', type=str, default='./', help='directory to store results') + '--save_dir', type=str, default='./', help='directory to store results') args = parser.parse_args() return args @@ -126,8 +126,7 @@ def infer(args): topk_inds = predictions[i].argsort()[0 - args.infer_topk:] topk_inds = topk_inds[::-1] preds = predictions[i][topk_inds] - results.append( - (video_id[i], preds.tolist(), topk_inds.tolist())) + results.append((video_id[i], preds.tolist(), topk_inds.tolist())) prev_time = cur_time cur_time = time.time() period = cur_time - prev_time @@ -145,6 +144,7 @@ def infer(args): "{}_infer_result".format(args.model_name)) pickle.dump(results, open(result_file_name, 'wb')) + if __name__ == "__main__": args = parse_args() logger.info(args) diff --git a/PaddleCV/video/metrics/multicrop_test/multicrop_test_metrics.py b/PaddleCV/video/metrics/multicrop_test/multicrop_test_metrics.py index 9da8826c..fcb7c954 100644 --- a/PaddleCV/video/metrics/multicrop_test/multicrop_test_metrics.py +++ b/PaddleCV/video/metrics/multicrop_test/multicrop_test_metrics.py @@ -63,6 +63,7 @@ class MetricsCalculator(): def accumulate(self, loss, pred, labels): labels = labels.astype(int) + labels = labels[:, 0] for i in range(pred.shape[0]): probs = pred[i, :].tolist() vid = labels[i] @@ -81,6 +82,8 @@ class MetricsCalculator(): evaluate_results(self.results, self.filename_gt, self.dataset_size, \ self.num_classes, self.num_test_clips) # save temporary file + if not os.path.isdir(self.checkpoint_dir): + os.makedirs(self.checkpoint_dir) pkl_path = os.path.join(self.checkpoint_dir, "results_probs.pkl") with open(pkl_path, 'w') as f: @@ -188,26 +191,4 @@ def evaluate_results(results, filename_gt, test_dataset_size, num_classes, logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100)) logger.info('-' * 80) - for i in range(sample_num): - prob = probs[i] - - # top-1 - idx = prob.argmax() - if idx == gt_labels[i] and counts[i] > 0: - accuracy = accuracy + 1 - - ids = np.argsort(prob)[::-1] - for j in range(5): - if ids[j] == gt_labels[i] and counts[i] > 0: - accuracy_top5 = accuracy_top5 + 1 - break - - accuracy = float(accuracy) / float(sample_num) - accuracy_top5 = float(accuracy_top5) / float(sample_num) - - logger.info('-' * 80) - logger.info('top-1 accuracy: {:.2f} percent'.format(accuracy * 100)) - logger.info('top-5 accuracy: {:.2f} percent'.format(accuracy_top5 * 100)) - logger.info('-' * 80) - return diff --git a/PaddleCV/video/models/__init__.py b/PaddleCV/video/models/__init__.py index ae3da375..006e373d 100644 --- a/PaddleCV/video/models/__init__.py +++ b/PaddleCV/video/models/__init__.py @@ -4,6 +4,7 @@ from .nextvlad import NEXTVLAD from .tsn import TSN from .stnet import STNET from .attention_lstm import AttentionLSTM +from .nonlocal_model import NonLocal # regist models regist_model("AttentionCluster", AttentionCluster) @@ -11,3 +12,4 @@ regist_model("NEXTVLAD", NEXTVLAD) regist_model("TSN", TSN) regist_model("STNET", STNET) regist_model("AttentionLSTM", AttentionLSTM) +regist_model('NONLOCAL', NonLocal) diff --git a/PaddleCV/video/models/model.py b/PaddleCV/video/models/model.py old mode 100755 new mode 100644 index 44f888ef..41b5b655 --- a/PaddleCV/video/models/model.py +++ b/PaddleCV/video/models/model.py @@ -137,8 +137,8 @@ class ModelBase(object): if os.path.exists(path): return path - logger.info("Download pretrain weights of {} from {}".format( - self.name, url)) + logger.info("Download pretrain weights of {} from {}".format(self.name, + url)) download(url, path) return path @@ -146,6 +146,12 @@ class ModelBase(object): logger.info("Load pretrain weights from {}".format(pretrain)) fluid.io.load_params(exe, pretrain, main_program=prog) + def load_test_weights(self, exe, weights, prog, place): + def if_exist(var): + return os.path.exists(os.path.join(weights, var.name)) + + fluid.io.load_vars(exe, weights, predicate=if_exist) + def get_config_from_sec(self, sec, item, default=None): if sec.upper() not in self.cfg: return default @@ -178,4 +184,3 @@ def regist_model(name, model): def get_model(name, cfg, mode='train'): return model_zoo.get(name, cfg, mode) - diff --git a/PaddleCV/video/models/nonlocal_model/__init__.py b/PaddleCV/video/models/nonlocal_model/__init__.py new file mode 100644 index 00000000..0a127553 --- /dev/null +++ b/PaddleCV/video/models/nonlocal_model/__init__.py @@ -0,0 +1 @@ +from .nonlocal_model import * diff --git a/PaddleCV/video/models/nonlocal_model/nonlocal_helper.py b/PaddleCV/video/models/nonlocal_model/nonlocal_helper.py new file mode 100644 index 00000000..a02603e4 --- /dev/null +++ b/PaddleCV/video/models/nonlocal_model/nonlocal_helper.py @@ -0,0 +1,254 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr + + +# 3d spacetime nonlocal (v1, spatial downsample) +def spacetime_nonlocal(blob_in, dim_in, dim_out, batch_size, prefix, dim_inner, cfg, \ + test_mode = False, max_pool_stride = 2): + #------------ + cur = blob_in + # we do projection to convert each spacetime location to a feature + # theta original size + # e.g., (8, 1024, 4, 14, 14) => (8, 1024, 4, 14, 14) + theta = fluid.layers.conv3d( + input=cur, + num_filters=dim_inner, + filter_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + param_attr=ParamAttr( + name=prefix + '_theta' + "_w", + initializer=fluid.initializer.Normal( + loc=0.0, scale=cfg.NONLOCAL.conv_init_std)), + bias_attr=ParamAttr( + name=prefix + '_theta' + "_b", + initializer=fluid.initializer.Constant(value=0.)) + if (cfg.NONLOCAL.no_bias == 0) else False, + name=prefix + '_theta') + theta_shape = theta.shape + + # phi and g: half spatial size + # e.g., (8, 1024, 4, 14, 14) => (8, 1024, 4, 7, 7) + if cfg.NONLOCAL.use_maxpool: + max_pool = fluid.layers.pool3d( + input=cur, + pool_size=[1, max_pool_stride, max_pool_stride], + pool_type='max', + pool_stride=[1, max_pool_stride, max_pool_stride], + pool_padding=[0, 0, 0], + name=prefix + '_pool') + else: + max_pool = cur + + phi = fluid.layers.conv3d( + input=max_pool, + num_filters=dim_inner, + filter_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + param_attr=ParamAttr( + name=prefix + '_phi' + "_w", + initializer=fluid.initializer.Normal( + loc=0.0, scale=cfg.NONLOCAL.conv_init_std)), + bias_attr=ParamAttr( + name=prefix + '_phi' + "_b", + initializer=fluid.initializer.Constant(value=0.)) + if (cfg.NONLOCAL.no_bias == 0) else False, + name=prefix + '_phi') + phi_shape = phi.shape + g = fluid.layers.conv3d( + input=max_pool, + num_filters=dim_inner, + filter_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + param_attr=ParamAttr( + name=prefix + '_g' + "_w", + initializer=fluid.initializer.Normal( + loc=0.0, scale=cfg.NONLOCAL.conv_init_std)), + bias_attr=ParamAttr( + name=prefix + '_g' + "_b", + initializer=fluid.initializer.Constant(value=0.)) + if (cfg.NONLOCAL.no_bias == 0) else False, + name=prefix + '_g') + g_shape = g.shape + + # we have to use explicit batch size (to support arbitrary spacetime size) + # e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784) + theta = fluid.layers.reshape( + theta, [-1, 0, theta_shape[2] * theta_shape[3] * theta_shape[4]]) + theta = fluid.layers.transpose(theta, [0, 2, 1]) + phi = fluid.layers.reshape( + phi, [-1, 0, phi_shape[2] * phi_shape[3] * phi_shape[4]]) + theta_phi = fluid.layers.matmul(theta, phi, name=prefix + '_affinity') + g = fluid.layers.reshape(g, [-1, 0, g_shape[2] * g_shape[3] * g_shape[4]]) + if cfg.NONLOCAL.use_softmax: + if cfg.NONLOCAL.use_scale is True: + theta_phi_sc = fluid.layers.scale(theta_phi, scale=dim_inner**-.5) + else: + theta_phi_sc = theta_phi + p = fluid.layers.softmax( + theta_phi_sc, name=prefix + '_affinity' + '_prob') + else: + # not clear about what is doing in xlw's code + p = None # not implemented + raise "Not implemented when not use softmax" + + # note g's axis[2] corresponds to p's axis[2] + # e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1) + p = fluid.layers.transpose(p, [0, 2, 1]) + t = fluid.layers.matmul(g, p, name=prefix + '_y') + + # reshape back + # e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14) + t_shape = t.shape + # print(t_shape) + # print(theta_shape) + t_re = fluid.layers.reshape(t, shape=list(theta_shape)) + blob_out = t_re + + blob_out = fluid.layers.conv3d( + input=blob_out, + num_filters=dim_out, + filter_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + param_attr=ParamAttr( + name=prefix + '_out' + "_w", + initializer=fluid.initializer.Constant(value=0.) + if cfg.NONLOCAL.use_zero_init_conv else fluid.initializer.Normal( + loc=0.0, scale=cfg.NONLOCAL.conv_init_std)), + bias_attr=ParamAttr( + name=prefix + '_out' + "_b", + initializer=fluid.initializer.Constant(value=0.)) + if (cfg.NONLOCAL.no_bias == 0) else False, + name=prefix + '_out') + blob_out_shape = blob_out.shape + + if cfg.NONLOCAL.use_bn is True: + bn_name = prefix + "_bn" + blob_out = fluid.layers.batch_norm( + blob_out, + is_test=test_mode, + momentum=cfg.NONLOCAL.bn_momentum, + epsilon=cfg.NONLOCAL.bn_epsilon, + name=bn_name, + param_attr=ParamAttr( + name=bn_name + "_scale", + initializer=fluid.initializer.Constant( + value=cfg.NONLOCAL.bn_init_gamma), + regularizer=fluid.regularizer.L2Decay( + cfg.TRAIN.weight_decay_bn)), + bias_attr=ParamAttr( + name=bn_name + "_offset", + regularizer=fluid.regularizer.L2Decay( + cfg.TRAIN.weight_decay_bn)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") # add bn + + if cfg.NONLOCAL.use_affine is True: + affine_scale = fluid.layers.create_parameter( + shape=[blob_out_shape[1]], + dtype=blob_out.dtype, + attr=ParamAttr(name=prefix + '_affine' + '_s'), + default_initializer=fluid.initializer.Constant(value=1.)) + affine_bias = fluid.layers.create_parameter( + shape=[blob_out_shape[1]], + dtype=blob_out.dtype, + attr=ParamAttr(name=prefix + '_affine' + '_b'), + default_initializer=fluid.initializer.Constant(value=0.)) + blob_out = fluid.layers.affine_channel( + blob_out, + scale=affine_scale, + bias=affine_bias, + name=prefix + '_affine') # add affine + + return blob_out + + +def add_nonlocal(blob_in, + dim_in, + dim_out, + batch_size, + prefix, + dim_inner, + cfg, + test_mode=False): + blob_out = spacetime_nonlocal(blob_in, \ + dim_in, dim_out, batch_size, prefix, dim_inner, cfg, test_mode = test_mode) + blob_out = fluid.layers.elementwise_add( + blob_out, blob_in, name=prefix + '_sum') + return blob_out + + +# this is to reduce memory usage if the feature maps are big +# devide the feature maps into groups in the temporal dimension, +# and perform non-local operations inside each group. +def add_nonlocal_group(blob_in, + dim_in, + dim_out, + batch_size, + pool_stride, + height, + width, + group_size, + prefix, + dim_inner, + cfg, + test_mode=False): + group_num = int(pool_stride / group_size) + assert (pool_stride % group_size == 0), \ + 'nonlocal block {}: pool_stride({}) should be divided by group size({})'.format(prefix, pool_stride, group_size) + + if group_num > 1: + blob_in = fluid.layers.transpose( + blob_in, [0, 2, 1, 3, 4], name=prefix + '_pre_trans1') + blob_in = fluid.layers.reshape( + blob_in, + [batch_size * group_num, group_size, dim_in, height, width], + name=prefix + '_pre_reshape1') + blob_in = fluid.layers.transpose( + blob_in, [0, 2, 1, 3, 4], name=prefix + '_pre_trans2') + + blob_out = spacetime_nonlocal( + blob_in, + dim_in, + dim_out, + batch_size, + prefix, + dim_inner, + cfg, + test_mode=test_mode) + blob_out = fluid.layers.elementwise_add( + blob_out, blob_in, name=prefix + '_sum') + + if group_num > 1: + blob_out = fluid.layers.transpose( + blob_out, [0, 2, 1, 3, 4], name=prefix + '_post_trans1') + blob_out = fluid.layers.reshape( + blob_out, + [batch_size, group_num * group_size, dim_out, height, width], + name=prefix + '_post_reshape1') + blob_out = fluid.layers.transpose( + blob_out, [0, 2, 1, 3, 4], name=prefix + '_post_trans2') + + return blob_out diff --git a/PaddleCV/video/models/nonlocal_model/nonlocal_model.py b/PaddleCV/video/models/nonlocal_model/nonlocal_model.py new file mode 100644 index 00000000..db27a111 --- /dev/null +++ b/PaddleCV/video/models/nonlocal_model/nonlocal_model.py @@ -0,0 +1,299 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import numpy as np +import cPickle +import paddle.fluid as fluid + +from ..model import ModelBase +import resnet_video + +import logging +logger = logging.getLogger(__name__) + +__all__ = ["NonLocal"] + +# To add new models, import them, add them to this map and models/TARGETS + + +class NonLocal(ModelBase): + def __init__(self, name, cfg, mode='train'): + super(NonLocal, self).__init__(name, cfg, mode=mode) + self.get_config() + + def get_config(self): + # video_length + self.video_length = self.get_config_from_sec(self.mode, 'video_length') + # crop size + self.crop_size = self.get_config_from_sec(self.mode, 'crop_size') + + def build_input(self, use_pyreader=True): + input_shape = [3, self.video_length, self.crop_size, self.crop_size] + label_shape = [1] + py_reader = None + if use_pyreader: + assert self.mode != 'infer', \ + 'pyreader is not recommendated when infer, please set use_pyreader to be false.' + py_reader = fluid.layers.py_reader( + capacity=20, + shapes=[[-1] + input_shape, [-1] + label_shape], + dtypes=['float32', 'int64'], + name='train_py_reader' + if self.is_training else 'test_py_reader', + use_double_buffer=True) + data, label = fluid.layers.read_file(py_reader) + self.py_reader = py_reader + else: + data = fluid.layers.data( + name='train_data' if self.is_training else 'test_data', + shape=input_shape, + dtype='float32') + if self.mode != 'infer': + label = fluid.layers.data( + name='train_label' if self.is_training else 'test_label', + shape=label_shape, + dtype='int64') + else: + label = None + self.feature_input = [data] + self.label_input = label + + def create_model_args(self): + return None + + def build_model(self): + pred, loss = resnet_video.create_model( + data=self.feature_input[0], + label=self.label_input, + cfg=self.cfg, + is_training=self.is_training, + mode=self.mode) + if loss is not None: + loss = fluid.layers.mean(loss) + self.network_outputs = [pred] + self.loss_ = loss + + def optimizer(self): + base_lr = self.get_config_from_sec('TRAIN', 'learning_rate') + lr_decay = self.get_config_from_sec('TRAIN', 'learning_rate_decay') + step_sizes = self.get_config_from_sec('TRAIN', 'step_sizes') + lr_bounds, lr_values = get_learning_rate_decay_list(base_lr, lr_decay, + step_sizes) + learning_rate = fluid.layers.piecewise_decay( + boundaries=lr_bounds, values=lr_values) + + momentum = self.get_config_from_sec('TRAIN', 'momentum') + use_nesterov = self.get_config_from_sec('TRAIN', 'nesterov') + l2_weight_decay = self.get_config_from_sec('TRAIN', 'weight_decay') + logger.info( + 'Build up optimizer, \ntype: {}, \nmomentum: {}, \nnesterov: {}, \ + \nregularization: L2 {}, \nlr_values: {}, lr_bounds: {}' + .format('Momentum', momentum, use_nesterov, l2_weight_decay, + lr_values, lr_bounds)) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=momentum, + use_nesterov=use_nesterov, + regularization=fluid.regularizer.L2Decay(l2_weight_decay)) + return optimizer + + def loss(self): + return self.loss_ + + def outputs(self): + return self.network_outputs + + def feeds(self): + return self.feature_input if self.mode == 'infer' else \ + self.feature_input + [self.label_input] + + def pretrain_info(self): + return None, None + + def weights_info(self): + pass + + def load_pretrain_params(self, exe, pretrain, prog, place): + load_params_from_file(exe, prog, pretrain, place) + + def load_test_weights(self, exe, weights, prog, place): + super(NonLocal, self).load_test_weights(exe, weights, prog, place) + pred_w = fluid.global_scope().find_var('pred_w').get_tensor() + pred_array = np.array(pred_w) + pred_w_shape = pred_array.shape + if len(pred_w_shape) == 2: + logger.info('reshape for pred_w when test') + pred_array = np.transpose(pred_array, (1, 0)) + pred_w_shape = pred_array.shape + pred_array = np.reshape( + pred_array, [pred_w_shape[0], pred_w_shape[1], 1, 1, 1]) + pred_w.set(pred_array.astype('float32'), place) + + +def get_learning_rate_decay_list(base_learning_rate, lr_decay, step_lists): + lr_bounds = [] + lr_values = [base_learning_rate * 1] + cur_step = 0 + for i in range(len(step_lists)): + cur_step += step_lists[i] + lr_bounds.append(cur_step) + decay_rate = lr_decay**(i + 1) + lr_values.append(base_learning_rate * decay_rate) + + return lr_bounds, lr_values + + +def load_params_from_pkl_file(prog, pretrained_file, place): + param_list = prog.block(0).all_parameters() + param_name_list = [p.name for p in param_list] + + if os.path.exists(pretrained_file): + params_from_file = cPickle.load(open(pretrained_file)) + if len(params_from_file.keys()) == 1: + params_from_file = params_from_file['blobs'] + param_name_from_file = params_from_file.keys() + param_list = prog.block(0).all_parameters() + param_name_list = [p.name for p in param_list] + + common_names = get_common_names(param_name_list, param_name_from_file) + + logger.info('-------- loading params -----------') + for name in common_names: + t = fluid.global_scope().find_var(name).get_tensor() + t_array = np.array(t) + f_array = params_from_file[name] + if 'pred' in name: + assert np.prod(t_array.shape) == np.prod( + f_array.shape), "number of params should be the same" + if t_array.shape == f_array.shape: + logger.info("pred param is the same {}".format(name)) + else: + re_f_array = np.reshape(f_array, t_array.shape) + t.set(re_f_array.astype('float32'), place) + logger.info("load pred param {}".format(name)) + continue + if t_array.shape == f_array.shape: + t.set(f_array.astype('float32'), place) + logger.info("load param {}".format(name)) + elif (t_array.shape[:2] == f_array.shape[:2]) and ( + t_array.shape[-2:] == f_array.shape[-2:]): + num_inflate = t_array.shape[2] + stack_f_array = np.stack( + [f_array] * num_inflate, axis=2) / float(num_inflate) + assert t_array.shape == stack_f_array.shape, "inflated shape should be the same with tensor {}".format( + name) + t.set(stack_f_array.astype('float32'), place) + logger.info("load inflated({}) param {}".format(num_inflate, + name)) + else: + logger.info("Invalid case for name: {}".format(name)) + raise + logger.info("finished loading params from resnet pretrained model") + + +def load_params_from_paddle_file(exe, prog, pretrained_file, place): + if os.path.isdir(pretrained_file): + param_list = prog.block(0).all_parameters() + param_name_list = [p.name for p in param_list] + param_shape = {} + for name in param_name_list: + param_tensor = fluid.global_scope().find_var(name).get_tensor() + param_shape[name] = np.array(param_tensor).shape + + param_name_from_file = os.listdir(pretrained_file) + common_names = get_common_names(param_name_list, param_name_from_file) + + logger.info('-------- loading params -----------') + + # load params from file + def is_parameter(var): + if isinstance(var, fluid.framework.Parameter): + return isinstance(var, fluid.framework.Parameter) and \ + os.path.exists(os.path.join(pretrained_file, var.name)) + + logger.info("Load pretrain weights from file {}".format( + pretrained_file)) + vars = filter(is_parameter, prog.list_vars()) + fluid.io.load_vars(exe, pretrained_file, vars=vars, main_program=prog) + + # reset params if necessary + for name in common_names: + t = fluid.global_scope().find_var(name).get_tensor() + t_array = np.array(t) + origin_shape = param_shape[name] + if 'pred' in name: + assert np.prod(t_array.shape) == np.prod( + origin_shape), "number of params should be the same" + if t_array.shape == origin_shape: + logger.info("pred param is the same {}".format(name)) + else: + reshaped_t_array = np.reshape(t_array, origin_shape) + t.set(reshaped_t_array.astype('float32'), place) + logger.info("load pred param {}".format(name)) + continue + if t_array.shape == origin_shape: + logger.info("load param {}".format(name)) + elif (t_array.shape[:2] == origin_shape[:2]) and ( + t_array.shape[-2:] == origin_shape[-2:]): + num_inflate = origin_shape[2] + stack_t_array = np.stack( + [t_array] * num_inflate, axis=2) / float(num_inflate) + assert origin_shape == stack_t_array.shape, "inflated shape should be the same with tensor {}".format( + name) + t.set(stack_t_array.astype('float32'), place) + logger.info("load inflated({}) param {}".format(num_inflate, + name)) + else: + logger.info("Invalid case for name: {}".format(name)) + raise + logger.info("finished loading params from resnet pretrained model") + else: + logger.info( + "pretrained file is not in a directory, not suitable to load params". + format(pretrained_file)) + pass + + +def get_common_names(param_name_list, param_name_from_file): + # name check and return common names both in param_name_list and file + common_names = [] + paddle_only_names = [] + file_only_names = [] + logger.info('-------- comon params -----------') + for name in param_name_list: + if name in param_name_from_file: + common_names.append(name) + logger.info(name) + else: + paddle_only_names.append(name) + logger.info('-------- paddle only params ----------') + for name in paddle_only_names: + logger.info(name) + logger.info('-------- file only params -----------') + for name in param_name_from_file: + if name in param_name_list: + assert name in common_names + else: + file_only_names.append(name) + logger.info(name) + return common_names + + +def load_params_from_file(exe, prog, pretrained_file, place): + logger.info('load params from {}'.format(pretrained_file)) + if '.pkl' in pretrained_file: + load_params_from_pkl_file(prog, pretrained_file, place) + else: + load_params_from_paddle_file(exe, prog, pretrained_file, place) diff --git a/PaddleCV/video/models/nonlocal_model/resnet_helper.py b/PaddleCV/video/models/nonlocal_model/resnet_helper.py new file mode 100644 index 00000000..8ba6a3bc --- /dev/null +++ b/PaddleCV/video/models/nonlocal_model/resnet_helper.py @@ -0,0 +1,356 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr + +import numpy as np +import nonlocal_helper + + +def Conv3dAffine(blob_in, + prefix, + dim_in, + dim_out, + filter_size, + stride, + padding, + cfg, + group=1, + test_mode=False, + bn_init=None): + blob_out = fluid.layers.conv3d( + input=blob_in, + num_filters=dim_out, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=group, + param_attr=ParamAttr( + name=prefix + "_weights", initializer=fluid.initializer.MSRA()), + bias_attr=False, + name=prefix + "_conv") + blob_out_shape = blob_out.shape + + affine_name = "bn" + prefix[3:] + + affine_scale = fluid.layers.create_parameter( + shape=[blob_out_shape[1]], + dtype=blob_out.dtype, + attr=ParamAttr(name=affine_name + '_scale'), + default_initializer=fluid.initializer.Constant(value=1.)) + affine_bias = fluid.layers.create_parameter( + shape=[blob_out_shape[1]], + dtype=blob_out.dtype, + attr=ParamAttr(name=affine_name + '_offset'), + default_initializer=fluid.initializer.Constant(value=0.)) + blob_out = fluid.layers.affine_channel( + blob_out, scale=affine_scale, bias=affine_bias, name=affine_name) + + return blob_out + + +def Conv3dBN(blob_in, + prefix, + dim_in, + dim_out, + filter_size, + stride, + padding, + cfg, + group=1, + test_mode=False, + bn_init=None): + blob_out = fluid.layers.conv3d( + input=blob_in, + num_filters=dim_out, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=group, + param_attr=ParamAttr( + name=prefix + "_weights", initializer=fluid.initializer.MSRA()), + bias_attr=False, + name=prefix + "_conv") + + bn_name = "bn" + prefix[3:] + + blob_out = fluid.layers.batch_norm( + blob_out, + is_test=test_mode, + momentum=cfg.MODEL.bn_momentum, + epsilon=cfg.MODEL.bn_epsilon, + name=bn_name, + param_attr=ParamAttr( + name=bn_name + "_scale", + initializer=fluid.initializer.Constant(value=bn_init if + (bn_init != None) else 1.), + regularizer=fluid.regularizer.L2Decay(cfg.TRAIN.weight_decay_bn)), + bias_attr=ParamAttr( + name=bn_name + "_offset", + regularizer=fluid.regularizer.L2Decay(cfg.TRAIN.weight_decay_bn)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") + return blob_out + + +# 3d bottleneck +def bottleneck_transformation_3d(blob_in, + dim_in, + dim_out, + stride, + prefix, + dim_inner, + cfg, + group=1, + use_temp_conv=1, + temp_stride=1, + test_mode=False): + conv_op = Conv3dAffine if cfg.MODEL.use_affine else Conv3dBN + + # 1x1 layer + blob_out = conv_op( + blob_in, + prefix + "_branch2a", + dim_in, + dim_inner, [1 + use_temp_conv * 2, 1, 1], [temp_stride, 1, 1], + [use_temp_conv, 0, 0], + cfg, + test_mode=test_mode) + blob_out = fluid.layers.relu(blob_out, name=prefix + "_branch2a" + "_relu") + + # 3x3 layer + blob_out = conv_op( + blob_out, + prefix + '_branch2b', + dim_inner, + dim_inner, [1, 3, 3], [1, stride, stride], [0, 1, 1], + cfg, + group=group, + test_mode=test_mode) + blob_out = fluid.layers.relu(blob_out, name=prefix + "_branch2b" + "_relu") + + # 1x1 layer, no relu + blob_out = conv_op( + blob_out, + prefix + '_branch2c', + dim_inner, + dim_out, [1, 1, 1], [1, 1, 1], [0, 0, 0], + cfg, + test_mode=test_mode, + bn_init=cfg.MODEL.bn_init_gamma) + + return blob_out + + +def _add_shortcut_3d(blob_in, + prefix, + dim_in, + dim_out, + stride, + cfg, + temp_stride=1, + test_mode=False): + if ((dim_in == dim_out) and (temp_stride == 1) and (stride == 1)): + # identity mapping (do nothing) + return blob_in + else: + # when dim changes + conv_op = Conv3dAffine if cfg.MODEL.use_affine else Conv3dBN + blob_out = conv_op( + blob_in, + prefix, + dim_in, + dim_out, [1, 1, 1], [temp_stride, stride, stride], [0, 0, 0], + cfg, + test_mode=test_mode) + + return blob_out + + +# residual block abstraction +def _generic_residual_block_3d(blob_in, + dim_in, + dim_out, + stride, + prefix, + dim_inner, + cfg, + group=1, + use_temp_conv=0, + temp_stride=1, + trans_func=None, + test_mode=False): + # transformation branch (e.g. 1x1-3x3-1x1, or 3x3-3x3), namely "F(x)" + if trans_func is None: + trans_func = globals()[cfg.RESNETS.trans_func] + + tr_blob = trans_func( + blob_in, + dim_in, + dim_out, + stride, + prefix, + dim_inner, + cfg, + group=group, + use_temp_conv=use_temp_conv, + temp_stride=temp_stride, + test_mode=test_mode) + + # create short cut, namely, "x" + sc_blob = _add_shortcut_3d( + blob_in, + prefix + "_branch1", + dim_in, + dim_out, + stride, + cfg, + temp_stride=temp_stride, + test_mode=test_mode) + + # addition, namely, "x + F(x)", and relu + sum_blob = fluid.layers.elementwise_add( + tr_blob, sc_blob, act='relu', name=prefix + '_sum') + + return sum_blob + + +def res_stage_nonlocal(block_fn, + blob_in, + dim_in, + dim_out, + stride, + num_blocks, + prefix, + cfg, + dim_inner=None, + group=None, + use_temp_convs=None, + temp_strides=None, + batch_size=None, + nonlocal_name=None, + nonlocal_mod=1000, + test_mode=False): + # prefix is something like: res2, res3, etc. + # each res layer has num_blocks stacked. + + # check dtype and format of use_temp_convs and temp_strides + if use_temp_convs is None: + use_temp_convs = np.zeros(num_blocks).astype(int) + if temp_strides is None: + temp_strides = np.ones(num_blocks).astype(int) + + if len(use_temp_convs) < num_blocks: + for _ in range(num_blocks - len(use_temp_convs)): + use_temp_convs.append(0) + temp_strides.append(1) + + for idx in range(num_blocks): + block_prefix = '{}{}'.format(prefix, chr(idx + 97)) + block_stride = 2 if ((idx == 0) and (stride == 2)) else 1 + blob_in = _generic_residual_block_3d( + blob_in, + dim_in, + dim_out, + block_stride, + block_prefix, + dim_inner, + cfg, + group=group, + use_temp_conv=use_temp_convs[idx], + temp_stride=temp_strides[idx], + test_mode=test_mode) + dim_in = dim_out + + if idx % nonlocal_mod == nonlocal_mod - 1: + blob_in = nonlocal_helper.add_nonlocal( + blob_in, + dim_in, + dim_in, + batch_size, + nonlocal_name + '_{}'.format(idx), + int(dim_in / 2), + cfg, + test_mode=test_mode) + + return blob_in, dim_in + + +def res_stage_nonlocal_group(block_fn, + blob_in, + dim_in, + dim_out, + stride, + num_blocks, + prefix, + cfg, + dim_inner=None, + group=None, + use_temp_convs=None, + temp_strides=None, + batch_size=None, + pool_stride=None, + spatial_dim=None, + group_size=None, + nonlocal_name=None, + nonlocal_mod=1000, + test_mode=False): + # prefix is something like res2, res3, etc. + # each res layer has num_blocks stacked + + # check dtype and format of use_temp_convs and temp_strides + if use_temp_convs is None: + use_temp_convs = np.zeros(num_blocks).astype(int) + if temp_strides is None: + temp_strides = np.ones(num_blocks).astype(int) + + for idx in range(num_blocks): + block_prefix = "{}{}".format(prefix, chr(idx + 97)) + block_stride = 2 if (idx == 0 and stride == 2) else 1 + blob_in = _generic_residual_block_3d( + blob_in, + dim_in, + dim_out, + block_stride, + block_prefix, + dim_inner, + cfg, + group=group, + use_temp_conv=use_temp_convs[idx], + temp_stride=temp_strides[idx], + test_mode=test_mode) + dim_in = dim_out + + if idx % nonlocal_mod == nonlocal_mod - 1: + blob_in = nonlocal_helper.add_nonlocal_group( + blob_in, + dim_in, + dim_in, + batch_size, + pool_stride, + spatial_dim, + spatial_dim, + group_size, + nonlocal_name + "_{}".format(idx), + int(dim_in / 2), + cfg, + test_mode=test_mode) + + return blob_in, dim_in diff --git a/PaddleCV/video/models/nonlocal_model/resnet_video.py b/PaddleCV/video/models/nonlocal_model/resnet_video.py new file mode 100644 index 00000000..bb4b51e3 --- /dev/null +++ b/PaddleCV/video/models/nonlocal_model/resnet_video.py @@ -0,0 +1,371 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +import resnet_helper +import logging + +logger = logging.getLogger(__name__) + +# For more depths, add the block config here +BLOCK_CONFIG = { + 50: (3, 4, 6, 3), + 101: (3, 4, 23, 3), +} + + +# ------------------------------------------------------------------------ +# obtain_arc defines the temporal kernel radius and temporal strides for +# each layers residual blocks in a resnet. +# e.g. use_temp_convs = 1 means a temporal kernel of 3 is used. +# In ResNet50, it has (3, 4, 6, 3) blocks in conv2, 3, 4, 5, +# so the lengths of the corresponding lists are (3, 4, 6, 3). +# ------------------------------------------------------------------------ +def obtain_arc(arc_type, video_length): + + pool_stride = 1 + + # c2d, ResNet50 + if arc_type == 1: + use_temp_convs_1 = [0] + temp_strides_1 = [1] + use_temp_convs_2 = [0, 0, 0] + temp_strides_2 = [1, 1, 1] + use_temp_convs_3 = [0, 0, 0, 0] + temp_strides_3 = [1, 1, 1, 1] + use_temp_convs_4 = [0, ] * 6 + temp_strides_4 = [1, ] * 6 + use_temp_convs_5 = [0, 0, 0] + temp_strides_5 = [1, 1, 1] + + pool_stride = int(video_length / 2) + + # i3d, ResNet50 + if arc_type == 2: + use_temp_convs_1 = [2] + temp_strides_1 = [1] + use_temp_convs_2 = [1, 1, 1] + temp_strides_2 = [1, 1, 1] + use_temp_convs_3 = [1, 0, 1, 0] + temp_strides_3 = [1, 1, 1, 1] + use_temp_convs_4 = [1, 0, 1, 0, 1, 0] + temp_strides_4 = [1, 1, 1, 1, 1, 1] + use_temp_convs_5 = [0, 1, 0] + temp_strides_5 = [1, 1, 1] + + pool_stride = int(video_length / 2) + + # c2d, ResNet101 + if arc_type == 3: + use_temp_convs_1 = [0] + temp_strides_1 = [1] + use_temp_convs_2 = [0, 0, 0] + temp_strides_2 = [1, 1, 1] + use_temp_convs_3 = [0, 0, 0, 0] + temp_strides_3 = [1, 1, 1, 1] + use_temp_convs_4 = [0, ] * 23 + temp_strides_4 = [1, ] * 23 + use_temp_convs_5 = [0, 0, 0] + temp_strides_5 = [1, 1, 1] + + pool_stride = int(video_length / 2) + + # i3d, ResNet101 + if arc_type == 4: + use_temp_convs_1 = [2] + temp_strides_1 = [1] + use_temp_convs_2 = [1, 1, 1] + temp_strides_2 = [1, 1, 1] + use_temp_convs_3 = [1, 0, 1, 0] + temp_strides_3 = [1, 1, 1, 1] + use_temp_convs_4 = [] + for i in range(23): + if i % 2 == 0: + use_temp_convs_4.append(1) + else: + use_temp_convs_4.append(0) + + temp_strides_4 = [1] * 23 + use_temp_convs_5 = [0, 1, 0] + temp_strides_5 = [1, 1, 1] + + pool_stride = int(video_length / 2) + + use_temp_convs_set = [ + use_temp_convs_1, use_temp_convs_2, use_temp_convs_3, use_temp_convs_4, + use_temp_convs_5 + ] + temp_strides_set = [ + temp_strides_1, temp_strides_2, temp_strides_3, temp_strides_4, + temp_strides_5 + ] + + return use_temp_convs_set, temp_strides_set, pool_stride + + +def create_model(data, label, cfg, is_training=True, mode='train'): + group = cfg.RESNETS.num_groups + width_per_group = cfg.RESNETS.width_per_group + batch_size = int(cfg.TRAIN.batch_size / cfg.NUM_GPUS) + + logger.info('--------------- ResNet-{} {}x{}d-{}, {} ---------------'. + format(cfg.MODEL.depth, group, width_per_group, + cfg.RESNETS.trans_func, cfg.MODEL.dataset)) + + assert cfg.MODEL.depth in BLOCK_CONFIG.keys(), \ + "Block config is not defined for specified model depth." + (n1, n2, n3, n4) = BLOCK_CONFIG[cfg.MODEL.depth] + + res_block = resnet_helper._generic_residual_block_3d + dim_inner = group * width_per_group + + use_temp_convs_set, temp_strides_set, pool_stride = obtain_arc( + cfg.MODEL.video_arc_choice, cfg[mode.upper()]['video_length']) + logger.info(use_temp_convs_set) + logger.info(temp_strides_set) + conv_blob = fluid.layers.conv3d( + input=data, + num_filters=64, + filter_size=[1 + use_temp_convs_set[0][0] * 2, 7, 7], + stride=[temp_strides_set[0][0], 2, 2], + padding=[use_temp_convs_set[0][0], 3, 3], + param_attr=ParamAttr( + name='conv1' + "_weights", initializer=fluid.initializer.MSRA()), + bias_attr=False, + name='conv1') + + test_mode = False if (mode == 'train') else True + if cfg.MODEL.use_affine is False: + # use bn + bn_name = 'bn_conv1' + bn_blob = fluid.layers.batch_norm( + conv_blob, + is_test=test_mode, + momentum=cfg.MODEL.bn_momentum, + epsilon=cfg.MODEL.bn_epsilon, + name=bn_name, + param_attr=ParamAttr( + name=bn_name + "_scale", + regularizer=fluid.regularizer.L2Decay( + cfg.TRAIN.weight_decay_bn)), + bias_attr=ParamAttr( + name=bn_name + "_offset", + regularizer=fluid.regularizer.L2Decay( + cfg.TRAIN.weight_decay_bn)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") + else: + # use affine + affine_name = 'bn_conv1' + conv_blob_shape = conv_blob.shape + affine_scale = fluid.layers.create_parameter( + shape=[conv_blob_shape[1]], + dtype=conv_blob.dtype, + attr=ParamAttr(name=affine_name + '_scale'), + default_initializer=fluid.initializer.Constant(value=1.)) + affine_bias = fluid.layers.create_parameter( + shape=[conv_blob_shape[1]], + dtype=conv_blob.dtype, + attr=ParamAttr(name=affine_name + '_offset'), + default_initializer=fluid.initializer.Constant(value=0.)) + bn_blob = fluid.layers.affine_channel( + conv_blob, scale=affine_scale, bias=affine_bias, name=affine_name) + + # relu + relu_blob = fluid.layers.relu(bn_blob, name='res_conv1_bn_relu') + # max pool + max_pool = fluid.layers.pool3d( + input=relu_blob, + pool_size=[1, 3, 3], + pool_type='max', + pool_stride=[1, 2, 2], + pool_padding=[0, 0, 0], + name='pool1') + + # building res block + if cfg.MODEL.depth in [50, 101]: + blob_in, dim_in = resnet_helper.res_stage_nonlocal( + res_block, + max_pool, + 64, + 256, + stride=1, + num_blocks=n1, + prefix='res2', + cfg=cfg, + dim_inner=dim_inner, + group=group, + use_temp_convs=use_temp_convs_set[1], + temp_strides=temp_strides_set[1], + test_mode=test_mode) + + layer_mod = cfg.NONLOCAL.layer_mod + if cfg.MODEL.depth == 101: + layer_mod = 2 + if cfg.NONLOCAL.conv3_nonlocal is False: + layer_mod = 1000 + + blob_in = fluid.layers.pool3d( + blob_in, + pool_size=[2, 1, 1], + pool_type='max', + pool_stride=[2, 1, 1], + pool_padding=[0, 0, 0], + name='pool2') + + if cfg.MODEL.use_affine is False: + blob_in, dim_in = resnet_helper.res_stage_nonlocal( + res_block, + blob_in, + dim_in, + 512, + stride=2, + num_blocks=n2, + prefix='res3', + cfg=cfg, + dim_inner=dim_inner * 2, + group=group, + use_temp_convs=use_temp_convs_set[2], + temp_strides=temp_strides_set[2], + batch_size=batch_size, + nonlocal_name="nonlocal_conv3", + nonlocal_mod=layer_mod, + test_mode=test_mode) + else: + crop_size = cfg[mode.upper()]['crop_size'] + blob_in, dim_in = resnet_helper.res_stage_nonlocal_group( + res_block, + blob_in, + dim_in, + 512, + stride=2, + num_blocks=n2, + prefix='res3', + cfg=cfg, + dim_inner=dim_inner * 2, + group=group, + use_temp_convs=use_temp_convs_set[2], + temp_strides=temp_strides_set[2], + batch_size=batch_size, + pool_stride=pool_stride, + spatial_dim=int(crop_size / 8), + group_size=4, + nonlocal_name="nonlocal_conv3_group", + nonlocal_mod=layer_mod, + test_mode=test_mode) + + layer_mod = cfg.NONLOCAL.layer_mod + if cfg.MODEL.depth == 101: + layer_mod = layer_mod * 4 - 1 + if cfg.NONLOCAL.conv4_nonlocal is False: + layer_mod = 1000 + + blob_in, dim_in = resnet_helper.res_stage_nonlocal( + res_block, + blob_in, + dim_in, + 1024, + stride=2, + num_blocks=n3, + prefix='res4', + cfg=cfg, + dim_inner=dim_inner * 4, + group=group, + use_temp_convs=use_temp_convs_set[3], + temp_strides=temp_strides_set[3], + batch_size=batch_size, + nonlocal_name="nonlocal_conv4", + nonlocal_mod=layer_mod, + test_mode=test_mode) + + blob_in, dim_in = resnet_helper.res_stage_nonlocal( + res_block, + blob_in, + dim_in, + 2048, + stride=2, + num_blocks=n4, + prefix='res5', + cfg=cfg, + dim_inner=dim_inner * 8, + group=group, + use_temp_convs=use_temp_convs_set[4], + temp_strides=temp_strides_set[4], + test_mode=test_mode) + + else: + raise Exception("Unsupported network settings.") + + blob_out = fluid.layers.pool3d( + blob_in, + pool_size=[pool_stride, 7, 7], + pool_type='avg', + pool_stride=[1, 1, 1], + pool_padding=[0, 0, 0], + name='pool5') + + if (cfg.TRAIN.dropout_rate > 0) and (test_mode is False): + blob_out = fluid.layers.dropout( + blob_out, cfg.TRAIN.dropout_rate, is_test=test_mode) + + if mode in ['train', 'valid']: + blob_out = fluid.layers.fc( + blob_out, + cfg.MODEL.num_classes, + param_attr=ParamAttr( + name='pred' + "_w", + initializer=fluid.initializer.Normal( + loc=0.0, scale=cfg.MODEL.fc_init_std)), + bias_attr=ParamAttr( + name='pred' + "_b", + initializer=fluid.initializer.Constant(value=0.)), + name='pred') + elif mode in ['test', 'infer']: + blob_out = fluid.layers.conv3d( + input=blob_out, + num_filters=cfg.MODEL.num_classes, + filter_size=[1, 1, 1], + stride=[1, 1, 1], + padding=[0, 0, 0], + param_attr=ParamAttr( + name='pred' + "_w", initializer=fluid.initializer.MSRA()), + bias_attr=ParamAttr( + name='pred' + "_b", + initializer=fluid.initializer.Constant(value=0.)), + name='pred') + + if (mode == 'train') or (mode == 'valid'): + softmax = fluid.layers.softmax(blob_out) + loss = fluid.layers.cross_entropy( + softmax, label, soft_label=False, ignore_index=-100) + + elif (mode == 'test') or (mode == 'infer'): + # fully convolutional testing, when loading test model, + # params should be copied from train_prog fc layer named pred + blob_out = fluid.layers.transpose( + blob_out, [0, 2, 3, 4, 1], name='pred_tr') + blob_out = fluid.layers.softmax(blob_out, name='softmax_conv') + softmax = fluid.layers.reduce_mean( + blob_out, dim=[1, 2, 3], keep_dim=False, name='softmax') + loss = None + else: + raise 'Not implemented Error' + + return softmax, loss diff --git a/PaddleCV/video/models/utils.py b/PaddleCV/video/models/utils.py old mode 100755 new mode 100644 diff --git a/PaddleCV/video/scripts/infer/infer_attention_cluster.sh b/PaddleCV/video/scripts/infer/infer_attention_cluster.sh index be6045db..5f967a9f 100644 --- a/PaddleCV/video/scripts/infer/infer_attention_cluster.sh +++ b/PaddleCV/video/scripts/infer/infer_attention_cluster.sh @@ -1,4 +1,4 @@ -python infer.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \ +python infer.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \ --filelist=./data/youtube8m/infer.list \ --weights=./checkpoints/AttentionCluster_epoch0 \ - --save-dir="./save" + --save_dir="./save" diff --git a/PaddleCV/video/scripts/infer/infer_attention_lstm.sh b/PaddleCV/video/scripts/infer/infer_attention_lstm.sh index 019bb346..e79e0258 100644 --- a/PaddleCV/video/scripts/infer/infer_attention_lstm.sh +++ b/PaddleCV/video/scripts/infer/infer_attention_lstm.sh @@ -1,4 +1,4 @@ -python infer.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt \ +python infer.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \ --filelist=./data/youtube8m/infer.list \ --weights=./checkpoints/AttentionLSTM_epoch0 \ - --save-dir="./save" + --save_dir="./save" diff --git a/PaddleCV/video/scripts/infer/infer_nextvlad.sh b/PaddleCV/video/scripts/infer/infer_nextvlad.sh index 1a969801..3dfc3106 100644 --- a/PaddleCV/video/scripts/infer/infer_nextvlad.sh +++ b/PaddleCV/video/scripts/infer/infer_nextvlad.sh @@ -1,3 +1,3 @@ -python infer.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \ +python infer.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --filelist=./data/youtube8m/infer.list \ --weights=./checkpoints/NEXTVLAD_epoch0 \ - --save-dir="./save" + --save_dir="./save" diff --git a/PaddleCV/video/scripts/infer/infer_nonlocal.sh b/PaddleCV/video/scripts/infer/infer_nonlocal.sh new file mode 100644 index 00000000..1480908d --- /dev/null +++ b/PaddleCV/video/scripts/infer/infer_nonlocal.sh @@ -0,0 +1,2 @@ +python infer.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --filelist=./dataset/nonlocal/infer.list \ + --log_interval=10 --weights=./checkpoints/NONLOCAL_epoch0 --save_dir=./save diff --git a/PaddleCV/video/scripts/infer/infer_stnet.sh b/PaddleCV/video/scripts/infer/infer_stnet.sh index 8b27a234..fdd19af0 100644 --- a/PaddleCV/video/scripts/infer/infer_stnet.sh +++ b/PaddleCV/video/scripts/infer/infer_stnet.sh @@ -1,2 +1,2 @@ -python infer.py --model-name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \ - --log-interval=10 --weights=./checkpoints/STNET_epoch0 --save-dir=./save +python infer.py --model_name="STNET" --config=./configs/stnet.txt --filelist=./data/kinetics/infer.list \ + --log_interval=10 --weights=./checkpoints/STNET_epoch0 --save_dir=./save diff --git a/PaddleCV/video/scripts/infer/infer_tsn.sh b/PaddleCV/video/scripts/infer/infer_tsn.sh index 515feaf4..d6009b6c 100644 --- a/PaddleCV/video/scripts/infer/infer_tsn.sh +++ b/PaddleCV/video/scripts/infer/infer_tsn.sh @@ -1,2 +1,2 @@ -python infer.py --model-name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \ - --log-interval=10 --weights=./checkpoints/TSN_epoch0 --save-dir=./save +python infer.py --model_name="TSN" --config=./configs/tsn.txt --filelist=./data/kinetics/infer.list \ + --log_interval=10 --weights=./checkpoints/TSN_epoch0 --save_dir=./save diff --git a/PaddleCV/video/scripts/test/test_attention_cluster.sh b/PaddleCV/video/scripts/test/test_attention_cluster.sh index 21df1319..1bdc5acf 100644 --- a/PaddleCV/video/scripts/test/test_attention_cluster.sh +++ b/PaddleCV/video/scripts/test/test_attention_cluster.sh @@ -1,2 +1,2 @@ -python test.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt \ - --log-interval=5 --weights=./checkpoints/AttentionCluster_epoch0 +python test.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt \ + --log_interval=5 --weights=./checkpoints/AttentionCluster_epoch0 diff --git a/PaddleCV/video/scripts/test/test_attention_lstm.sh b/PaddleCV/video/scripts/test/test_attention_lstm.sh index d728dbd1..27bff350 100644 --- a/PaddleCV/video/scripts/test/test_attention_lstm.sh +++ b/PaddleCV/video/scripts/test/test_attention_lstm.sh @@ -1,2 +1,2 @@ -python test.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt \ - --log-interval=5 --weights=./checkpoints/AttentionLSTM_epoch0 +python test.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt \ + --log_interval=5 --weights=./checkpoints/AttentionLSTM_epoch0 diff --git a/PaddleCV/video/scripts/test/test_nextvlad.sh b/PaddleCV/video/scripts/test/test_nextvlad.sh index 239e9980..4d390a0b 100644 --- a/PaddleCV/video/scripts/test/test_nextvlad.sh +++ b/PaddleCV/video/scripts/test/test_nextvlad.sh @@ -1,2 +1,2 @@ -python test.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt \ - --log-interval=10 --weights=./checkpoints/NEXTVLAD_epoch0 +python test.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt \ + --log_interval=10 --weights=./checkpoints/NEXTVLAD_epoch0 diff --git a/PaddleCV/video/scripts/test/test_nonlocal.sh b/PaddleCV/video/scripts/test/test_nonlocal.sh new file mode 100644 index 00000000..7a42bb05 --- /dev/null +++ b/PaddleCV/video/scripts/test/test_nonlocal.sh @@ -0,0 +1,2 @@ +python -i test.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt \ + --log_interval=1 --weights=./checkpoints/NONLOCAL_epoch0 diff --git a/PaddleCV/video/scripts/test/test_stnet.sh b/PaddleCV/video/scripts/test/test_stnet.sh index 6913ea69..0b471ed9 100644 --- a/PaddleCV/video/scripts/test/test_stnet.sh +++ b/PaddleCV/video/scripts/test/test_stnet.sh @@ -1,2 +1,2 @@ -python test.py --model-name="STNET" --config=./configs/stnet.txt \ - --log-interval=10 --weights=./checkpoints/STNET_epoch0 +python test.py --model_name="STNET" --config=./configs/stnet.txt \ + --log_interval=10 --weights=./checkpoints/STNET_epoch0 diff --git a/PaddleCV/video/scripts/test/test_tsn.sh b/PaddleCV/video/scripts/test/test_tsn.sh index b66bcb2c..ffe0ff51 100644 --- a/PaddleCV/video/scripts/test/test_tsn.sh +++ b/PaddleCV/video/scripts/test/test_tsn.sh @@ -1,2 +1,2 @@ -python test.py --model-name="TSN" --config=./configs/tsn.txt \ - --log-interval=10 --weights=./checkpoints/TSN_epoch0 +python test.py --model_name="TSN" --config=./configs/tsn.txt \ + --log_interval=10 --weights=./checkpoints/TSN_epoch0 diff --git a/PaddleCV/video/scripts/train/train_attention_cluster.sh b/PaddleCV/video/scripts/train/train_attention_cluster.sh index 0a0b0bbb..41a555da 100644 --- a/PaddleCV/video/scripts/train/train_attention_cluster.sh +++ b/PaddleCV/video/scripts/train/train_attention_cluster.sh @@ -1,2 +1,2 @@ -python train.py --model-name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch-num=5 \ - --valid-interval=1 --log-interval=10 +python train.py --model_name="AttentionCluster" --config=./configs/attention_cluster.txt --epoch_num=5 \ + --valid_interval=1 --log_interval=10 diff --git a/PaddleCV/video/scripts/train/train_attention_lstm.sh b/PaddleCV/video/scripts/train/train_attention_lstm.sh index bb855b19..32e80a27 100644 --- a/PaddleCV/video/scripts/train/train_attention_lstm.sh +++ b/PaddleCV/video/scripts/train/train_attention_lstm.sh @@ -1,2 +1,2 @@ -python train.py --model-name="AttentionLSTM" --config=./configs/attention_lstm.txt --epoch-num=10 \ - --valid-interval=1 --log-interval=10 +python train.py --model_name="AttentionLSTM" --config=./configs/attention_lstm.txt --epoch_num=10 \ + --valid_interval=1 --log_interval=10 diff --git a/PaddleCV/video/scripts/train/train_nextvlad.sh b/PaddleCV/video/scripts/train/train_nextvlad.sh index b5857e9f..dd8e5c57 100644 --- a/PaddleCV/video/scripts/train/train_nextvlad.sh +++ b/PaddleCV/video/scripts/train/train_nextvlad.sh @@ -1,3 +1,3 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 -python train.py --model-name="NEXTVLAD" --config=./configs/nextvlad.txt --epoch-num=6 \ - --valid-interval=1 --log-interval=10 +python train.py --model_name="NEXTVLAD" --config=./configs/nextvlad.txt --epoch_num=6 \ + --valid_interval=1 --log_interval=10 diff --git a/PaddleCV/video/scripts/train/train_nonlocal.sh b/PaddleCV/video/scripts/train/train_nonlocal.sh new file mode 100644 index 00000000..67b8efb1 --- /dev/null +++ b/PaddleCV/video/scripts/train/train_nonlocal.sh @@ -0,0 +1,3 @@ +python train.py --model_name="NONLOCAL" --config=./configs/nonlocal.txt --epoch_num=120 \ + --valid_interval=1 --log_interval=1 \ + --pretrain=./pretrained/ResNet50_pretrained diff --git a/PaddleCV/video/scripts/train/train_stnet.sh b/PaddleCV/video/scripts/train/train_stnet.sh index c595c10c..fed3eb39 100644 --- a/PaddleCV/video/scripts/train/train_stnet.sh +++ b/PaddleCV/video/scripts/train/train_stnet.sh @@ -1,2 +1,2 @@ -python train.py --model-name="STNET" --config=./configs/stnet.txt --epoch-num=60 \ - --valid-interval=1 --log-interval=10 +python train.py --model_name="STNET" --config=./configs/stnet.txt --epoch_num=60 \ + --valid_interval=1 --log_interval=10 diff --git a/PaddleCV/video/scripts/train/train_tsn.sh b/PaddleCV/video/scripts/train/train_tsn.sh index e476744d..c81c3f85 100644 --- a/PaddleCV/video/scripts/train/train_tsn.sh +++ b/PaddleCV/video/scripts/train/train_tsn.sh @@ -1,2 +1,2 @@ -python train.py --model-name="TSN" --config=./configs/tsn.txt --epoch-num=45 \ - --valid-interval=1 --log-interval=10 +python train.py --model_name="TSN" --config=./configs/tsn.txt --epoch_num=45 \ + --valid_interval=1 --log_interval=10 diff --git a/PaddleCV/video/test.py b/PaddleCV/video/test.py old mode 100755 new mode 100644 index 9698caec..a08fd908 --- a/PaddleCV/video/test.py +++ b/PaddleCV/video/test.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - '--model-name', + '--model_name', type=str, default='AttentionCluster', help='name of model to train.') @@ -44,19 +44,19 @@ def parse_args(): default='configs/attention_cluster.txt', help='path to config file of model') parser.add_argument( - '--batch-size', + '--batch_size', type=int, default=None, help='traing batch size per GPU. None to use config file setting.') parser.add_argument( - '--use-gpu', type=bool, default=True, help='default use gpu.') + '--use_gpu', type=bool, default=True, help='default use gpu.') parser.add_argument( '--weights', type=str, default=None, help='weight path, None to use weights from Paddle.') parser.add_argument( - '--log-interval', + '--log_interval', type=int, default=1, help='mini-batch interval to log.') @@ -75,7 +75,7 @@ def test(args): test_model.build_model() test_feeds = test_model.feeds() test_outputs = test_model.outputs() - loss = test_model.loss() + test_loss = test_model.loss() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) @@ -85,29 +85,34 @@ def test(args): args.weights), "Given weight dir {} not exist.".format(args.weights) weights = args.weights or test_model.get_weights() - def if_exist(var): - return os.path.exists(os.path.join(weights, var.name)) - - fluid.io.load_vars(exe, weights, predicate=if_exist) + test_model.load_test_weights(exe, weights, + fluid.default_main_program(), place) # get reader and metrics test_reader = get_reader(args.model_name.upper(), 'test', test_config) test_metrics = get_metrics(args.model_name.upper(), 'test', test_config) test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds) - fetch_list = [loss.name] + [x.name - for x in test_outputs] + [test_feeds[-1].name] + if test_loss is None: + fetch_list = [x.name for x in test_outputs] + [test_feeds[-1].name] + else: + fetch_list = [test_loss.name] + [x.name for x in test_outputs + ] + [test_feeds[-1].name] epoch_period = [] for test_iter, data in enumerate(test_reader()): cur_time = time.time() - test_outs = exe.run(fetch_list=fetch_list, - feed=test_feeder.feed(data)) + test_outs = exe.run(fetch_list=fetch_list, feed=test_feeder.feed(data)) period = time.time() - cur_time epoch_period.append(period) - loss = np.array(test_outs[0]) - pred = np.array(test_outs[1]) - label = np.array(test_outs[-1]) + if test_loss is None: + loss = np.zeros(1, ).astype('float32') + pred = np.array(test_outs[0]) + label = np.array(test_outs[-1]) + else: + loss = np.array(test_outs[0]) + pred = np.array(test_outs[1]) + label = np.array(test_outs[-1]) test_metrics.accumulate(loss, pred, label) # metric here diff --git a/PaddleCV/video/train.py b/PaddleCV/video/train.py old mode 100755 new mode 100644 index 154c51ed..066ba58d --- a/PaddleCV/video/train.py +++ b/PaddleCV/video/train.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser("Paddle Video train script") parser.add_argument( - '--model-name', + '--model_name', type=str, default='AttentionCluster', help='name of model to train.') @@ -45,12 +45,12 @@ def parse_args(): default='configs/attention_cluster.txt', help='path to config file of model') parser.add_argument( - '--batch-size', + '--batch_size', type=int, default=None, help='training batch size. None to use config file setting.') parser.add_argument( - '--learning-rate', + '--learning_rate', type=float, default=None, help='learning rate use for training. None to use config file setting.') @@ -65,37 +65,36 @@ def parse_args(): type=str, default=None, help='path to resume training based on previous checkpoints. ' - 'None for not resuming any checkpoints.' - ) + 'None for not resuming any checkpoints.') parser.add_argument( - '--use-gpu', type=bool, default=True, help='default use gpu.') + '--use_gpu', type=bool, default=True, help='default use gpu.') parser.add_argument( - '--no-use-pyreader', + '--no_use_pyreader', action='store_true', default=False, help='whether to use pyreader') parser.add_argument( - '--no-memory-optimize', + '--no_memory_optimize', action='store_true', default=False, help='whether to use memory optimize in train') parser.add_argument( - '--epoch-num', + '--epoch_num', type=int, default=0, help='epoch number, 0 for read from config file') parser.add_argument( - '--valid-interval', + '--valid_interval', type=int, default=1, help='validation epoch interval, 0 for no validation.') parser.add_argument( - '--save-dir', + '--save_dir', type=str, default='checkpoints', help='directory name to save train snapshoot') parser.add_argument( - '--log-interval', + '--log_interval', type=int, default=10, help='mini-batch interval to log.') @@ -108,6 +107,8 @@ def train(args): config = parse_config(args.config) train_config = merge_configs(config, 'train', vars(args)) valid_config = merge_configs(config, 'valid', vars(args)) + logger.info("############### train config ###############") + print_configs(train_config) train_model = models.get_model(args.model_name, train_config, mode='train') valid_model = models.get_model(args.model_name, valid_config, mode='valid') @@ -153,9 +154,12 @@ def train(args): # if resume weights is given, load resume weights directly assert os.path.exists(args.resume), \ "Given resume weight dir {} not exist.".format(args.resume) + def if_exist(var): return os.path.exists(os.path.join(args.resume, var.name)) - fluid.io.load_vars(exe, args.resume, predicate=if_exist, main_program=train_prog) + + fluid.io.load_vars( + exe, args.resume, predicate=if_exist, main_program=train_prog) else: # if not in resume mode, load pretrain weights if args.pretrain: @@ -199,21 +203,43 @@ def train(args): if args.no_use_pyreader: train_feeder = fluid.DataFeeder(place=place, feed_list=train_feeds) valid_feeder = fluid.DataFeeder(place=place, feed_list=valid_feeds) - train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feeder, - train_fetch_list, train_metrics, epochs = epochs, - log_interval = args.log_interval, valid_interval = args.valid_interval, - save_dir = args.save_dir, save_model_name = args.model_name, - test_exe = valid_exe, test_reader = valid_reader, test_feeder = valid_feeder, - test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) + train_without_pyreader( + exe, + train_prog, + train_exe, + train_reader, + train_feeder, + train_fetch_list, + train_metrics, + epochs=epochs, + log_interval=args.log_interval, + valid_interval=args.valid_interval, + save_dir=args.save_dir, + save_model_name=args.model_name, + test_exe=valid_exe, + test_reader=valid_reader, + test_feeder=valid_feeder, + test_fetch_list=valid_fetch_list, + test_metrics=valid_metrics) else: train_pyreader.decorate_paddle_reader(train_reader) valid_pyreader.decorate_paddle_reader(valid_reader) - train_with_pyreader(exe, train_prog, train_exe, train_pyreader, train_fetch_list, train_metrics, - epochs = epochs, log_interval = args.log_interval, - valid_interval = args.valid_interval, - save_dir = args.save_dir, save_model_name = args.model_name, - test_exe = valid_exe, test_pyreader = valid_pyreader, - test_fetch_list = valid_fetch_list, test_metrics = valid_metrics) + train_with_pyreader( + exe, + train_prog, + train_exe, + train_pyreader, + train_fetch_list, + train_metrics, + epochs=epochs, + log_interval=args.log_interval, + valid_interval=args.valid_interval, + save_dir=args.save_dir, + save_model_name=args.model_name, + test_exe=valid_exe, + test_pyreader=valid_pyreader, + test_fetch_list=valid_fetch_list, + test_metrics=valid_metrics) if __name__ == "__main__": diff --git a/PaddleCV/video/utils.py b/PaddleCV/video/utils.py old mode 100755 new mode 100644 index 3b07d606..11681ca9 --- a/PaddleCV/video/utils.py +++ b/PaddleCV/video/utils.py @@ -14,6 +14,7 @@ __all__ = ['AttrDict'] + class AttrDict(dict): def __getattr__(self, key): return self[key] -- GitLab