From e12544e2c9c8c3a00b86b94492ccf60cc048859c Mon Sep 17 00:00:00 2001 From: SunGaofeng Date: Tue, 25 Jun 2019 01:51:42 +0800 Subject: [PATCH] Add ctcn model for action detection (#2529) * Add ctcn model for action detection * remove data list from codebase --- PaddleCV/video/configs/ctcn.txt | 53 ++ PaddleCV/video/datareader/__init__.py | 2 + PaddleCV/video/datareader/ctcn_reader.py | 456 ++++++++++++++++++ PaddleCV/video/metrics/detections/__init__.py | 0 .../metrics/detections/detection_metrics.py | 133 +++++ PaddleCV/video/metrics/metrics_util.py | 39 ++ PaddleCV/video/models/__init__.py | 2 + PaddleCV/video/models/ctcn/__init__.py | 1 + PaddleCV/video/models/ctcn/ctcn.py | 179 +++++++ PaddleCV/video/models/ctcn/ctcn_utils.py | 331 +++++++++++++ PaddleCV/video/models/ctcn/fpn_ctcn.py | 322 +++++++++++++ PaddleCV/video/models/model.py | 3 - PaddleCV/video/scripts/test/test_ctcn.sh | 2 + PaddleCV/video/scripts/train/train_ctcn.sh | 9 + PaddleCV/video/tools/train_utils.py | 83 +++- PaddleCV/video/train.py | 37 +- 16 files changed, 1625 insertions(+), 27 deletions(-) create mode 100644 PaddleCV/video/configs/ctcn.txt create mode 100644 PaddleCV/video/datareader/ctcn_reader.py create mode 100644 PaddleCV/video/metrics/detections/__init__.py create mode 100644 PaddleCV/video/metrics/detections/detection_metrics.py create mode 100644 PaddleCV/video/models/ctcn/__init__.py create mode 100644 PaddleCV/video/models/ctcn/ctcn.py create mode 100644 PaddleCV/video/models/ctcn/ctcn_utils.py create mode 100644 PaddleCV/video/models/ctcn/fpn_ctcn.py create mode 100644 PaddleCV/video/scripts/test/test_ctcn.sh create mode 100644 PaddleCV/video/scripts/train/train_ctcn.sh diff --git a/PaddleCV/video/configs/ctcn.txt b/PaddleCV/video/configs/ctcn.txt new file mode 100644 index 00000000..26110c8f --- /dev/null +++ b/PaddleCV/video/configs/ctcn.txt @@ -0,0 +1,53 @@ +[MODEL] +name = "CTCN" +num_classes = 201 +img_size = 512 +concept_size = 402 +num_anchors = 7 +total_num_anchors = 1785 +snippet_length = 1 +root = '/ssd3/huangjun/Paddle/feats' + +[TRAIN] +epoch = 35 +filelist = 'dataset/ctcn/Activity1.3_train_rgb.listformat' +rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_train' +flow = 'senet152-201cls-flow-60.9-5seg-331data_train' +batch_size = 16 +num_threads = 8 +use_gpu = True +num_gpus = 8 +learning_rate = 0.0005 +learning_rate_decay = 0.1 +lr_decay_iter = 9000 +l2_weight_decay = 1e-4 +momentum = 0.9 + +[VALID] +filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat' +rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val' +flow = 'senet152-201cls-flow-60.9-5seg-331data_val' +batch_size = 16 +num_threads = 8 +use_gpu = True +num_gpus = 8 + +[TEST] +filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat' +rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val' +flow = 'senet152-201cls-flow-60.9-5seg-331data_val' +class_label_file = 'dataset/ctcn/test_val_label.list' +video_duration_file = 'dataset/ctcn/val_duration_frame.list' +batch_size = 1 +num_threads = 1 +score_thresh = 0.001 +nms_thresh = 0.08 +sigma_thresh = 0.006 +soft_thresh = 0.006 + +[INFER] +filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat' +rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val' +flow = 'senet152-201cls-flow-60.9-5seg-331data_val' +batch_size = 1 +num_threads = 1 diff --git a/PaddleCV/video/datareader/__init__.py b/PaddleCV/video/datareader/__init__.py index ee898672..c5c15938 100644 --- a/PaddleCV/video/datareader/__init__.py +++ b/PaddleCV/video/datareader/__init__.py @@ -2,6 +2,7 @@ from .reader_utils import regist_reader, get_reader from .feature_reader import FeatureReader from .kinetics_reader import KineticsReader from .nonlocal_reader import NonlocalReader +from .ctcn_reader import CTCNReader # regist reader, sort by alphabet regist_reader("ATTENTIONCLUSTER", FeatureReader) @@ -11,3 +12,4 @@ regist_reader("NONLOCAL", NonlocalReader) regist_reader("TSM", KineticsReader) regist_reader("TSN", KineticsReader) regist_reader("STNET", KineticsReader) +regist_reader("CTCN", CTCNReader) diff --git a/PaddleCV/video/datareader/ctcn_reader.py b/PaddleCV/video/datareader/ctcn_reader.py new file mode 100644 index 00000000..4b4b8afd --- /dev/null +++ b/PaddleCV/video/datareader/ctcn_reader.py @@ -0,0 +1,456 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import random +import cv2 +import sys +import numpy as np +import gc +import copy +import multiprocessing + +import logging +logger = logging.getLogger(__name__) + +try: + import cPickle as pickle + from cStringIO import StringIO +except ImportError: + import pickle + from io import BytesIO + +from .reader_utils import DataReader +from models.ctcn.ctcn_utils import box_clamp1D, box_iou1D, BoxCoder + +python_ver = sys.version_info + +#random.seed(0) +#np.random.seed(0) + + +class CTCNReader(DataReader): + """ + Data reader for C-TCN model, which was stored as features extracted by prior networks + dataset cfg: img_size, the temporal dimension size of input data + root, the root dir of data + snippet_length, snippet length when sampling + filelist, the file list storing id and annotations of each data item + rgb, the dir of rgb data + flow, the dir of optical flow data + batch_size, batch size of input data + num_threads, number of threads of data processing + + """ + + def __init__(self, name, mode, cfg): + self.name = name + self.mode = mode + self.img_size = cfg.MODEL.img_size # 512 + self.snippet_length = cfg.MODEL.snippet_length # 1 + self.root = cfg.MODEL.root # root dir of data + self.filelist = cfg[mode.upper()]['filelist'] + self.rgb = cfg[mode.upper()]['rgb'] + self.flow = cfg[mode.upper()]['flow'] + self.batch_size = cfg[mode.upper()]['batch_size'] + self.num_threads = cfg[mode.upper()]['num_threads'] + if (mode == 'test') or (mode == 'infer'): + self.num_threads = 1 # set num_threads as 1 for test and infer + + def random_move(self, img, o_boxes, labels): + boxes = np.array(o_boxes) + mask = np.zeros(img.shape[0]) + for i in boxes: + for j in range(i[0].astype('int'), + min(i[1].astype('int'), img.shape[0])): + mask[j] = 1 + mask = (mask == 0) + bg = img[mask] + bg_len = bg.shape[0] + if bg_len < 5: + return img, boxes, labels + insert_place = random.sample(range(bg_len), len(boxes)) + index = np.argsort(insert_place) + new_img = bg[0:insert_place[index[0]], :] + new_boxes = [] + new_labels = [] + + for i in range(boxes.shape[0]): + new_boxes.append([ + new_img.shape[0], + new_img.shape[0] + boxes[index[i]][1] - boxes[index[i]][0] + ]) + new_labels.append(labels[index[i]]) + new_img = np.concatenate( + (new_img, + img[int(boxes[index[i]][0]):int(boxes[index[i]][1]), :])) + if i < boxes.shape[0] - 1: + new_img = np.concatenate( + (new_img, + bg[insert_place[index[i]]:insert_place[index[i + 1]], :])) + new_img = np.concatenate( + (new_img, bg[insert_place[index[len(boxes) - 1]]:, :])) + del img, boxes, mask, bg, labels + gc.collect() + return new_img, new_boxes, new_labels + + def random_crop(self, img, boxes, labels, min_scale=0.3): + boxes = np.array(boxes) + labels = np.array(labels) + imh, imw = img.shape[:2] + params = [(0, imh)] + for min_iou in (0, 0.1, 0.3, 0.5, 0.7, 0.9): + for _ in range(100): + scale = random.uniform(0.3, 1) + h = int(imh * scale) + + y = random.randrange(imh - h) + roi = [[y, y + h]] + ious = box_iou1D(boxes, roi) + if ious.min() >= min_iou: + params.append((y, h)) + break + y, h = random.choice(params) + img = img[y:y + h, :] + center = (boxes[:, 0] + boxes[:, 1]) / 2 + mask = (center[:] >= y) & (center[:] <= y + h) + if mask.any(): + boxes = boxes[np.squeeze(mask.nonzero())] - np.array([[y, y]]) + boxes = box_clamp1D(boxes, 0, h) + labels = labels[mask] + else: + boxes = [[0, 0]] + labels = [0] + return img, boxes, labels + + def resize(self, img, boxes, size, random_interpolation=False): + '''Resize the input PIL image to given size. + + If boxes is not None, resize boxes accordingly. + + Args: + img: image to be resized. + boxes: (tensor) object boxes, sized [#obj,2]. + size: (tuple or int) + - if is tuple, resize image to the size. + - if is int, resize the shorter side to the size while maintaining the aspect ratio. + random_interpolation: (bool) randomly choose a resize interpolation method. + + Returns: + img: (cv2's numpy.ndarray) resized image. + boxes: (tensor) resized boxes. + + Example: + >> img, boxes = resize(img, boxes, 600) # resize shorter side to 600 + ''' + h, w = img.shape[:2] + if h == size: + return img, boxes + if h == 0: + img = np.zeros((512, 402), np.float32) + return img, boxes + + ow = w + oh = size + sw = 1 + sh = float(oh) / h + method = random.choice([ + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA + ]) if random_interpolation else cv2.INTER_NEAREST + img = cv2.resize(img, (ow, oh), interpolation=method) + if boxes is not None: + boxes = boxes * np.array([sh, sh]) + return img, boxes + + def transform(self, feats, boxes, labels, mode): + feats = np.array(feats) + boxes = np.array(boxes) + labels = np.array(labels) + #print('name {}, labels {}'.format(fname, labels)) + + if mode == 'train': + feats, boxes, labels = self.random_move(feats, boxes, labels) + feats, boxes, labels = self.random_crop(feats, boxes, labels) + feats, boxes = self.resize( + feats, boxes, size=self.img_size, random_interpolation=True) + h, w = feats.shape[:2] + img = feats.reshape(1, h, w) + Coder = BoxCoder() + boxes, labels = Coder.encode(boxes, labels) + if mode == 'test' or mode == 'valid': + feats, boxes = self.resize(feats, boxes, size=self.img_size) + h, w = feats.shape[:2] + img = feats.reshape(1, h, w) + Coder = BoxCoder() + boxes, labels = Coder.encode(boxes, labels) + return img, boxes, labels + + def create_reader(self): + """reader creator for ctcn model""" + if self.num_threads == 1: + return self.make_reader() + else: + return self.make_multiprocess_reader() + + def make_reader(self): + """single process reader""" + + def reader(): + with open(self.filelist) as f: + reader_list = f.readlines() + if self.mode == 'train': + random.shuffle(reader_list) + fnames = [] + total_boxes = [] + total_labels = [] + total_label_ids = [] + for i in range(len(reader_list)): + line = reader_list[i] + splited = line.strip().split() + rgb_exist = os.path.exists( + os.path.join(self.root, self.rgb, splited[0] + '.pkl')) + flow_exist = os.path.exists( + os.path.join(self.root, self.flow, splited[0] + '.pkl')) + if not (rgb_exist and flow_exist): + print('file not exist', splited[0]) + continue + fnames.append(splited[0]) + frames_num = int(splited[1]) // self.snippet_length + num_boxes = int(splited[2]) + box = [] + label = [] + for i in range(num_boxes): + c = splited[3 + 3 * i] + xmin = splited[4 + 3 * i] + xmax = splited[5 + 3 * i] + box.append([ + float(xmin) / self.snippet_length, + float(xmax) / self.snippet_length + ]) + label.append(int(c)) + total_label_ids.append(i) + total_boxes.append(box) + total_labels.append(label) + num_videos = len(fnames) + batch_out = [] + for idx in range(num_videos): + fname = fnames[idx] + try: + if python_ver < (3, 0): + rgb_pkl = pickle.load( + open( + os.path.join(self.root, self.rgb, fname + + '.pkl'))) + flow_pkl = pickle.load( + open( + os.path.join(self.root, self.flow, fname + + '.pkl'))) + else: + rgb_pkl = pickle.load( + open( + os.path.join(self.root, self.rgb, fname + + '.pkl')), + encoding='bytes') + flow_pkl = pickle.load( + open( + os.path.join(self.root, self.flow, fname + + '.pkl')), + encoding='bytes') + + data_flow = np.array(flow_pkl['scores']) + data_rgb = np.array(rgb_pkl['scores']) + + if data_flow.shape[0] < data_rgb.shape[0]: + data_rgb = data_rgb[0:data_flow.shape[0], :] + elif data_flow.shape[0] > data_rgb.shape[0]: + data_flow = data_flow[0:data_rgb.shape[0], :] + + feats = np.concatenate((data_rgb, data_flow), axis=1) + if feats.shape[0] == 0 or feats.shape[1] == 0: + feats = np.zeros((512, 1024), np.float32) + logger.info('### file loading len = 0 {} ###'.format( + fname)) + + boxes = copy.deepcopy(total_boxes[idx]) + labels = copy.deepcopy(total_labels[idx]) + + feats, boxes, labels = self.transform(feats, boxes, labels, + self.mode) + labels = labels.astype('int64') + boxes = boxes.astype('float32') + num_pos = len(np.where(labels > 0)[0]) + except: + logger.info('Error when loading {}'.format(fname)) + continue + if (num_pos < 1) and (self.mode == 'train' or + self.mode == 'valid'): + #logger.info('=== no pos for ==='.format(fname, num_pos)) + continue + if self.mode == 'train' or self.mode == 'valid': + batch_out.append((feats, boxes, labels)) + elif self.mode == 'test': + batch_out.append( + (feats, boxes, labels, total_label_ids[idx])) + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + if len(batch_out) == self.batch_size: + yield batch_out + batch_out = [] + + return reader + + def make_multiprocess_reader(self): + """multiprocess reader""" + + def read_into_queue(reader_list, queue): + fnames = [] + total_boxes = [] + total_labels = [] + total_label_ids = [] + #for line in reader_list: + for i in range(len(reader_list)): + line = reader_list[i] + splited = line.strip().split() + rgb_exist = os.path.exists( + os.path.join(self.root, self.rgb, splited[0] + '.pkl')) + flow_exist = os.path.exists( + os.path.join(self.root, self.flow, splited[0] + '.pkl')) + if not (rgb_exist and flow_exist): + logger.info('file not exist {}'.format(splited[0])) + continue + fnames.append(splited[0]) + frames_num = int(splited[1]) // self.snippet_length + num_boxes = int(splited[2]) + box = [] + label = [] + for i in range(num_boxes): + c = splited[3 + 3 * i] + xmin = splited[4 + 3 * i] + xmax = splited[5 + 3 * i] + box.append([ + float(xmin) / self.snippet_length, + float(xmax) / self.snippet_length + ]) + label.append(int(c)) + total_label_ids.append(i) + total_boxes.append(box) + total_labels.append(label) + num_videos = len(fnames) + batch_out = [] + for idx in range(num_videos): + fname = fnames[idx] + try: + if python_ver < (3, 0): + rgb_pkl = pickle.load( + open( + os.path.join(self.root, self.rgb, fname + + '.pkl'))) + flow_pkl = pickle.load( + open( + os.path.join(self.root, self.flow, fname + + '.pkl'))) + else: + rgb_pkl = pickle.load( + open( + os.path.join(self.root, self.rgb, fname + + '.pkl')), + encoding='bytes') + flow_pkl = pickle.load( + open( + os.path.join(self.root, self.flow, fname + + '.pkl')), + encoding='bytes') + + data_flow = np.array(flow_pkl['scores']) + data_rgb = np.array(rgb_pkl['scores']) + + if data_flow.shape[0] < data_rgb.shape[0]: + data_rgb = data_rgb[0:data_flow.shape[0], :] + elif data_flow.shape[0] > data_rgb.shape[0]: + data_flow = data_flow[0:data_rgb.shape[0], :] + + feats = np.concatenate((data_rgb, data_flow), axis=1) + if feats.shape[0] == 0 or feats.shape[1] == 0: + feats = np.zeros((512, 1024), np.float32) + logger.info('### file loading len = 0 {} ###'.format( + fname)) + + boxes = copy.deepcopy(total_boxes[idx]) + labels = copy.deepcopy(total_labels[idx]) + + feats, boxes, labels = self.transform(feats, boxes, labels, + self.mode) + labels = labels.astype('int64') + boxes = boxes.astype('float32') + num_pos = len(np.where(labels > 0)[0]) + except: + logger.info('Error when loading {}'.format(fname)) + continue + if (not (num_pos >= 1)) and (self.mode == 'train' or + self.mode == 'valid'): + #logger.info('=== no pos for {}, num_pos = {} ==='.format(fname, num_pos)) + continue + + if self.mode == 'train' or self.mode == 'valid': + batch_out.append((feats, boxes, labels)) + elif self.mode == 'test': + batch_out.append( + (feats, boxes, labels, total_label_ids[idx])) + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + + if len(batch_out) == self.batch_size: + queue.put(batch_out) + batch_out = [] + queue.put(None) + + def queue_reader(): + with open(self.filelist) as f: + fl = f.readlines() + if self.mode == 'train': + random.shuffle(fl) + n = self.num_threads + queue_size = 20 + reader_lists = [None] * n + file_num = int(len(fl) // n) + for i in range(n): + if i < len(reader_lists) - 1: + tmp_list = fl[i * file_num:(i + 1) * file_num] + else: + tmp_list = fl[i * file_num:] + reader_lists[i] = tmp_list + + queue = multiprocessing.Queue(queue_size) + p_list = [None] * len(reader_lists) + # for reader_list in reader_lists: + for i in range(len(reader_lists)): + reader_list = reader_lists[i] + p_list[i] = multiprocessing.Process( + target=read_into_queue, args=(reader_list, queue)) + p_list[i].start() + reader_num = len(reader_lists) + finish_num = 0 + while finish_num < reader_num: + sample = queue.get() + if sample is None: + finish_num += 1 + else: + yield sample + for i in range(len(p_list)): + if p_list[i].is_alive(): + p_list[i].join() + + return queue_reader diff --git a/PaddleCV/video/metrics/detections/__init__.py b/PaddleCV/video/metrics/detections/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/PaddleCV/video/metrics/detections/detection_metrics.py b/PaddleCV/video/metrics/detections/detection_metrics.py new file mode 100644 index 00000000..6d452f70 --- /dev/null +++ b/PaddleCV/video/metrics/detections/detection_metrics.py @@ -0,0 +1,133 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and + +import numpy as np +import datetime +import logging +import json + +from models.ctcn.ctcn_utils import BoxCoder + +logger = logging.getLogger(__name__) + + +def get_class_label(class_label_file): + class_label = open(class_label_file, 'r').readlines() + return class_label + + +def get_video_time_dict(video_duration_file): + video_time_dict = dict() + fps_file = open(video_duration_file, 'r').readlines() + for line in fps_file: + contents = line.split() + video_time_dict[contents[0]] = float(contents[-1]) + return video_time_dict + + +class MetricsCalculator(): + def __init__(self, + name='CTCN', + mode='train', + score_thresh=0.001, + nms_thresh=0.8, + sigma_thresh=0.8, + soft_thresh=0.006, + gt_label_file='', + class_label_file='', + video_duration_file=''): + self.name = name + self.mode = mode # 'train', 'val', 'test' + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.sigma_thresh = sigma_thresh + self.soft_thresh = soft_thresh + self.class_label_file = class_label_file + self.video_duration_file = video_duration_file + if mode == 'test': + lines = open(gt_label_file).readlines() + self.gt_labels = [item.split(' ')[0] for item in lines] + self.box_coder = BoxCoder() + else: + self.gt_labels = None + self.box_coder = None + self.reset() + + def reset(self): + logger.info('Resetting {} metrics...'.format(self.mode)) + self.aggr_loss = 0.0 + self.aggr_loc_loss = 0.0 + self.aggr_cls_loss = 0.0 + self.aggr_batch_size = 0 + if self.mode == 'test': + self.class_label = get_class_label(self.class_label_file) + self.video_time_dict = get_video_time_dict(self.video_duration_file) + self.res_detect = dict() + self.res_detect["version"] = "VERSION 1.3" + self.res_detect["external_data"] = { + "uesd": False, + "details": "none" + } + + self.results_detect = dict() + self.box_decode_params = { + 'score_thresh': self.score_thresh, + 'nms_thresh': self.nms_thresh, + 'sigma_thresh': self.sigma_thresh, + 'soft_thresh': self.soft_thresh + } + self.out_file = 'res_decode_' + str(self.score_thresh) + '_' + \ + str(self.nms_thresh) + '_' + str(self.sigma_thresh) + \ + '_' + str(self.soft_thresh) + + def accumulate(self, loss, pred, label): + cur_batch_size = loss[0].shape[0] + self.aggr_loss += np.mean(np.array(loss[0])) + self.aggr_loc_loss += np.mean(np.array(loss[1])) + self.aggr_cls_loss += np.mean(np.array(loss[2])) + self.aggr_batch_size += cur_batch_size + if self.mode == 'test': + box_preds, label_preds, score_preds = self.box_coder.decode( + pred[0].squeeze(), pred[1].squeeze(), **self.box_decode_params) + fid = label[-1] + fname = self.gt_labels[fid] + logger.info("file {}, num of box preds {}:".format(fname, + len(box_preds))) + self.results_detect[fname] = [] + for j in range(len(label_preds)): + results_detect[fname[0]].append({ + "score": score_preds[j], + "label": self.class_label[label_preds[j]].strip(), + "segment": [ + max(0, self.video_time_dict[fname] * box_preds[j][0] / + 512.0), min(self.video_time_dict[fname], + self.video_time_dict[fname] * + box_preds[j][1] / 512.0) + ] + }) + + def finalize_metrics(self): + self.avg_loss = self.aggr_loss / self.aggr_batch_size + self.avg_loc_loss = self.aggr_loc_loss / self.aggr_batch_size + self.avg_cls_loss = self.aggr_cls_loss / self.aggr_batch_size + if self.mode == 'test': + self.res_detect['results'] = self.results_detect + with open(self.out_file, 'w') as f: + json.dump(res_detect, f) + + def get_computed_metrics(self): + json_stats = {} + json_stats['avg_loss'] = self.avg_loss + json_stats['avg_loc_loss'] = self.avg_loc_loss + json_stats['avg_cls_loss'] = self.avg_cls_loss + return json_stats diff --git a/PaddleCV/video/metrics/metrics_util.py b/PaddleCV/video/metrics/metrics_util.py index d2f9c207..416016cc 100644 --- a/PaddleCV/video/metrics/metrics_util.py +++ b/PaddleCV/video/metrics/metrics_util.py @@ -23,6 +23,7 @@ import numpy as np from metrics.youtube8m import eval_util as youtube8m_metrics from metrics.kinetics import accuracy_metrics as kinetics_metrics from metrics.multicrop_test import multicrop_test_metrics as multicrop_test_metrics +from metrics.detections import detection_metrics as detection_metrics logger = logging.getLogger(__name__) @@ -160,6 +161,43 @@ class MulticropMetrics(Metrics): self.calculator.reset() +class DetectionMetrics(Metrics): + def __init__(self, name, mode, cfg): + self.name = name + self.mode = mode + args = {} + args['score_thresh'] = cfg.TEST.score_thresh + args['nms_thresh'] = cfg.TEST.nms_thresh + args['sigma_thresh'] = cfg.TEST.sigma_thresh + args['soft_thresh'] = cfg.TEST.soft_thresh + args['class_label_file'] = cfg.TEST.class_label_file + args['video_duration_file'] = cfg.TEST.video_duration_file + args['gt_label_file'] = cfg.TEST.filelist + args['mode'] = mode + args['name'] = name + self.calculator = detection_metrics.MetricsCalculator(**args) + + def calculate_and_log_out(self, loss, pred, label, info=''): + logger.info(info + + '\tLoss = {}, \tloc_loss = {}, \tcls_loss = {}'.format( + np.mean(loss[0]), np.mean(loss[1]), np.mean(loss[2]))) + + def accumulate(self, loss, pred, label): + self.calculator.accumulate(loss, pred, label) + + def finalize_and_log_out(self, info=''): + self.calculator.finalize_metrics() + metrics_dict = self.calculator.get_computed_metrics() + loss = metrics_dict['avg_loss'] + loc_loss = metrics_dict['avg_loc_loss'] + cls_loss = metrics_dict['avg_cls_loss'] + logger.info(info + '\tLoss: {},\tloc_loss: {}, \tcls_loss: {}'.format('%.6f' % loss, \ + '%.6f' % loc_loss, '%.6f' % cls_loss)) + + def reset(self): + self.calculator.reset() + + class MetricsZoo(object): def __init__(self): self.metrics_zoo = {} @@ -196,3 +234,4 @@ regist_metrics("NONLOCAL", MulticropMetrics) regist_metrics("TSM", Kinetics400Metrics) regist_metrics("TSN", Kinetics400Metrics) regist_metrics("STNET", Kinetics400Metrics) +regist_metrics("CTCN", DetectionMetrics) diff --git a/PaddleCV/video/models/__init__.py b/PaddleCV/video/models/__init__.py index 72ee303f..40276731 100644 --- a/PaddleCV/video/models/__init__.py +++ b/PaddleCV/video/models/__init__.py @@ -6,6 +6,7 @@ from .nonlocal_model import NonLocal from .tsm import TSM from .tsn import TSN from .stnet import STNET +from .ctcn import CTCN # regist models, sort by alphabet regist_model("AttentionCluster", AttentionCluster) @@ -15,3 +16,4 @@ regist_model('NONLOCAL', NonLocal) regist_model("TSM", TSM) regist_model("TSN", TSN) regist_model("STNET", STNET) +regist_model("CTCN", CTCN) diff --git a/PaddleCV/video/models/ctcn/__init__.py b/PaddleCV/video/models/ctcn/__init__.py new file mode 100644 index 00000000..9220ebe9 --- /dev/null +++ b/PaddleCV/video/models/ctcn/__init__.py @@ -0,0 +1 @@ +from .ctcn import * diff --git a/PaddleCV/video/models/ctcn/ctcn.py b/PaddleCV/video/models/ctcn/ctcn.py new file mode 100644 index 00000000..d0623f80 --- /dev/null +++ b/PaddleCV/video/models/ctcn/ctcn.py @@ -0,0 +1,179 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +import numpy as np + +from ..model import ModelBase +from . import fpn_ctcn + +import logging +logger = logging.getLogger(__name__) + +__all__ = ["CTCN"] + + +class CTCN(ModelBase): + """C-TCN model""" + + def __init__(self, name, cfg, mode='train'): + super(CTCN, self).__init__(name, cfg, mode=mode) + self.get_config() + + def get_config(self): + self.img_size = self.get_config_from_sec('MODEL', 'img_size') + self.concept_size = self.get_config_from_sec('MODEL', 'concept_size') + self.num_classes = self.get_config_from_sec('MODEL', 'num_classes') + self.num_anchors = self.get_config_from_sec('MODEL', 'num_anchors') + self.total_num_anchors = self.get_config_from_sec('MODEL', + 'total_num_anchors') + + self.num_epochs = self.get_config_from_sec('train', 'epoch') + self.base_learning_rate = self.get_config_from_sec('train', + 'learning_rate') + self.learning_rate_decay = self.get_config_from_sec( + 'train', 'learning_rate_decay') + self.l2_weight_decay = self.get_config_from_sec('train', + 'l2_weight_decay') + self.momentum = self.get_config_from_sec('train', 'momentum') + self.lr_decay_iter = self.get_config_from_sec('train', 'lr_decay_iter') + + def build_input(self, use_pyreader=True): + image_shape = [1, self.img_size, self.concept_size] + loc_shape = [self.total_num_anchors, 2] + cls_shape = [self.total_num_anchors] + fileid_shape = [1] + self.use_pyreader = use_pyreader + # set init data to None + py_reader = None + image = None + loc_targets = None + cls_targets = None + fileid = None + if use_pyreader: + assert self.mode != 'infer', \ + 'pyreader is not recommendated when infer, please set use_pyreader to be false.' + if (self.mode == 'train') or (self.mode == 'valid'): + py_reader = fluid.layers.py_reader( + capacity=100, + shapes=[[-1] + image_shape, [-1] + loc_shape, + [-1] + cls_shape], + dtypes=['float32', 'float32', 'int64'], + name='train_py_reader' + if self.is_training else 'test_py_reader', + use_double_buffer=True) + image, loc_targets, cls_targets = fluid.layers.read_file( + py_reader) + elif self.mode == 'test': + py_reader = fluid.layers.py_reader( + capacity=100, + shapes=[[-1] + image_shape, [-1] + loc_shape, [-1] + + cls_shape] + [-1, 1], + dtypes=['float32', 'float32', 'int64', 'int64'], + use_double_buffer=True) + image, loc_targets, cls_targets, fileid = fluid.layers.read_file( + pyreader) + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + self.py_reader = py_reader + else: + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + if (self.mode == 'train') or (self.mode == 'valid'): + loc_targets = fluid.layers.data( + name='loc_targets', shape=loc_shape, dtype='float32') + cls_targets = fluid.layers.data( + name='cls_targets', shape=cls_shape, dtype='int64') + elif self.mode == 'test': + loc_targets = fluid.layers.data( + name='loc_targets', shape=loc_shape, dtype='float32') + cls_targets = fluid.layers.data( + name='cls_targets', shape=cls_shape, dtype='int64') + fileid = fluid.layers.data( + name='fileid', shape=fileid_shape, dtype='int64') + elif self.mode == 'infer': + fileid = fluid.layers.data( + name='fileid', shape=fileid_shape, dtype='int64') + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + self.feature_input = [image] + self.cls_targets = cls_targets + self.loc_targets = loc_targets + self.fileid = fileid + + def create_model_args(self): + cfg = {} + cfg['num_anchors'] = self.num_anchors + cfg['concept_size'] = self.concept_size + cfg['num_classes'] = self.num_classes + return cfg + + def build_model(self): + cfg = self.create_model_args() + self.videomodel = fpn_ctcn.FPNCTCN( + num_anchors=cfg['num_anchors'], + concept_size=cfg['concept_size'], + num_classes=cfg['num_classes'], + mode=self.mode) + loc_preds, cls_preds = self.videomodel.net(input=self.feature_input[0]) + self.network_outputs = [loc_preds, cls_preds] + + def optimizer(self): + bd = [self.lr_decay_iter] + base_lr = self.base_learning_rate + lr_decay = self.learning_rate_decay + lr = [base_lr, base_lr * lr_decay] + l2_weight_decay = self.l2_weight_decay + momentum = self.momentum + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=momentum, + regularization=fluid.regularizer.L2Decay(l2_weight_decay)) + + return optimizer + + def loss(self): + assert self.mode != 'infer', "invalid loss calculationg in infer mode" + self.loss_ = self.videomodel.loss(self.network_outputs[0], + self.network_outputs[1], + self.loc_targets, self.cls_targets) + return self.loss_ + + def outputs(self): + loc_preds = self.network_outputs[0] + cls_preds = fluid.layers.softmax(self.network_outputs[1]) + return [loc_preds, cls_preds] + + def feeds(self): + if (self.mode == 'train') or (self.mode == 'valid'): + return self.feature_input + [self.loc_targets, self.cls_targets] + elif self.mode == 'test': + return self.feature_input + [ + self.loc_targets, self.cls_targets, self.fileid + ] + elif self.mode == 'infer': + return self.feature_input + [self.fileid] + else: + raise NotImplemented + + def pretrain_info(self): + return (None, None) + + def weights_info(self): + return (None, None) diff --git a/PaddleCV/video/models/ctcn/ctcn_utils.py b/PaddleCV/video/models/ctcn/ctcn_utils.py new file mode 100644 index 00000000..c34038e7 --- /dev/null +++ b/PaddleCV/video/models/ctcn/ctcn_utils.py @@ -0,0 +1,331 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and + +import numpy as np +from paddle.fluid.initializer import Uniform + + +# This file includes initializer, box encode, box decode +# initializer +def get_ctcn_conv_initializer(x, filter_size): + c_in = x.shape[1] + if isinstance(filter_size, int): + fan_in = c_in * filter_size * filter_size + else: + fan_in = c_in * filter_size[0] * filter_size[1] + std = np.sqrt(1.0 / fan_in) + return Uniform(0. - std, std) + + +#box tools +def box_clamp1D(boxes, xmin, xmax): + '''Clamp boxes. + Args: + boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [N,2]. + xmin: (number) min value of x. + xmax: (number) max value of x. + ''' + np.clip(boxes[:, 0], xmin, xmax, out=boxes[:, 0]) + np.clip(boxes[:, 1], xmin, xmax, out=boxes[:, 1]) + return boxes + + +def box_iou1D(box1, box2): + '''Compute the intersection over union of two set of boxes. + + The box order must be (xmin, xmax). + + Args: + box1: (tensor) bounding boxes, sized [N,2]. + box2: (tensor) bounding boxes, sized [M,2]. + + Return: + (tensor) iou, sized [N,M]. + ''' + box1 = np.array(box1) + box2 = np.array(box2) + N = box1.shape[0] + M = box2.shape[0] + + left = np.maximum(box1[:, None, 0], box2[:, 0]) + right = np.minimum(box1[:, None, 1], box2[:, 1]) + inter = (right - left).clip(min=0) + area1 = np.abs(box1[:, 0] - box1[:, 1]) + area2 = np.abs(box2[:, 0] - box2[:, 1]) + iou = inter / (area1[:, None] + area2 - inter) + return iou + + +def change_box_order(boxes, order): + assert order in ['yy2yh', 'yh2yy'] + a = boxes[:, 0, None] + b = boxes[:, 1, None] + if order == 'yy2yh': + return np.concatenate(((a + b) / 2, b - a), axis=1) + return np.concatenate((a - b / 2, a + b / 2), axis=1) + + +def box_nms(bboxes, scores, threshold=0.5, mode='union'): + '''Non maximum suppression. + Args: + bboxes: (tensor) bounding boxes, sized [N,2]. + scores: (tensor) confidence scores, sized [N,]. + threshold: (float) overlap threshold. + mode: (str) 'union' or 'min'. + + Returns: + keep: (tensor) selected indices. + + Reference: + https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py + ''' + y1 = bboxes[:, 0] + y2 = bboxes[:, 1] + + areas = (y2 - y1) + order = np.argsort(-scores, axis=0) + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + if order.size == 1: + break + + yy1 = np.clip(y1[order[1:]], y1[i], None) + yy2 = np.clip(y2[order[1:]], None, y2[i]) + h = np.clip(yy2 - yy1, 0, None) + inter = h + + if mode == 'union': + ovr = inter / (areas[i] + areas[order[1:]] - inter) + elif mode == 'min': + ovr = inter / np.clip(areas[order[1:]], None, areas[i]) + else: + raise TypeError('Unknown nms mode: %s.' % mode) + + ids = (ovr <= threshold).nonzero()[0] + if ids.size == 0: + break + order = order[ids + 1] + return np.array(keep, dtype='int64') + + +def soft_nms(props, method=0, sigma=1., Nt=0.7, threshold=0.001): + ''' + param dets: dection results, 2 dims [N, 3] + param props: predicted scores + ''' + N = props.shape[0] + + for i in range(N): + maxscore = props[i, 2] + maxpos = i + + tx = props[i, 0] + ty = props[i, 1] + ts = props[i, 2] + + pos = i + 1 + while pos < N: + if maxscore < props[pos, 2]: + maxscore = props[pos, 2] + maxpos = pos + pos += 1 + + props[i, 0] = props[maxpos, 0] + props[i, 1] = props[maxpos, 1] + props[i, 2] = props[maxpos, 2] + + props[maxpos, 0] = tx + props[maxpos, 1] = ty + props[maxpos, 2] = ts + + tx = props[i, 0] + ty = props[i, 1] + ts = props[i, 2] + + pos = i + 1 + while pos < N: + x = props[pos, 0] + y = props[pos, 1] + s = props[pos, 2] + + max_begin = max(x, tx) + min_end = min(y, ty) + inter = max(0.0, min_end - max_begin) + overlap = inter / (y - x + ty - tx - inter) + + if method == 1: + if overlap > Nt: + weight = 1 - overlap + else: + weight = 1 + elif method == 2: + weight = np.exp(-(overlap**2) / sigma) + else: + if overlap > Nt: + weight = 0 + else: + weight = 1 + + props[pos, 2] = weight * props[pos, 2] + + if props[pos, 2] < threshold: + props[pos, 0] = props[N - 1, 0] + props[pos, 1] = props[N - 1, 1] + props[pos, 2] = props[N - 1, 2] + N -= 1 + pos -= 1 + + pos += 1 + keep = [i for i in range(N)] + return props[keep] + + +# box encode and decode +class BoxCoder(): + def __init__(self): + self.steps = (4, 8, 16, 32, 64, 128, 256, 512) + self.fm_sizes = (128, 64, 32, 16, 8, 4, 2, 1) + self.anchor_num = 3 + self.default_boxes = self._get_default_boxes() + + def _get_default_boxes(self): + boxes = [] + for i, fm_size in enumerate(self.fm_sizes): + for h in range(fm_size): + cy = (h + 0.5) * self.steps[i] + base_s = self.steps[i] + boxes.append((cy, base_s)) + for p in range(self.anchor_num): + s = (base_s * 4.5 / 15.0) * (1.0 + p) / self.anchor_num + boxes.append((cy, base_s - s)) + if base_s == 512: + step_s = (base_s * 4.5 / 15.0) / (2 * self.anchor_num) + boxes.append((cy, base_s - s - step_s)) + else: + boxes.append((cy, base_s + s)) + return np.array(boxes) + + def encode(self, boxes, labels): + def argmax(x): + v = x.max(0) # sort by cols, max_v, index + i = np.argmax(x, 0) + j = np.argmax(v, 0) # v.max(0)[1][0] # sort v, index + return (i[j], j) # return max index (row,col) + + labels = np.array(labels) + default_boxes = self.default_boxes + default_boxes = change_box_order(default_boxes, 'yh2yy') + + ious = box_iou1D(default_boxes, boxes) # [#anchors, #obj] + index = np.full(len(default_boxes), fill_value=-1, dtype='int64') + + masked_ious = ious.copy() + + while True: + i, j = argmax(masked_ious) + if masked_ious[i, j] < 1e-6: + break + index[i] = j + masked_ious[i, :] = 0 + masked_ious[:, j] = 0 + + mask = (index < 0) & (ious.max(1) >= 0.5) + if mask.any(): + if np.squeeze(mask.nonzero()).size > 1: + index[mask] = np.argmax(ious[np.squeeze(mask.nonzero())], 1) + + boxes = boxes[np.clip(index, a_min=0, a_max=None)] + boxes = change_box_order(boxes, 'yy2yh') + default_boxes = change_box_order(default_boxes, 'yy2yh') + + variances = (0.1, 0.2) + loc_xy = (boxes[:, 0, None] - default_boxes[:, 0, None] + ) / default_boxes[:, 1, None] / variances[0] + loc_wh = ( + boxes[:, 1, None] / default_boxes[:, 1, None] - 1.0) / variances[1] + + loc_targets = np.concatenate((loc_xy, loc_wh), axis=1) + cls_targets = labels[index.clip(0, None)] + cls_targets[index < 0] = 0 + + return loc_targets, cls_targets + + def decode(self, + loc_preds, + cls_preds, + score_thresh=0.6, + nms_thresh=0.45, + sigma_thresh=1.0, + soft_thresh=0.01): + '''Decode predicted loc/cls back to real box locations and class labels. + + Args: + loc_preds: (tensor) predicted loc, sized [8732,2]. + cls_preds: (tensor) predicted conf, sized [8732,201]. + score_thresh: (float) threshold for object confidence score. + nms_thresh: (float) threshold for box nms. + + Returns: + boxes: (tensor) bbox locations, sized [#obj,2]. + labels: (tensor) class labels, sized [#obj,]. + ''' + + variances = (0.1, 0.2) + y = loc_preds[:, 0, None] * variances[ + 0] * self.default_boxes[:, 1, None] + self.default_boxes[:, 0, None] + h = (loc_preds[:, 1, None] * variances[1] + 1.0 + ) * self.default_boxes[:, 1, None] + box_preds = np.concatenate((y - h / 2.0, y + h / 2.0), axis=1) + + boxes = [] + labels = [] + scores = [] + num_classes = cls_preds.shape[1] + max_num = -1 + max_id = -1 + for i in range(num_classes - 1): + score = cls_preds[:, i + 1] + mask = score > score_thresh + if not mask.any(): + continue + box = box_preds[mask] + score = score[mask] + if len(score) > max_num: + max_num = len(score) + max_id = i + + keep = box_nms(box, score, nms_thresh) + box = box[keep] + score = score[keep] + + now_vector = np.concatenate((box, score[:, None]), axis=1) + + res = soft_nms( + now_vector, method=2, sigma=sigma_thresh, threshold=soft_thresh) + + final_box = res[:, :2] + final_score = res[:, 2] + boxes.append(final_box) + labels.append(np.full(len(final_box), fill_value=i, dtype='int64')) + scores.append(final_score) + if len(boxes) == 0: + boxes.append(np.array([[0, 1.0]], dtype='float32')) + labels.append(np.full(1, fill_value=1, dtype='int64')) + scores.append(np.full(1, fill_value=1, dtype='float32')) + boxes = np.concatenate(boxes, 0) + labels = np.concatenate(labels, 0) + scores = np.concatenate(scores, 0) + return boxes, labels, scores diff --git a/PaddleCV/video/models/ctcn/fpn_ctcn.py b/PaddleCV/video/models/ctcn/fpn_ctcn.py new file mode 100644 index 00000000..5dcd9673 --- /dev/null +++ b/PaddleCV/video/models/ctcn/fpn_ctcn.py @@ -0,0 +1,322 @@ +#coding=UTF-8 + +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +import numpy as np +from .ctcn_utils import get_ctcn_conv_initializer as get_init + +DATATYPE = 'float32' + + +class FPNCTCN(object): + def __init__(self, num_anchors, concept_size, num_classes, mode='train'): + self.num_anchors = num_anchors + self.concept_size = concept_size + self.num_classes = num_classes + self.is_training = (mode == 'train') + + def conv_bn_layer(self, + input, + ch_out, + filter_size, + stride=1, + padding=0, + act='relu'): + conv = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(initializer=get_init(input, filter_size)), + bias_attr=False) + return fluid.layers.batch_norm( + input=conv, + act=act, + is_test=(not self.is_training), ) + + def shortcut(self, input, planes, stride): + if (input.shape[1] == planes * 4) and (stride == 1): + return input + else: + return self.conv_bn_layer(input, planes * 4, 1, stride, act=None) + + def bottleneck_block(self, input, planes, stride=1): + conv0 = self.conv_bn_layer(input, planes, filter_size=1) + conv1 = self.conv_bn_layer( + conv0, planes, filter_size=(3, 1), stride=stride, padding=(1, 0)) + conv2 = self.conv_bn_layer(conv1, planes * 4, filter_size=1, act=None) + short = self.shortcut(input, planes, stride) + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + + def layer_warp(self, input, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + for stride in strides: + input = self.bottleneck_block(input, planes, stride) + return input + + def upsample_add(self, x, y): + _, _, H, W = y.shape + upsample = fluid.layers.image_resize( + x, out_shape=[H, W], resample='BILINEAR') + return upsample + y + + def extractor(self, input): + num_blocks = [3, 4, 6, 3] + + c1 = self.conv_bn_layer( + input, ch_out=32, filter_size=(7, 1), stride=(2, 1), padding=(3, 0)) + + c1 = self.conv_bn_layer( + c1, ch_out=64, filter_size=(7, 1), stride=(2, 1), padding=(3, 0)) + + c2 = self.layer_warp(c1, 64, num_blocks[0], 1) + c3 = self.layer_warp(c2, 128, num_blocks[1], (2, 1)) + c4 = self.layer_warp(c3, 256, num_blocks[2], (2, 1)) + c5 = self.layer_warp(c4, 512, num_blocks[3], (2, 1)) + + #feature pyramid + p6 = fluid.layers.conv2d( + c5, + num_filters=512, + filter_size=(3, 1), + stride=(2, 1), + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(c5, (3, 1)))) + + p7 = fluid.layers.relu(p6) + p7 = fluid.layers.conv2d( + p7, + num_filters=512, + filter_size=(3, 1), + stride=(2, 1), + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(p7, (3, 1)))) + + p8 = fluid.layers.relu(p7) + p8 = fluid.layers.conv2d( + p8, + num_filters=512, + filter_size=(3, 1), + stride=(2, 1), + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(p8, (3, 1)))) + + p9 = fluid.layers.relu(p8) + p9 = fluid.layers.conv2d( + p9, + num_filters=512, + filter_size=(3, 1), + stride=(2, 1), + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(p9, (3, 1)))) + + #top_down + p5 = fluid.layers.conv2d( + c5, + 512, + 1, + 1, + 0, + param_attr=ParamAttr(initializer=get_init(c5, 1)), ) + + p4 = self.upsample_add( + p5, + fluid.layers.conv2d( + c4, + 512, + 1, + 1, + 0, + param_attr=ParamAttr(initializer=get_init(c4, 1)), )) + + p3 = self.upsample_add( + p4, + fluid.layers.conv2d( + c3, + 512, + 1, + 1, + 0, + param_attr=ParamAttr(initializer=get_init(c3, 1)), )) + + p2 = self.upsample_add( + p3, + fluid.layers.conv2d( + c2, + 512, + 1, + 1, + 0, + param_attr=ParamAttr(initializer=get_init(c2, 1)))) + #smooth + p4 = fluid.layers.conv2d( + p4, + num_filters=512, + filter_size=(3, 1), + stride=1, + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(p4, (3, 1))), ) + + p3 = fluid.layers.conv2d( + p3, + num_filters=512, + filter_size=(3, 1), + stride=1, + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(p3, (3, 1))), ) + + p2 = fluid.layers.conv2d( + p2, + num_filters=512, + filter_size=(3, 1), + stride=1, + padding=(1, 0), + param_attr=ParamAttr(initializer=get_init(p2, (3, 1))), ) + + return p2, p3, p4, p5, p6, p7, p8, p9 + + def net(self, input): + fm_sizes = self.concept_size # 402 + num_anchors = self.num_anchors # 7 + + loc_preds = [] + cls_preds = [] + # build fpn network + xs = self.extractor(input) + # build predict head + for i, x in enumerate(xs): + loc_pred = fluid.layers.dropout( + x, dropout_prob=0.5, is_test=(not self.is_training)) + loc_pred = fluid.layers.conv2d( + loc_pred, + num_filters=256, + filter_size=(3, 1), + stride=1, + padding=(1, 0), + param_attr=ParamAttr( + name='loc_pred_conv1_weights', + initializer=get_init(loc_pred, (3, 1))), + bias_attr=ParamAttr( + name='loc_pred_conv1_bias', )) + + loc_pred = fluid.layers.conv2d( + loc_pred, + num_filters=num_anchors * 2, + filter_size=(1, fm_sizes), + stride=1, + padding=0, + param_attr=ParamAttr( + name='loc_pred_conv2_weights', + initializer=get_init(loc_pred, (1, fm_sizes))), + bias_attr=ParamAttr( + name='loc_pred_conv2_bias', )) + + loc_pred = 10.0 * fluid.layers.sigmoid(loc_pred) - 5.0 + loc_pred = fluid.layers.transpose(loc_pred, perm=[0, 2, 3, 1]) + tmp_size1 = loc_pred.shape[1] * loc_pred.shape[2] * loc_pred.shape[ + 3] // 2 + loc_pred = fluid.layers.reshape( + x=loc_pred, shape=[loc_pred.shape[0], tmp_size1, 2]) + loc_preds.append(loc_pred) + + cls_pred = fluid.layers.dropout( + x, dropout_prob=0.5, is_test=(not self.is_training)) + cls_pred = fluid.layers.conv2d( + cls_pred, + num_filters=512, + filter_size=(3, 1), + stride=1, + padding=(1, 0), + param_attr=ParamAttr( + name='cls_pred_conv1_weights', + initializer=get_init(cls_pred, (3, 1))), + bias_attr=ParamAttr( + name='cls_pred_conv1_bias', )) + + cls_pred = fluid.layers.conv2d( + cls_pred, + num_filters=num_anchors * self.num_classes, + filter_size=(1, fm_sizes), + stride=1, + padding=0, + param_attr=ParamAttr( + name='cls_pred_conv2_weights', + initializer=get_init(cls_pred, (1, fm_sizes))), + bias_attr=ParamAttr( + name='cls_pred_conv2_bias', )) + + cls_pred = fluid.layers.transpose(cls_pred, perm=[0, 2, 3, 1]) + tmp_size2 = cls_pred.shape[1] * cls_pred.shape[2] * cls_pred.shape[ + 3] // self.num_classes + cls_pred = fluid.layers.reshape( + x=cls_pred, + shape=[cls_pred.shape[0], tmp_size2, self.num_classes]) + cls_preds.append(cls_pred) + + loc_preds = fluid.layers.concat(input=loc_preds, axis=1) + cls_preds = fluid.layers.concat(input=cls_preds, axis=1) + return loc_preds, cls_preds + + def hard_negative_mining(self, cls_loss, pos_bool): + pos = fluid.layers.cast(pos_bool, dtype=DATATYPE) + cls_loss = cls_loss * (pos - 1) + _, indices = fluid.layers.argsort(cls_loss, axis=1) + indices = fluid.layers.cast(indices, dtype=DATATYPE) + _, rank = fluid.layers.argsort(indices, axis=1) + + num_neg = 3 * fluid.layers.reduce_sum(pos, dim=1) + num_neg = fluid.layers.reshape(x=num_neg, shape=[-1, 1]) + neg = rank < num_neg + return neg + + def loss(self, loc_preds, cls_preds, loc_targets, cls_targets): + """ + param loc_targets: [N, 1785,2] + param cls_targets: [N, 1785] + """ + + loc_targets.stop_gradient = True + cls_targets.stop_gradient = True + + pos = cls_targets > 0 + pos_bool = pos + pos = fluid.layers.cast(pos, dtype=DATATYPE) + num_pos = fluid.layers.reduce_sum(pos) + pos = fluid.layers.unsqueeze(pos, axes=[2]) + mask = fluid.layers.expand(pos, expand_times=[1, 1, 2]) + mask.stop_gradient = True + + loc_loss = fluid.layers.smooth_l1( + loc_preds, loc_targets, inside_weight=mask, outside_weight=mask) + loc_loss = fluid.layers.reduce_sum(loc_loss) + + cls_loss = fluid.layers.softmax_with_cross_entropy( + logits=fluid.layers.reshape( + cls_preds, shape=[-1, self.num_classes]), + label=fluid.layers.reshape( + cls_targets, shape=[-1, 1]), + numeric_stable_mode=True) + + cls_loss = fluid.layers.reshape( + cls_loss, shape=[-1, loc_targets.shape[1]]) + not_ignore = cls_targets >= 0 + not_ignore = fluid.layers.cast(not_ignore, dtype=DATATYPE) + not_ignore.stop_gradient = True + + cls_loss = cls_loss * not_ignore + + neg = self.hard_negative_mining(cls_loss, pos_bool) + neg = fluid.layers.cast(neg, dtype='bool') + pos_bool = fluid.layers.cast(pos_bool, dtype='bool') + + selects = fluid.layers.logical_or(pos_bool, neg) + selects = fluid.layers.cast(selects, dtype=DATATYPE) + selects.stop_gradient = True + cls_loss = cls_loss * selects + cls_loss = fluid.layers.reduce_sum(cls_loss) + alpha = 2.0 + loss = (alpha * loc_loss + cls_loss) / num_pos + num_pos.stop_gradient = True + return loss, alpha * loc_loss / num_pos, cls_loss / num_pos diff --git a/PaddleCV/video/models/model.py b/PaddleCV/video/models/model.py index bf5947f4..b3f9d4f9 100644 --- a/PaddleCV/video/models/model.py +++ b/PaddleCV/video/models/model.py @@ -20,8 +20,6 @@ except: from ConfigParser import ConfigParser import paddle.fluid as fluid -from datareader import get_reader -from metrics import get_metrics from .utils import download, AttrDict WEIGHT_DIR = os.path.expanduser("~/.paddle/weights") @@ -68,7 +66,6 @@ class ModelBase(object): self.cfg = cfg self.py_reader = None - def build_model(self): "build model struct" raise NotImplementError(self, self.build_model) diff --git a/PaddleCV/video/scripts/test/test_ctcn.sh b/PaddleCV/video/scripts/test/test_ctcn.sh new file mode 100644 index 00000000..9003a142 --- /dev/null +++ b/PaddleCV/video/scripts/test/test_ctcn.sh @@ -0,0 +1,2 @@ +python test.py --model_name="CTCN" --config=./configs/ctcn.txt \ + --log_interval=10 --weights=./checkpoints/CTCN_epoch0 diff --git a/PaddleCV/video/scripts/train/train_ctcn.sh b/PaddleCV/video/scripts/train/train_ctcn.sh new file mode 100644 index 00000000..bddbe01f --- /dev/null +++ b/PaddleCV/video/scripts/train/train_ctcn.sh @@ -0,0 +1,9 @@ +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +#export CUDA_VISIBLE_DEVICES=0 + +export FLAGS_fast_eager_deletion_mode=1 +export FLAGS_eager_delete_tensor_gb=0.0 +export FLAGS_fraction_of_gpu_memory_to_use=1.0 + +python train.py --model_name="CTCN" --config=./configs/ctcn.txt --epoch=35 \ + --valid_interval=1 --log_interval=1 diff --git a/PaddleCV/video/tools/train_utils.py b/PaddleCV/video/tools/train_utils.py index f31e0077..3347de1f 100644 --- a/PaddleCV/video/tools/train_utils.py +++ b/PaddleCV/video/tools/train_utils.py @@ -37,13 +37,26 @@ def test_without_pyreader(test_exe, test_feeder, test_fetch_list, test_metrics, - log_interval=0): + log_interval=0, + save_model_name=''): test_metrics.reset() for test_iter, data in enumerate(test_reader()): test_outs = test_exe.run(test_fetch_list, feed=test_feeder.feed(data)) - loss = np.array(test_outs[0]) - pred = np.array(test_outs[1]) - label = np.array(test_outs[-1]) + if save_model_name in ['CTCN']: + # for detection + total_loss = np.array(test_outs[0]) + loc_loss = np.array(test_outs[1]) + cls_loss = np.array(test_outs[2]) + loc_preds = np.array(test_outs[3]) + cls_preds = np.array(test_outs[4]) + label = np.array(test_outs[-1]) + loss = [total_loss, loc_loss, cls_loss] + pred = [loc_preds, cls_preds] + else: + # for classification + loss = np.array(test_outs[0]) + pred = np.array(test_outs[1]) + label = np.array(test_outs[-1]) test_metrics.accumulate(loss, pred, label) if log_interval > 0 and test_iter % log_interval == 0: test_metrics.calculate_and_log_out(loss, pred, label, \ @@ -55,7 +68,8 @@ def test_with_pyreader(test_exe, test_pyreader, test_fetch_list, test_metrics, - log_interval=0): + log_interval=0, + save_model_name=''): if not test_pyreader: logger.error("[TEST] get pyreader failed.") test_pyreader.start() @@ -64,9 +78,20 @@ def test_with_pyreader(test_exe, try: while True: test_outs = test_exe.run(fetch_list=test_fetch_list) - loss = np.array(test_outs[0]) - pred = np.array(test_outs[1]) - label = np.array(test_outs[-1]) + if save_model_name in ['CTCN']: + # for detection + total_loss = np.array(test_outs[0]) + loc_loss = np.array(test_outs[1]) + cls_loss = np.array(test_outs[2]) + loc_preds = np.array(test_outs[3]) + cls_preds = np.array(test_outs[4]) + label = np.array(test_outs[-1]) + loss = [total_loss, loc_loss, cls_loss] + pred = [loc_preds, cls_preds] + else: + loss = np.array(test_outs[0]) + pred = np.array(test_outs[1]) + label = np.array(test_outs[-1]) test_metrics.accumulate(loss, pred, label) if log_interval > 0 and test_iter % log_interval == 0: test_metrics.calculate_and_log_out(loss, pred, label, \ @@ -92,9 +117,21 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede feed=train_feeder.feed(data)) period = time.time() - cur_time epoch_periods.append(period) - loss = np.array(train_outs[0]) - pred = np.array(train_outs[1]) - label = np.array(train_outs[-1]) + if save_model_name in ['CTCN']: + # detection model + total_loss = np.array(train_outs[0]) + loc_loss = np.array(train_outs[1]) + cls_loss = np.array(train_outs[2]) + loc_preds = np.array(train_outs[3]) + cls_preds = np.array(train_outs[4]) + label = np.array(train_outs[-1]) + loss = [total_loss, loc_loss, cls_loss] + pred = [loc_preds, cls_preds] + else: + # classification model + loss = np.array(train_outs[0]) + pred = np.array(train_outs[1]) + label = np.array(train_outs[-1]) if log_interval > 0 and (train_iter % log_interval == 0): # eval here train_metrics.calculate_and_log_out(loss, pred, label, \ @@ -107,8 +144,8 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede if test_exe and valid_interval > 0 and (epoch + 1 ) % valid_interval == 0: test_without_pyreader(test_exe, test_reader, test_feeder, - test_fetch_list, test_metrics, log_interval) - + test_fetch_list, test_metrics, log_interval, + save_model_name) def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \ @@ -133,9 +170,21 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \ train_outs = train_exe.run(fetch_list=train_fetch_list) period = time.time() - cur_time epoch_periods.append(period) - loss = np.array(train_outs[0]) - pred = np.array(train_outs[1]) - label = np.array(train_outs[-1]) + if save_model_name in ['CTCN']: + # for detection + total_loss = np.array(train_outs[0]) + loc_loss = np.array(train_outs[1]) + cls_loss = np.array(train_outs[2]) + loc_preds = np.array(train_outs[3]) + cls_preds = np.array(train_outs[4]) + label = np.array(train_outs[-1]) + loss = [total_loss, loc_loss, cls_loss] + pred = [loc_preds, cls_preds] + else: + # for classification + loss = np.array(train_outs[0]) + pred = np.array(train_outs[1]) + label = np.array(train_outs[-1]) if log_interval > 0 and (train_iter % log_interval == 0): # eval here train_loss = train_metrics.calculate_and_log_out(loss, pred, label, \ @@ -150,7 +199,7 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \ if test_exe and valid_interval > 0 and (epoch + 1 ) % valid_interval == 0: test_with_pyreader(test_exe, test_pyreader, test_fetch_list, - test_metrics, log_interval) + test_metrics, log_interval, save_model_name) finally: epoch_period = [] train_pyreader.reset() diff --git a/PaddleCV/video/train.py b/PaddleCV/video/train.py index efaa0edf..eb2a72e5 100644 --- a/PaddleCV/video/train.py +++ b/PaddleCV/video/train.py @@ -130,11 +130,20 @@ def train(args): train_feeds = train_model.feeds() train_feeds[-1].persistable = True # for the output of classification model, has the form [pred] + # for the output of detection model, has the form [loc_pred, cls_pred] train_outputs = train_model.outputs() for output in train_outputs: output.persistable = True - train_loss = train_model.loss() - train_loss.persistable = True + train_losses = train_model.loss() + if isinstance(train_losses, list) or isinstance(train_losses, + tuple): + # for detection model, train_losses has the form [total_loss, loc_loss, cls_loss] + train_loss = train_losses[0] + for item in train_losses: + item.persistable = True + else: + train_loss = train_losses + train_loss.persistable = True # outputs, loss, label should be fetched, so set persistable to be true optimizer = train_model.optimizer() optimizer.minimize(train_loss) @@ -146,8 +155,10 @@ def train(args): valid_model.build_input(not args.no_use_pyreader) valid_model.build_model() valid_feeds = valid_model.feeds() + # for the output of classification model, has the form [pred] + # for the output of detection model, has the form [loc_pred, cls_pred] valid_outputs = valid_model.outputs() - valid_loss = valid_model.loss() + valid_losses = valid_model.loss() valid_pyreader = valid_model.pyreader() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() @@ -175,6 +186,8 @@ def train(args): build_strategy = fluid.BuildStrategy() build_strategy.enable_inplace = True + if args.model_name in ['CTCN']: + build_strategy.enable_sequential_execution = True #build_strategy.memory_optimize = True train_exe = fluid.ParallelExecutor( @@ -202,10 +215,20 @@ def train(args): train_metrics = get_metrics(args.model_name.upper(), 'train', train_config) valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config) - train_fetch_list = [train_loss.name] + [x.name for x in train_outputs - ] + [train_feeds[-1].name] - valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs - ] + [valid_feeds[-1].name] + if isinstance(train_losses, tuple) or isinstance(train_losses, list): + # for detection + train_fetch_list = [item.name for item in train_losses] + \ + [x.name for x in train_outputs] + [train_feeds[-1].name] + valid_fetch_list = [item.name for item in valid_losses] + \ + [x.name for x in valid_outputs] + [valid_feeds[-1].name] + else: + # for classification + train_fetch_list = [train_losses.name] + [ + x.name for x in train_outputs + ] + [train_feeds[-1].name] + valid_fetch_list = [valid_losses.name] + [ + x.name for x in valid_outputs + ] + [valid_feeds[-1].name] epochs = args.epoch or train_model.epoch_num() -- GitLab