提交 e12544e2 编写于 作者: S SunGaofeng 提交者: ruri

Add ctcn model for action detection (#2529)

* Add ctcn model for action detection

* remove data list from codebase
上级 4ccf1884
[MODEL]
name = "CTCN"
num_classes = 201
img_size = 512
concept_size = 402
num_anchors = 7
total_num_anchors = 1785
snippet_length = 1
root = '/ssd3/huangjun/Paddle/feats'
[TRAIN]
epoch = 35
filelist = 'dataset/ctcn/Activity1.3_train_rgb.listformat'
rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_train'
flow = 'senet152-201cls-flow-60.9-5seg-331data_train'
batch_size = 16
num_threads = 8
use_gpu = True
num_gpus = 8
learning_rate = 0.0005
learning_rate_decay = 0.1
lr_decay_iter = 9000
l2_weight_decay = 1e-4
momentum = 0.9
[VALID]
filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat'
rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val'
flow = 'senet152-201cls-flow-60.9-5seg-331data_val'
batch_size = 16
num_threads = 8
use_gpu = True
num_gpus = 8
[TEST]
filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat'
rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val'
flow = 'senet152-201cls-flow-60.9-5seg-331data_val'
class_label_file = 'dataset/ctcn/test_val_label.list'
video_duration_file = 'dataset/ctcn/val_duration_frame.list'
batch_size = 1
num_threads = 1
score_thresh = 0.001
nms_thresh = 0.08
sigma_thresh = 0.006
soft_thresh = 0.006
[INFER]
filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat'
rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val'
flow = 'senet152-201cls-flow-60.9-5seg-331data_val'
batch_size = 1
num_threads = 1
......@@ -2,6 +2,7 @@ from .reader_utils import regist_reader, get_reader
from .feature_reader import FeatureReader
from .kinetics_reader import KineticsReader
from .nonlocal_reader import NonlocalReader
from .ctcn_reader import CTCNReader
# regist reader, sort by alphabet
regist_reader("ATTENTIONCLUSTER", FeatureReader)
......@@ -11,3 +12,4 @@ regist_reader("NONLOCAL", NonlocalReader)
regist_reader("TSM", KineticsReader)
regist_reader("TSN", KineticsReader)
regist_reader("STNET", KineticsReader)
regist_reader("CTCN", CTCNReader)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import cv2
import sys
import numpy as np
import gc
import copy
import multiprocessing
import logging
logger = logging.getLogger(__name__)
try:
import cPickle as pickle
from cStringIO import StringIO
except ImportError:
import pickle
from io import BytesIO
from .reader_utils import DataReader
from models.ctcn.ctcn_utils import box_clamp1D, box_iou1D, BoxCoder
python_ver = sys.version_info
#random.seed(0)
#np.random.seed(0)
class CTCNReader(DataReader):
"""
Data reader for C-TCN model, which was stored as features extracted by prior networks
dataset cfg: img_size, the temporal dimension size of input data
root, the root dir of data
snippet_length, snippet length when sampling
filelist, the file list storing id and annotations of each data item
rgb, the dir of rgb data
flow, the dir of optical flow data
batch_size, batch size of input data
num_threads, number of threads of data processing
"""
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
self.img_size = cfg.MODEL.img_size # 512
self.snippet_length = cfg.MODEL.snippet_length # 1
self.root = cfg.MODEL.root # root dir of data
self.filelist = cfg[mode.upper()]['filelist']
self.rgb = cfg[mode.upper()]['rgb']
self.flow = cfg[mode.upper()]['flow']
self.batch_size = cfg[mode.upper()]['batch_size']
self.num_threads = cfg[mode.upper()]['num_threads']
if (mode == 'test') or (mode == 'infer'):
self.num_threads = 1 # set num_threads as 1 for test and infer
def random_move(self, img, o_boxes, labels):
boxes = np.array(o_boxes)
mask = np.zeros(img.shape[0])
for i in boxes:
for j in range(i[0].astype('int'),
min(i[1].astype('int'), img.shape[0])):
mask[j] = 1
mask = (mask == 0)
bg = img[mask]
bg_len = bg.shape[0]
if bg_len < 5:
return img, boxes, labels
insert_place = random.sample(range(bg_len), len(boxes))
index = np.argsort(insert_place)
new_img = bg[0:insert_place[index[0]], :]
new_boxes = []
new_labels = []
for i in range(boxes.shape[0]):
new_boxes.append([
new_img.shape[0],
new_img.shape[0] + boxes[index[i]][1] - boxes[index[i]][0]
])
new_labels.append(labels[index[i]])
new_img = np.concatenate(
(new_img,
img[int(boxes[index[i]][0]):int(boxes[index[i]][1]), :]))
if i < boxes.shape[0] - 1:
new_img = np.concatenate(
(new_img,
bg[insert_place[index[i]]:insert_place[index[i + 1]], :]))
new_img = np.concatenate(
(new_img, bg[insert_place[index[len(boxes) - 1]]:, :]))
del img, boxes, mask, bg, labels
gc.collect()
return new_img, new_boxes, new_labels
def random_crop(self, img, boxes, labels, min_scale=0.3):
boxes = np.array(boxes)
labels = np.array(labels)
imh, imw = img.shape[:2]
params = [(0, imh)]
for min_iou in (0, 0.1, 0.3, 0.5, 0.7, 0.9):
for _ in range(100):
scale = random.uniform(0.3, 1)
h = int(imh * scale)
y = random.randrange(imh - h)
roi = [[y, y + h]]
ious = box_iou1D(boxes, roi)
if ious.min() >= min_iou:
params.append((y, h))
break
y, h = random.choice(params)
img = img[y:y + h, :]
center = (boxes[:, 0] + boxes[:, 1]) / 2
mask = (center[:] >= y) & (center[:] <= y + h)
if mask.any():
boxes = boxes[np.squeeze(mask.nonzero())] - np.array([[y, y]])
boxes = box_clamp1D(boxes, 0, h)
labels = labels[mask]
else:
boxes = [[0, 0]]
labels = [0]
return img, boxes, labels
def resize(self, img, boxes, size, random_interpolation=False):
'''Resize the input PIL image to given size.
If boxes is not None, resize boxes accordingly.
Args:
img: image to be resized.
boxes: (tensor) object boxes, sized [#obj,2].
size: (tuple or int)
- if is tuple, resize image to the size.
- if is int, resize the shorter side to the size while maintaining the aspect ratio.
random_interpolation: (bool) randomly choose a resize interpolation method.
Returns:
img: (cv2's numpy.ndarray) resized image.
boxes: (tensor) resized boxes.
Example:
>> img, boxes = resize(img, boxes, 600) # resize shorter side to 600
'''
h, w = img.shape[:2]
if h == size:
return img, boxes
if h == 0:
img = np.zeros((512, 402), np.float32)
return img, boxes
ow = w
oh = size
sw = 1
sh = float(oh) / h
method = random.choice([
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
]) if random_interpolation else cv2.INTER_NEAREST
img = cv2.resize(img, (ow, oh), interpolation=method)
if boxes is not None:
boxes = boxes * np.array([sh, sh])
return img, boxes
def transform(self, feats, boxes, labels, mode):
feats = np.array(feats)
boxes = np.array(boxes)
labels = np.array(labels)
#print('name {}, labels {}'.format(fname, labels))
if mode == 'train':
feats, boxes, labels = self.random_move(feats, boxes, labels)
feats, boxes, labels = self.random_crop(feats, boxes, labels)
feats, boxes = self.resize(
feats, boxes, size=self.img_size, random_interpolation=True)
h, w = feats.shape[:2]
img = feats.reshape(1, h, w)
Coder = BoxCoder()
boxes, labels = Coder.encode(boxes, labels)
if mode == 'test' or mode == 'valid':
feats, boxes = self.resize(feats, boxes, size=self.img_size)
h, w = feats.shape[:2]
img = feats.reshape(1, h, w)
Coder = BoxCoder()
boxes, labels = Coder.encode(boxes, labels)
return img, boxes, labels
def create_reader(self):
"""reader creator for ctcn model"""
if self.num_threads == 1:
return self.make_reader()
else:
return self.make_multiprocess_reader()
def make_reader(self):
"""single process reader"""
def reader():
with open(self.filelist) as f:
reader_list = f.readlines()
if self.mode == 'train':
random.shuffle(reader_list)
fnames = []
total_boxes = []
total_labels = []
total_label_ids = []
for i in range(len(reader_list)):
line = reader_list[i]
splited = line.strip().split()
rgb_exist = os.path.exists(
os.path.join(self.root, self.rgb, splited[0] + '.pkl'))
flow_exist = os.path.exists(
os.path.join(self.root, self.flow, splited[0] + '.pkl'))
if not (rgb_exist and flow_exist):
print('file not exist', splited[0])
continue
fnames.append(splited[0])
frames_num = int(splited[1]) // self.snippet_length
num_boxes = int(splited[2])
box = []
label = []
for i in range(num_boxes):
c = splited[3 + 3 * i]
xmin = splited[4 + 3 * i]
xmax = splited[5 + 3 * i]
box.append([
float(xmin) / self.snippet_length,
float(xmax) / self.snippet_length
])
label.append(int(c))
total_label_ids.append(i)
total_boxes.append(box)
total_labels.append(label)
num_videos = len(fnames)
batch_out = []
for idx in range(num_videos):
fname = fnames[idx]
try:
if python_ver < (3, 0):
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')))
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')))
else:
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')),
encoding='bytes')
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')),
encoding='bytes')
data_flow = np.array(flow_pkl['scores'])
data_rgb = np.array(rgb_pkl['scores'])
if data_flow.shape[0] < data_rgb.shape[0]:
data_rgb = data_rgb[0:data_flow.shape[0], :]
elif data_flow.shape[0] > data_rgb.shape[0]:
data_flow = data_flow[0:data_rgb.shape[0], :]
feats = np.concatenate((data_rgb, data_flow), axis=1)
if feats.shape[0] == 0 or feats.shape[1] == 0:
feats = np.zeros((512, 1024), np.float32)
logger.info('### file loading len = 0 {} ###'.format(
fname))
boxes = copy.deepcopy(total_boxes[idx])
labels = copy.deepcopy(total_labels[idx])
feats, boxes, labels = self.transform(feats, boxes, labels,
self.mode)
labels = labels.astype('int64')
boxes = boxes.astype('float32')
num_pos = len(np.where(labels > 0)[0])
except:
logger.info('Error when loading {}'.format(fname))
continue
if (num_pos < 1) and (self.mode == 'train' or
self.mode == 'valid'):
#logger.info('=== no pos for ==='.format(fname, num_pos))
continue
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((feats, boxes, labels))
elif self.mode == 'test':
batch_out.append(
(feats, boxes, labels, total_label_ids[idx]))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_multiprocess_reader(self):
"""multiprocess reader"""
def read_into_queue(reader_list, queue):
fnames = []
total_boxes = []
total_labels = []
total_label_ids = []
#for line in reader_list:
for i in range(len(reader_list)):
line = reader_list[i]
splited = line.strip().split()
rgb_exist = os.path.exists(
os.path.join(self.root, self.rgb, splited[0] + '.pkl'))
flow_exist = os.path.exists(
os.path.join(self.root, self.flow, splited[0] + '.pkl'))
if not (rgb_exist and flow_exist):
logger.info('file not exist {}'.format(splited[0]))
continue
fnames.append(splited[0])
frames_num = int(splited[1]) // self.snippet_length
num_boxes = int(splited[2])
box = []
label = []
for i in range(num_boxes):
c = splited[3 + 3 * i]
xmin = splited[4 + 3 * i]
xmax = splited[5 + 3 * i]
box.append([
float(xmin) / self.snippet_length,
float(xmax) / self.snippet_length
])
label.append(int(c))
total_label_ids.append(i)
total_boxes.append(box)
total_labels.append(label)
num_videos = len(fnames)
batch_out = []
for idx in range(num_videos):
fname = fnames[idx]
try:
if python_ver < (3, 0):
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')))
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')))
else:
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')),
encoding='bytes')
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')),
encoding='bytes')
data_flow = np.array(flow_pkl['scores'])
data_rgb = np.array(rgb_pkl['scores'])
if data_flow.shape[0] < data_rgb.shape[0]:
data_rgb = data_rgb[0:data_flow.shape[0], :]
elif data_flow.shape[0] > data_rgb.shape[0]:
data_flow = data_flow[0:data_rgb.shape[0], :]
feats = np.concatenate((data_rgb, data_flow), axis=1)
if feats.shape[0] == 0 or feats.shape[1] == 0:
feats = np.zeros((512, 1024), np.float32)
logger.info('### file loading len = 0 {} ###'.format(
fname))
boxes = copy.deepcopy(total_boxes[idx])
labels = copy.deepcopy(total_labels[idx])
feats, boxes, labels = self.transform(feats, boxes, labels,
self.mode)
labels = labels.astype('int64')
boxes = boxes.astype('float32')
num_pos = len(np.where(labels > 0)[0])
except:
logger.info('Error when loading {}'.format(fname))
continue
if (not (num_pos >= 1)) and (self.mode == 'train' or
self.mode == 'valid'):
#logger.info('=== no pos for {}, num_pos = {} ==='.format(fname, num_pos))
continue
if self.mode == 'train' or self.mode == 'valid':
batch_out.append((feats, boxes, labels))
elif self.mode == 'test':
batch_out.append(
(feats, boxes, labels, total_label_ids[idx]))
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
if len(batch_out) == self.batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
with open(self.filelist) as f:
fl = f.readlines()
if self.mode == 'train':
random.shuffle(fl)
n = self.num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
import datetime
import logging
import json
from models.ctcn.ctcn_utils import BoxCoder
logger = logging.getLogger(__name__)
def get_class_label(class_label_file):
class_label = open(class_label_file, 'r').readlines()
return class_label
def get_video_time_dict(video_duration_file):
video_time_dict = dict()
fps_file = open(video_duration_file, 'r').readlines()
for line in fps_file:
contents = line.split()
video_time_dict[contents[0]] = float(contents[-1])
return video_time_dict
class MetricsCalculator():
def __init__(self,
name='CTCN',
mode='train',
score_thresh=0.001,
nms_thresh=0.8,
sigma_thresh=0.8,
soft_thresh=0.006,
gt_label_file='',
class_label_file='',
video_duration_file=''):
self.name = name
self.mode = mode # 'train', 'val', 'test'
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sigma_thresh = sigma_thresh
self.soft_thresh = soft_thresh
self.class_label_file = class_label_file
self.video_duration_file = video_duration_file
if mode == 'test':
lines = open(gt_label_file).readlines()
self.gt_labels = [item.split(' ')[0] for item in lines]
self.box_coder = BoxCoder()
else:
self.gt_labels = None
self.box_coder = None
self.reset()
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
self.aggr_loss = 0.0
self.aggr_loc_loss = 0.0
self.aggr_cls_loss = 0.0
self.aggr_batch_size = 0
if self.mode == 'test':
self.class_label = get_class_label(self.class_label_file)
self.video_time_dict = get_video_time_dict(self.video_duration_file)
self.res_detect = dict()
self.res_detect["version"] = "VERSION 1.3"
self.res_detect["external_data"] = {
"uesd": False,
"details": "none"
}
self.results_detect = dict()
self.box_decode_params = {
'score_thresh': self.score_thresh,
'nms_thresh': self.nms_thresh,
'sigma_thresh': self.sigma_thresh,
'soft_thresh': self.soft_thresh
}
self.out_file = 'res_decode_' + str(self.score_thresh) + '_' + \
str(self.nms_thresh) + '_' + str(self.sigma_thresh) + \
'_' + str(self.soft_thresh)
def accumulate(self, loss, pred, label):
cur_batch_size = loss[0].shape[0]
self.aggr_loss += np.mean(np.array(loss[0]))
self.aggr_loc_loss += np.mean(np.array(loss[1]))
self.aggr_cls_loss += np.mean(np.array(loss[2]))
self.aggr_batch_size += cur_batch_size
if self.mode == 'test':
box_preds, label_preds, score_preds = self.box_coder.decode(
pred[0].squeeze(), pred[1].squeeze(), **self.box_decode_params)
fid = label[-1]
fname = self.gt_labels[fid]
logger.info("file {}, num of box preds {}:".format(fname,
len(box_preds)))
self.results_detect[fname] = []
for j in range(len(label_preds)):
results_detect[fname[0]].append({
"score": score_preds[j],
"label": self.class_label[label_preds[j]].strip(),
"segment": [
max(0, self.video_time_dict[fname] * box_preds[j][0] /
512.0), min(self.video_time_dict[fname],
self.video_time_dict[fname] *
box_preds[j][1] / 512.0)
]
})
def finalize_metrics(self):
self.avg_loss = self.aggr_loss / self.aggr_batch_size
self.avg_loc_loss = self.aggr_loc_loss / self.aggr_batch_size
self.avg_cls_loss = self.aggr_cls_loss / self.aggr_batch_size
if self.mode == 'test':
self.res_detect['results'] = self.results_detect
with open(self.out_file, 'w') as f:
json.dump(res_detect, f)
def get_computed_metrics(self):
json_stats = {}
json_stats['avg_loss'] = self.avg_loss
json_stats['avg_loc_loss'] = self.avg_loc_loss
json_stats['avg_cls_loss'] = self.avg_cls_loss
return json_stats
......@@ -23,6 +23,7 @@ import numpy as np
from metrics.youtube8m import eval_util as youtube8m_metrics
from metrics.kinetics import accuracy_metrics as kinetics_metrics
from metrics.multicrop_test import multicrop_test_metrics as multicrop_test_metrics
from metrics.detections import detection_metrics as detection_metrics
logger = logging.getLogger(__name__)
......@@ -160,6 +161,43 @@ class MulticropMetrics(Metrics):
self.calculator.reset()
class DetectionMetrics(Metrics):
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
args = {}
args['score_thresh'] = cfg.TEST.score_thresh
args['nms_thresh'] = cfg.TEST.nms_thresh
args['sigma_thresh'] = cfg.TEST.sigma_thresh
args['soft_thresh'] = cfg.TEST.soft_thresh
args['class_label_file'] = cfg.TEST.class_label_file
args['video_duration_file'] = cfg.TEST.video_duration_file
args['gt_label_file'] = cfg.TEST.filelist
args['mode'] = mode
args['name'] = name
self.calculator = detection_metrics.MetricsCalculator(**args)
def calculate_and_log_out(self, loss, pred, label, info=''):
logger.info(info +
'\tLoss = {}, \tloc_loss = {}, \tcls_loss = {}'.format(
np.mean(loss[0]), np.mean(loss[1]), np.mean(loss[2])))
def accumulate(self, loss, pred, label):
self.calculator.accumulate(loss, pred, label)
def finalize_and_log_out(self, info=''):
self.calculator.finalize_metrics()
metrics_dict = self.calculator.get_computed_metrics()
loss = metrics_dict['avg_loss']
loc_loss = metrics_dict['avg_loc_loss']
cls_loss = metrics_dict['avg_cls_loss']
logger.info(info + '\tLoss: {},\tloc_loss: {}, \tcls_loss: {}'.format('%.6f' % loss, \
'%.6f' % loc_loss, '%.6f' % cls_loss))
def reset(self):
self.calculator.reset()
class MetricsZoo(object):
def __init__(self):
self.metrics_zoo = {}
......@@ -196,3 +234,4 @@ regist_metrics("NONLOCAL", MulticropMetrics)
regist_metrics("TSM", Kinetics400Metrics)
regist_metrics("TSN", Kinetics400Metrics)
regist_metrics("STNET", Kinetics400Metrics)
regist_metrics("CTCN", DetectionMetrics)
......@@ -6,6 +6,7 @@ from .nonlocal_model import NonLocal
from .tsm import TSM
from .tsn import TSN
from .stnet import STNET
from .ctcn import CTCN
# regist models, sort by alphabet
regist_model("AttentionCluster", AttentionCluster)
......@@ -15,3 +16,4 @@ regist_model('NONLOCAL', NonLocal)
regist_model("TSM", TSM)
regist_model("TSN", TSN)
regist_model("STNET", STNET)
regist_model("CTCN", CTCN)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
from ..model import ModelBase
from . import fpn_ctcn
import logging
logger = logging.getLogger(__name__)
__all__ = ["CTCN"]
class CTCN(ModelBase):
"""C-TCN model"""
def __init__(self, name, cfg, mode='train'):
super(CTCN, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
self.img_size = self.get_config_from_sec('MODEL', 'img_size')
self.concept_size = self.get_config_from_sec('MODEL', 'concept_size')
self.num_classes = self.get_config_from_sec('MODEL', 'num_classes')
self.num_anchors = self.get_config_from_sec('MODEL', 'num_anchors')
self.total_num_anchors = self.get_config_from_sec('MODEL',
'total_num_anchors')
self.num_epochs = self.get_config_from_sec('train', 'epoch')
self.base_learning_rate = self.get_config_from_sec('train',
'learning_rate')
self.learning_rate_decay = self.get_config_from_sec(
'train', 'learning_rate_decay')
self.l2_weight_decay = self.get_config_from_sec('train',
'l2_weight_decay')
self.momentum = self.get_config_from_sec('train', 'momentum')
self.lr_decay_iter = self.get_config_from_sec('train', 'lr_decay_iter')
def build_input(self, use_pyreader=True):
image_shape = [1, self.img_size, self.concept_size]
loc_shape = [self.total_num_anchors, 2]
cls_shape = [self.total_num_anchors]
fileid_shape = [1]
self.use_pyreader = use_pyreader
# set init data to None
py_reader = None
image = None
loc_targets = None
cls_targets = None
fileid = None
if use_pyreader:
assert self.mode != 'infer', \
'pyreader is not recommendated when infer, please set use_pyreader to be false.'
if (self.mode == 'train') or (self.mode == 'valid'):
py_reader = fluid.layers.py_reader(
capacity=100,
shapes=[[-1] + image_shape, [-1] + loc_shape,
[-1] + cls_shape],
dtypes=['float32', 'float32', 'int64'],
name='train_py_reader'
if self.is_training else 'test_py_reader',
use_double_buffer=True)
image, loc_targets, cls_targets = fluid.layers.read_file(
py_reader)
elif self.mode == 'test':
py_reader = fluid.layers.py_reader(
capacity=100,
shapes=[[-1] + image_shape, [-1] + loc_shape, [-1] +
cls_shape] + [-1, 1],
dtypes=['float32', 'float32', 'int64', 'int64'],
use_double_buffer=True)
image, loc_targets, cls_targets, fileid = fluid.layers.read_file(
pyreader)
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
self.py_reader = py_reader
else:
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
if (self.mode == 'train') or (self.mode == 'valid'):
loc_targets = fluid.layers.data(
name='loc_targets', shape=loc_shape, dtype='float32')
cls_targets = fluid.layers.data(
name='cls_targets', shape=cls_shape, dtype='int64')
elif self.mode == 'test':
loc_targets = fluid.layers.data(
name='loc_targets', shape=loc_shape, dtype='float32')
cls_targets = fluid.layers.data(
name='cls_targets', shape=cls_shape, dtype='int64')
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
elif self.mode == 'infer':
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
self.feature_input = [image]
self.cls_targets = cls_targets
self.loc_targets = loc_targets
self.fileid = fileid
def create_model_args(self):
cfg = {}
cfg['num_anchors'] = self.num_anchors
cfg['concept_size'] = self.concept_size
cfg['num_classes'] = self.num_classes
return cfg
def build_model(self):
cfg = self.create_model_args()
self.videomodel = fpn_ctcn.FPNCTCN(
num_anchors=cfg['num_anchors'],
concept_size=cfg['concept_size'],
num_classes=cfg['num_classes'],
mode=self.mode)
loc_preds, cls_preds = self.videomodel.net(input=self.feature_input[0])
self.network_outputs = [loc_preds, cls_preds]
def optimizer(self):
bd = [self.lr_decay_iter]
base_lr = self.base_learning_rate
lr_decay = self.learning_rate_decay
lr = [base_lr, base_lr * lr_decay]
l2_weight_decay = self.l2_weight_decay
momentum = self.momentum
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=momentum,
regularization=fluid.regularizer.L2Decay(l2_weight_decay))
return optimizer
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
self.loss_ = self.videomodel.loss(self.network_outputs[0],
self.network_outputs[1],
self.loc_targets, self.cls_targets)
return self.loss_
def outputs(self):
loc_preds = self.network_outputs[0]
cls_preds = fluid.layers.softmax(self.network_outputs[1])
return [loc_preds, cls_preds]
def feeds(self):
if (self.mode == 'train') or (self.mode == 'valid'):
return self.feature_input + [self.loc_targets, self.cls_targets]
elif self.mode == 'test':
return self.feature_input + [
self.loc_targets, self.cls_targets, self.fileid
]
elif self.mode == 'infer':
return self.feature_input + [self.fileid]
else:
raise NotImplemented
def pretrain_info(self):
return (None, None)
def weights_info(self):
return (None, None)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
from paddle.fluid.initializer import Uniform
# This file includes initializer, box encode, box decode
# initializer
def get_ctcn_conv_initializer(x, filter_size):
c_in = x.shape[1]
if isinstance(filter_size, int):
fan_in = c_in * filter_size * filter_size
else:
fan_in = c_in * filter_size[0] * filter_size[1]
std = np.sqrt(1.0 / fan_in)
return Uniform(0. - std, std)
#box tools
def box_clamp1D(boxes, xmin, xmax):
'''Clamp boxes.
Args:
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [N,2].
xmin: (number) min value of x.
xmax: (number) max value of x.
'''
np.clip(boxes[:, 0], xmin, xmax, out=boxes[:, 0])
np.clip(boxes[:, 1], xmin, xmax, out=boxes[:, 1])
return boxes
def box_iou1D(box1, box2):
'''Compute the intersection over union of two set of boxes.
The box order must be (xmin, xmax).
Args:
box1: (tensor) bounding boxes, sized [N,2].
box2: (tensor) bounding boxes, sized [M,2].
Return:
(tensor) iou, sized [N,M].
'''
box1 = np.array(box1)
box2 = np.array(box2)
N = box1.shape[0]
M = box2.shape[0]
left = np.maximum(box1[:, None, 0], box2[:, 0])
right = np.minimum(box1[:, None, 1], box2[:, 1])
inter = (right - left).clip(min=0)
area1 = np.abs(box1[:, 0] - box1[:, 1])
area2 = np.abs(box2[:, 0] - box2[:, 1])
iou = inter / (area1[:, None] + area2 - inter)
return iou
def change_box_order(boxes, order):
assert order in ['yy2yh', 'yh2yy']
a = boxes[:, 0, None]
b = boxes[:, 1, None]
if order == 'yy2yh':
return np.concatenate(((a + b) / 2, b - a), axis=1)
return np.concatenate((a - b / 2, a + b / 2), axis=1)
def box_nms(bboxes, scores, threshold=0.5, mode='union'):
'''Non maximum suppression.
Args:
bboxes: (tensor) bounding boxes, sized [N,2].
scores: (tensor) confidence scores, sized [N,].
threshold: (float) overlap threshold.
mode: (str) 'union' or 'min'.
Returns:
keep: (tensor) selected indices.
Reference:
https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py
'''
y1 = bboxes[:, 0]
y2 = bboxes[:, 1]
areas = (y2 - y1)
order = np.argsort(-scores, axis=0)
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if order.size == 1:
break
yy1 = np.clip(y1[order[1:]], y1[i], None)
yy2 = np.clip(y2[order[1:]], None, y2[i])
h = np.clip(yy2 - yy1, 0, None)
inter = h
if mode == 'union':
ovr = inter / (areas[i] + areas[order[1:]] - inter)
elif mode == 'min':
ovr = inter / np.clip(areas[order[1:]], None, areas[i])
else:
raise TypeError('Unknown nms mode: %s.' % mode)
ids = (ovr <= threshold).nonzero()[0]
if ids.size == 0:
break
order = order[ids + 1]
return np.array(keep, dtype='int64')
def soft_nms(props, method=0, sigma=1., Nt=0.7, threshold=0.001):
'''
param dets: dection results, 2 dims [N, 3]
param props: predicted scores
'''
N = props.shape[0]
for i in range(N):
maxscore = props[i, 2]
maxpos = i
tx = props[i, 0]
ty = props[i, 1]
ts = props[i, 2]
pos = i + 1
while pos < N:
if maxscore < props[pos, 2]:
maxscore = props[pos, 2]
maxpos = pos
pos += 1
props[i, 0] = props[maxpos, 0]
props[i, 1] = props[maxpos, 1]
props[i, 2] = props[maxpos, 2]
props[maxpos, 0] = tx
props[maxpos, 1] = ty
props[maxpos, 2] = ts
tx = props[i, 0]
ty = props[i, 1]
ts = props[i, 2]
pos = i + 1
while pos < N:
x = props[pos, 0]
y = props[pos, 1]
s = props[pos, 2]
max_begin = max(x, tx)
min_end = min(y, ty)
inter = max(0.0, min_end - max_begin)
overlap = inter / (y - x + ty - tx - inter)
if method == 1:
if overlap > Nt:
weight = 1 - overlap
else:
weight = 1
elif method == 2:
weight = np.exp(-(overlap**2) / sigma)
else:
if overlap > Nt:
weight = 0
else:
weight = 1
props[pos, 2] = weight * props[pos, 2]
if props[pos, 2] < threshold:
props[pos, 0] = props[N - 1, 0]
props[pos, 1] = props[N - 1, 1]
props[pos, 2] = props[N - 1, 2]
N -= 1
pos -= 1
pos += 1
keep = [i for i in range(N)]
return props[keep]
# box encode and decode
class BoxCoder():
def __init__(self):
self.steps = (4, 8, 16, 32, 64, 128, 256, 512)
self.fm_sizes = (128, 64, 32, 16, 8, 4, 2, 1)
self.anchor_num = 3
self.default_boxes = self._get_default_boxes()
def _get_default_boxes(self):
boxes = []
for i, fm_size in enumerate(self.fm_sizes):
for h in range(fm_size):
cy = (h + 0.5) * self.steps[i]
base_s = self.steps[i]
boxes.append((cy, base_s))
for p in range(self.anchor_num):
s = (base_s * 4.5 / 15.0) * (1.0 + p) / self.anchor_num
boxes.append((cy, base_s - s))
if base_s == 512:
step_s = (base_s * 4.5 / 15.0) / (2 * self.anchor_num)
boxes.append((cy, base_s - s - step_s))
else:
boxes.append((cy, base_s + s))
return np.array(boxes)
def encode(self, boxes, labels):
def argmax(x):
v = x.max(0) # sort by cols, max_v, index
i = np.argmax(x, 0)
j = np.argmax(v, 0) # v.max(0)[1][0] # sort v, index
return (i[j], j) # return max index (row,col)
labels = np.array(labels)
default_boxes = self.default_boxes
default_boxes = change_box_order(default_boxes, 'yh2yy')
ious = box_iou1D(default_boxes, boxes) # [#anchors, #obj]
index = np.full(len(default_boxes), fill_value=-1, dtype='int64')
masked_ious = ious.copy()
while True:
i, j = argmax(masked_ious)
if masked_ious[i, j] < 1e-6:
break
index[i] = j
masked_ious[i, :] = 0
masked_ious[:, j] = 0
mask = (index < 0) & (ious.max(1) >= 0.5)
if mask.any():
if np.squeeze(mask.nonzero()).size > 1:
index[mask] = np.argmax(ious[np.squeeze(mask.nonzero())], 1)
boxes = boxes[np.clip(index, a_min=0, a_max=None)]
boxes = change_box_order(boxes, 'yy2yh')
default_boxes = change_box_order(default_boxes, 'yy2yh')
variances = (0.1, 0.2)
loc_xy = (boxes[:, 0, None] - default_boxes[:, 0, None]
) / default_boxes[:, 1, None] / variances[0]
loc_wh = (
boxes[:, 1, None] / default_boxes[:, 1, None] - 1.0) / variances[1]
loc_targets = np.concatenate((loc_xy, loc_wh), axis=1)
cls_targets = labels[index.clip(0, None)]
cls_targets[index < 0] = 0
return loc_targets, cls_targets
def decode(self,
loc_preds,
cls_preds,
score_thresh=0.6,
nms_thresh=0.45,
sigma_thresh=1.0,
soft_thresh=0.01):
'''Decode predicted loc/cls back to real box locations and class labels.
Args:
loc_preds: (tensor) predicted loc, sized [8732,2].
cls_preds: (tensor) predicted conf, sized [8732,201].
score_thresh: (float) threshold for object confidence score.
nms_thresh: (float) threshold for box nms.
Returns:
boxes: (tensor) bbox locations, sized [#obj,2].
labels: (tensor) class labels, sized [#obj,].
'''
variances = (0.1, 0.2)
y = loc_preds[:, 0, None] * variances[
0] * self.default_boxes[:, 1, None] + self.default_boxes[:, 0, None]
h = (loc_preds[:, 1, None] * variances[1] + 1.0
) * self.default_boxes[:, 1, None]
box_preds = np.concatenate((y - h / 2.0, y + h / 2.0), axis=1)
boxes = []
labels = []
scores = []
num_classes = cls_preds.shape[1]
max_num = -1
max_id = -1
for i in range(num_classes - 1):
score = cls_preds[:, i + 1]
mask = score > score_thresh
if not mask.any():
continue
box = box_preds[mask]
score = score[mask]
if len(score) > max_num:
max_num = len(score)
max_id = i
keep = box_nms(box, score, nms_thresh)
box = box[keep]
score = score[keep]
now_vector = np.concatenate((box, score[:, None]), axis=1)
res = soft_nms(
now_vector, method=2, sigma=sigma_thresh, threshold=soft_thresh)
final_box = res[:, :2]
final_score = res[:, 2]
boxes.append(final_box)
labels.append(np.full(len(final_box), fill_value=i, dtype='int64'))
scores.append(final_score)
if len(boxes) == 0:
boxes.append(np.array([[0, 1.0]], dtype='float32'))
labels.append(np.full(1, fill_value=1, dtype='int64'))
scores.append(np.full(1, fill_value=1, dtype='float32'))
boxes = np.concatenate(boxes, 0)
labels = np.concatenate(labels, 0)
scores = np.concatenate(scores, 0)
return boxes, labels, scores
#coding=UTF-8
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
import numpy as np
from .ctcn_utils import get_ctcn_conv_initializer as get_init
DATATYPE = 'float32'
class FPNCTCN(object):
def __init__(self, num_anchors, concept_size, num_classes, mode='train'):
self.num_anchors = num_anchors
self.concept_size = concept_size
self.num_classes = num_classes
self.is_training = (mode == 'train')
def conv_bn_layer(self,
input,
ch_out,
filter_size,
stride=1,
padding=0,
act='relu'):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(initializer=get_init(input, filter_size)),
bias_attr=False)
return fluid.layers.batch_norm(
input=conv,
act=act,
is_test=(not self.is_training), )
def shortcut(self, input, planes, stride):
if (input.shape[1] == planes * 4) and (stride == 1):
return input
else:
return self.conv_bn_layer(input, planes * 4, 1, stride, act=None)
def bottleneck_block(self, input, planes, stride=1):
conv0 = self.conv_bn_layer(input, planes, filter_size=1)
conv1 = self.conv_bn_layer(
conv0, planes, filter_size=(3, 1), stride=stride, padding=(1, 0))
conv2 = self.conv_bn_layer(conv1, planes * 4, filter_size=1, act=None)
short = self.shortcut(input, planes, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def layer_warp(self, input, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
for stride in strides:
input = self.bottleneck_block(input, planes, stride)
return input
def upsample_add(self, x, y):
_, _, H, W = y.shape
upsample = fluid.layers.image_resize(
x, out_shape=[H, W], resample='BILINEAR')
return upsample + y
def extractor(self, input):
num_blocks = [3, 4, 6, 3]
c1 = self.conv_bn_layer(
input, ch_out=32, filter_size=(7, 1), stride=(2, 1), padding=(3, 0))
c1 = self.conv_bn_layer(
c1, ch_out=64, filter_size=(7, 1), stride=(2, 1), padding=(3, 0))
c2 = self.layer_warp(c1, 64, num_blocks[0], 1)
c3 = self.layer_warp(c2, 128, num_blocks[1], (2, 1))
c4 = self.layer_warp(c3, 256, num_blocks[2], (2, 1))
c5 = self.layer_warp(c4, 512, num_blocks[3], (2, 1))
#feature pyramid
p6 = fluid.layers.conv2d(
c5,
num_filters=512,
filter_size=(3, 1),
stride=(2, 1),
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(c5, (3, 1))))
p7 = fluid.layers.relu(p6)
p7 = fluid.layers.conv2d(
p7,
num_filters=512,
filter_size=(3, 1),
stride=(2, 1),
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(p7, (3, 1))))
p8 = fluid.layers.relu(p7)
p8 = fluid.layers.conv2d(
p8,
num_filters=512,
filter_size=(3, 1),
stride=(2, 1),
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(p8, (3, 1))))
p9 = fluid.layers.relu(p8)
p9 = fluid.layers.conv2d(
p9,
num_filters=512,
filter_size=(3, 1),
stride=(2, 1),
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(p9, (3, 1))))
#top_down
p5 = fluid.layers.conv2d(
c5,
512,
1,
1,
0,
param_attr=ParamAttr(initializer=get_init(c5, 1)), )
p4 = self.upsample_add(
p5,
fluid.layers.conv2d(
c4,
512,
1,
1,
0,
param_attr=ParamAttr(initializer=get_init(c4, 1)), ))
p3 = self.upsample_add(
p4,
fluid.layers.conv2d(
c3,
512,
1,
1,
0,
param_attr=ParamAttr(initializer=get_init(c3, 1)), ))
p2 = self.upsample_add(
p3,
fluid.layers.conv2d(
c2,
512,
1,
1,
0,
param_attr=ParamAttr(initializer=get_init(c2, 1))))
#smooth
p4 = fluid.layers.conv2d(
p4,
num_filters=512,
filter_size=(3, 1),
stride=1,
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(p4, (3, 1))), )
p3 = fluid.layers.conv2d(
p3,
num_filters=512,
filter_size=(3, 1),
stride=1,
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(p3, (3, 1))), )
p2 = fluid.layers.conv2d(
p2,
num_filters=512,
filter_size=(3, 1),
stride=1,
padding=(1, 0),
param_attr=ParamAttr(initializer=get_init(p2, (3, 1))), )
return p2, p3, p4, p5, p6, p7, p8, p9
def net(self, input):
fm_sizes = self.concept_size # 402
num_anchors = self.num_anchors # 7
loc_preds = []
cls_preds = []
# build fpn network
xs = self.extractor(input)
# build predict head
for i, x in enumerate(xs):
loc_pred = fluid.layers.dropout(
x, dropout_prob=0.5, is_test=(not self.is_training))
loc_pred = fluid.layers.conv2d(
loc_pred,
num_filters=256,
filter_size=(3, 1),
stride=1,
padding=(1, 0),
param_attr=ParamAttr(
name='loc_pred_conv1_weights',
initializer=get_init(loc_pred, (3, 1))),
bias_attr=ParamAttr(
name='loc_pred_conv1_bias', ))
loc_pred = fluid.layers.conv2d(
loc_pred,
num_filters=num_anchors * 2,
filter_size=(1, fm_sizes),
stride=1,
padding=0,
param_attr=ParamAttr(
name='loc_pred_conv2_weights',
initializer=get_init(loc_pred, (1, fm_sizes))),
bias_attr=ParamAttr(
name='loc_pred_conv2_bias', ))
loc_pred = 10.0 * fluid.layers.sigmoid(loc_pred) - 5.0
loc_pred = fluid.layers.transpose(loc_pred, perm=[0, 2, 3, 1])
tmp_size1 = loc_pred.shape[1] * loc_pred.shape[2] * loc_pred.shape[
3] // 2
loc_pred = fluid.layers.reshape(
x=loc_pred, shape=[loc_pred.shape[0], tmp_size1, 2])
loc_preds.append(loc_pred)
cls_pred = fluid.layers.dropout(
x, dropout_prob=0.5, is_test=(not self.is_training))
cls_pred = fluid.layers.conv2d(
cls_pred,
num_filters=512,
filter_size=(3, 1),
stride=1,
padding=(1, 0),
param_attr=ParamAttr(
name='cls_pred_conv1_weights',
initializer=get_init(cls_pred, (3, 1))),
bias_attr=ParamAttr(
name='cls_pred_conv1_bias', ))
cls_pred = fluid.layers.conv2d(
cls_pred,
num_filters=num_anchors * self.num_classes,
filter_size=(1, fm_sizes),
stride=1,
padding=0,
param_attr=ParamAttr(
name='cls_pred_conv2_weights',
initializer=get_init(cls_pred, (1, fm_sizes))),
bias_attr=ParamAttr(
name='cls_pred_conv2_bias', ))
cls_pred = fluid.layers.transpose(cls_pred, perm=[0, 2, 3, 1])
tmp_size2 = cls_pred.shape[1] * cls_pred.shape[2] * cls_pred.shape[
3] // self.num_classes
cls_pred = fluid.layers.reshape(
x=cls_pred,
shape=[cls_pred.shape[0], tmp_size2, self.num_classes])
cls_preds.append(cls_pred)
loc_preds = fluid.layers.concat(input=loc_preds, axis=1)
cls_preds = fluid.layers.concat(input=cls_preds, axis=1)
return loc_preds, cls_preds
def hard_negative_mining(self, cls_loss, pos_bool):
pos = fluid.layers.cast(pos_bool, dtype=DATATYPE)
cls_loss = cls_loss * (pos - 1)
_, indices = fluid.layers.argsort(cls_loss, axis=1)
indices = fluid.layers.cast(indices, dtype=DATATYPE)
_, rank = fluid.layers.argsort(indices, axis=1)
num_neg = 3 * fluid.layers.reduce_sum(pos, dim=1)
num_neg = fluid.layers.reshape(x=num_neg, shape=[-1, 1])
neg = rank < num_neg
return neg
def loss(self, loc_preds, cls_preds, loc_targets, cls_targets):
"""
param loc_targets: [N, 1785,2]
param cls_targets: [N, 1785]
"""
loc_targets.stop_gradient = True
cls_targets.stop_gradient = True
pos = cls_targets > 0
pos_bool = pos
pos = fluid.layers.cast(pos, dtype=DATATYPE)
num_pos = fluid.layers.reduce_sum(pos)
pos = fluid.layers.unsqueeze(pos, axes=[2])
mask = fluid.layers.expand(pos, expand_times=[1, 1, 2])
mask.stop_gradient = True
loc_loss = fluid.layers.smooth_l1(
loc_preds, loc_targets, inside_weight=mask, outside_weight=mask)
loc_loss = fluid.layers.reduce_sum(loc_loss)
cls_loss = fluid.layers.softmax_with_cross_entropy(
logits=fluid.layers.reshape(
cls_preds, shape=[-1, self.num_classes]),
label=fluid.layers.reshape(
cls_targets, shape=[-1, 1]),
numeric_stable_mode=True)
cls_loss = fluid.layers.reshape(
cls_loss, shape=[-1, loc_targets.shape[1]])
not_ignore = cls_targets >= 0
not_ignore = fluid.layers.cast(not_ignore, dtype=DATATYPE)
not_ignore.stop_gradient = True
cls_loss = cls_loss * not_ignore
neg = self.hard_negative_mining(cls_loss, pos_bool)
neg = fluid.layers.cast(neg, dtype='bool')
pos_bool = fluid.layers.cast(pos_bool, dtype='bool')
selects = fluid.layers.logical_or(pos_bool, neg)
selects = fluid.layers.cast(selects, dtype=DATATYPE)
selects.stop_gradient = True
cls_loss = cls_loss * selects
cls_loss = fluid.layers.reduce_sum(cls_loss)
alpha = 2.0
loss = (alpha * loc_loss + cls_loss) / num_pos
num_pos.stop_gradient = True
return loss, alpha * loc_loss / num_pos, cls_loss / num_pos
......@@ -20,8 +20,6 @@ except:
from ConfigParser import ConfigParser
import paddle.fluid as fluid
from datareader import get_reader
from metrics import get_metrics
from .utils import download, AttrDict
WEIGHT_DIR = os.path.expanduser("~/.paddle/weights")
......@@ -68,7 +66,6 @@ class ModelBase(object):
self.cfg = cfg
self.py_reader = None
def build_model(self):
"build model struct"
raise NotImplementError(self, self.build_model)
......
python test.py --model_name="CTCN" --config=./configs/ctcn.txt \
--log_interval=10 --weights=./checkpoints/CTCN_epoch0
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#export CUDA_VISIBLE_DEVICES=0
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=1.0
python train.py --model_name="CTCN" --config=./configs/ctcn.txt --epoch=35 \
--valid_interval=1 --log_interval=1
......@@ -37,13 +37,26 @@ def test_without_pyreader(test_exe,
test_feeder,
test_fetch_list,
test_metrics,
log_interval=0):
log_interval=0,
save_model_name=''):
test_metrics.reset()
for test_iter, data in enumerate(test_reader()):
test_outs = test_exe.run(test_fetch_list, feed=test_feeder.feed(data))
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
if save_model_name in ['CTCN']:
# for detection
total_loss = np.array(test_outs[0])
loc_loss = np.array(test_outs[1])
cls_loss = np.array(test_outs[2])
loc_preds = np.array(test_outs[3])
cls_preds = np.array(test_outs[4])
label = np.array(test_outs[-1])
loss = [total_loss, loc_loss, cls_loss]
pred = [loc_preds, cls_preds]
else:
# for classification
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
test_metrics.accumulate(loss, pred, label)
if log_interval > 0 and test_iter % log_interval == 0:
test_metrics.calculate_and_log_out(loss, pred, label, \
......@@ -55,7 +68,8 @@ def test_with_pyreader(test_exe,
test_pyreader,
test_fetch_list,
test_metrics,
log_interval=0):
log_interval=0,
save_model_name=''):
if not test_pyreader:
logger.error("[TEST] get pyreader failed.")
test_pyreader.start()
......@@ -64,9 +78,20 @@ def test_with_pyreader(test_exe,
try:
while True:
test_outs = test_exe.run(fetch_list=test_fetch_list)
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
if save_model_name in ['CTCN']:
# for detection
total_loss = np.array(test_outs[0])
loc_loss = np.array(test_outs[1])
cls_loss = np.array(test_outs[2])
loc_preds = np.array(test_outs[3])
cls_preds = np.array(test_outs[4])
label = np.array(test_outs[-1])
loss = [total_loss, loc_loss, cls_loss]
pred = [loc_preds, cls_preds]
else:
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
test_metrics.accumulate(loss, pred, label)
if log_interval > 0 and test_iter % log_interval == 0:
test_metrics.calculate_and_log_out(loss, pred, label, \
......@@ -92,9 +117,21 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede
feed=train_feeder.feed(data))
period = time.time() - cur_time
epoch_periods.append(period)
loss = np.array(train_outs[0])
pred = np.array(train_outs[1])
label = np.array(train_outs[-1])
if save_model_name in ['CTCN']:
# detection model
total_loss = np.array(train_outs[0])
loc_loss = np.array(train_outs[1])
cls_loss = np.array(train_outs[2])
loc_preds = np.array(train_outs[3])
cls_preds = np.array(train_outs[4])
label = np.array(train_outs[-1])
loss = [total_loss, loc_loss, cls_loss]
pred = [loc_preds, cls_preds]
else:
# classification model
loss = np.array(train_outs[0])
pred = np.array(train_outs[1])
label = np.array(train_outs[-1])
if log_interval > 0 and (train_iter % log_interval == 0):
# eval here
train_metrics.calculate_and_log_out(loss, pred, label, \
......@@ -107,8 +144,8 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede
if test_exe and valid_interval > 0 and (epoch + 1
) % valid_interval == 0:
test_without_pyreader(test_exe, test_reader, test_feeder,
test_fetch_list, test_metrics, log_interval)
test_fetch_list, test_metrics, log_interval,
save_model_name)
def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
......@@ -133,9 +170,21 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
train_outs = train_exe.run(fetch_list=train_fetch_list)
period = time.time() - cur_time
epoch_periods.append(period)
loss = np.array(train_outs[0])
pred = np.array(train_outs[1])
label = np.array(train_outs[-1])
if save_model_name in ['CTCN']:
# for detection
total_loss = np.array(train_outs[0])
loc_loss = np.array(train_outs[1])
cls_loss = np.array(train_outs[2])
loc_preds = np.array(train_outs[3])
cls_preds = np.array(train_outs[4])
label = np.array(train_outs[-1])
loss = [total_loss, loc_loss, cls_loss]
pred = [loc_preds, cls_preds]
else:
# for classification
loss = np.array(train_outs[0])
pred = np.array(train_outs[1])
label = np.array(train_outs[-1])
if log_interval > 0 and (train_iter % log_interval == 0):
# eval here
train_loss = train_metrics.calculate_and_log_out(loss, pred, label, \
......@@ -150,7 +199,7 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
if test_exe and valid_interval > 0 and (epoch + 1
) % valid_interval == 0:
test_with_pyreader(test_exe, test_pyreader, test_fetch_list,
test_metrics, log_interval)
test_metrics, log_interval, save_model_name)
finally:
epoch_period = []
train_pyreader.reset()
......
......@@ -130,11 +130,20 @@ def train(args):
train_feeds = train_model.feeds()
train_feeds[-1].persistable = True
# for the output of classification model, has the form [pred]
# for the output of detection model, has the form [loc_pred, cls_pred]
train_outputs = train_model.outputs()
for output in train_outputs:
output.persistable = True
train_loss = train_model.loss()
train_loss.persistable = True
train_losses = train_model.loss()
if isinstance(train_losses, list) or isinstance(train_losses,
tuple):
# for detection model, train_losses has the form [total_loss, loc_loss, cls_loss]
train_loss = train_losses[0]
for item in train_losses:
item.persistable = True
else:
train_loss = train_losses
train_loss.persistable = True
# outputs, loss, label should be fetched, so set persistable to be true
optimizer = train_model.optimizer()
optimizer.minimize(train_loss)
......@@ -146,8 +155,10 @@ def train(args):
valid_model.build_input(not args.no_use_pyreader)
valid_model.build_model()
valid_feeds = valid_model.feeds()
# for the output of classification model, has the form [pred]
# for the output of detection model, has the form [loc_pred, cls_pred]
valid_outputs = valid_model.outputs()
valid_loss = valid_model.loss()
valid_losses = valid_model.loss()
valid_pyreader = valid_model.pyreader()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......@@ -175,6 +186,8 @@ def train(args):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
if args.model_name in ['CTCN']:
build_strategy.enable_sequential_execution = True
#build_strategy.memory_optimize = True
train_exe = fluid.ParallelExecutor(
......@@ -202,10 +215,20 @@ def train(args):
train_metrics = get_metrics(args.model_name.upper(), 'train', train_config)
valid_metrics = get_metrics(args.model_name.upper(), 'valid', valid_config)
train_fetch_list = [train_loss.name] + [x.name for x in train_outputs
] + [train_feeds[-1].name]
valid_fetch_list = [valid_loss.name] + [x.name for x in valid_outputs
] + [valid_feeds[-1].name]
if isinstance(train_losses, tuple) or isinstance(train_losses, list):
# for detection
train_fetch_list = [item.name for item in train_losses] + \
[x.name for x in train_outputs] + [train_feeds[-1].name]
valid_fetch_list = [item.name for item in valid_losses] + \
[x.name for x in valid_outputs] + [valid_feeds[-1].name]
else:
# for classification
train_fetch_list = [train_losses.name] + [
x.name for x in train_outputs
] + [train_feeds[-1].name]
valid_fetch_list = [valid_losses.name] + [
x.name for x in valid_outputs
] + [valid_feeds[-1].name]
epochs = args.epoch or train_model.epoch_num()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册