diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index 9d39903b53844a9143fa83fd3d39315bf50ea126..9c2a0872bffd70006885c79376963146fda16b27 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -2,13 +2,6 @@ ## Installation -### Prerequisites - - - **Python = 2.7** only supported; - - **cuDNN >= 6.0** is required to utilize NVIDIA GPU platform in the installation of PaddlePaddle, and the **CUDA toolkit** with proper version suitable for cuDNN. The cuDNN library below 6.0 is found to yield a fatal error in batch normalization when handling utterances with long duration in inference. - -### Setup - ``` sh setup.sh export LD_LIBRARY_PATH=$PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib:$LD_LIBRARY_PATH diff --git a/deep_speech_2/data_utils/featurizer/audio_featurizer.py b/deep_speech_2/data_utils/featurizer/audio_featurizer.py index 00f0e8a35bc8e67ab285b7d509a0992c02dc54ca..f0d223cfbe8bbae039de84fbbffcf0cd3975b790 100644 --- a/deep_speech_2/data_utils/featurizer/audio_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/audio_featurizer.py @@ -159,24 +159,27 @@ class AudioFeaturizer(object): if max_freq is None: max_freq = sample_rate / 2 if max_freq > sample_rate / 2: - raise ValueError("max_freq must be greater than half of " + raise ValueError("max_freq must not be greater than half of " "sample rate.") if stride_ms > window_ms: raise ValueError("Stride size must not be greater than " "window size.") - # compute 13 cepstral coefficients, and the first one is replaced + # compute the 13 cepstral coefficients, and the first one is replaced # by log(frame energy) - mfcc_feat = np.transpose( - mfcc( - signal=samples, - samplerate=sample_rate, - winlen=0.001 * window_ms, - winstep=0.001 * stride_ms, - highfreq=max_freq)) + mfcc_feat = mfcc( + signal=samples, + samplerate=sample_rate, + winlen=0.001 * window_ms, + winstep=0.001 * stride_ms, + highfreq=max_freq) # Deltas d_mfcc_feat = delta(mfcc_feat, 2) # Deltas-Deltas dd_mfcc_feat = delta(d_mfcc_feat, 2) + # transpose + mfcc_feat = np.transpose(mfcc_feat) + d_mfcc_feat = np.transpose(d_mfcc_feat) + dd_mfcc_feat = np.transpose(dd_mfcc_feat) # concat above three features concat_mfcc_feat = np.concatenate( (mfcc_feat, d_mfcc_feat, dd_mfcc_feat)) diff --git a/ssd/README.md b/ssd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9467c4385b02bc9e11058e55568903808fc48473 --- /dev/null +++ b/ssd/README.md @@ -0,0 +1,226 @@ +# SSD目标检测 +## 概述 +SSD全称为Single Shot MultiBox Detector,是目标检测领域较新且效果较好的检测算法之一,具体参见论文\[[1](#引用)\]。SSD算法主要特点是检测速度快且检测精度高。PaddlePaddle已集成SSD算法,本示例旨在介绍如何使用PaddlePaddle中的SSD模型进行目标检测。下文展开顺序为:首先简要介绍SSD原理,然后介绍示例包含文件及作用,接着介绍如何在PASCAL VOC数据集上训练、评估及检测,最后简要介绍如何在自有数据集上使用SSD。 +## SSD原理 +SSD使用一个卷积神经网络实现“端到端”的检测,所谓“端到端”指输入为原始图像,输出为检测结果,无需借助外部工具或流程进行特征提取、候选框生成等。论文中SSD的基础模型为VGG16\[[2](#引用)\],不同于原始VGG16网络模型,SSD做了一些改变: + +1. 将最后的fc6、fc7全连接层变为卷积层,卷积层参数通过对原始fc6、fc7参数采样得到。 +2. 将pool5层的参数由2x2-s2(kernel大小为2x2,stride size为2)更改为3x3-s1-p1(kernel大小为3x3,stride size为1,padding size为1)。 +3. 在conv4\_3、conv7、conv8\_2、conv9\_2、conv10\_2及pool11层后面接了priorbox层,priorbox层的主要目的是根据输入的特征图(feature map)生成一系列的矩形候选框。关于SSD的更详细的介绍可以参考论文\[[1](#引用)\]。 + +下图为模型(300x300)的总体结构: + +

+
+图1. SSD网络结构 +

+ +图中每个矩形盒子代表一个卷积层,最后的两个矩形框分别表示汇总各卷积层输出结果和后处理阶段。具体地,在预测阶段网络会输出一组候选矩形框,每个矩形包含两类信息:位置和类别得分,图中倒数第二个矩形框即表示网络的检测结果的汇总处理,由于候选矩形框数量较多且很多矩形框重叠严重,这时需要经过后处理来筛选出质量较高的少数矩形框,这里的后处理主要指非极大值抑制(Non-maximum Suppression)。 + +从SSD的网络结构可以看出,候选矩形框在多个特征图(feature map上)生成,不同的feature map具有的感受野不同,这样可以在不同尺度扫描图像,相对于其他检测方法可以生成更丰富的候选框,从而提高检测精度;另一方面SSD对VGG16的扩展部分以较小的代价实现对候选框的位置和类别得分的计算,整个过程只需要一个卷积神经网络完成,所以速度较快。 + +## 示例总览 +本示例共包含如下文件: + + + + + + + + + + + + + +
表1. 示例文件
文件用途
train.py训练脚本
eval.py评估脚本,用于评估训好模型
infer.py检测脚本,给定图片及模型,实施检测
visual.py检测结果可视化
image_util.py图像预处理所需公共函数
data_provider.py数据处理脚本,生成训练、评估或检测所需数据
config/pascal_voc_conf.py神经网络超参数配置文件
data/label_list类别列表
data/prepare_voc_data.py准备训练PASCAL VOC数据列表
+ +训练阶段需要对数据做预处理,包括裁剪、采样等,这部分操作在```image_util.py```和```data_provider.py```中完成。值得注意的是,```config/vgg_config.py```为参数配置文件,包括训练参数、神经网络参数等,本配置文件包含参数是针对PASCAL VOC数据配置的,当训练自有数据时,需要仿照该文件配置新的参数。```data/prepare_voc_data.py```脚本用来生成文件列表,包括切分训练集和测试集,使用时需要用户事先下载并解压数据,默认采用VOC2007和VOC2012。 + +## PASCAL VOC数据集 +### 数据准备 +首先需要下载数据集,VOC2007\[[3](#引用)\]和VOC2012\[[4](#引用)\],VOC2007包含训练集和测试集,VOC2012只包含训练集,将下载好的数据解压,目录结构为```data/VOCdevkit/VOC2007```和```data/VOCdevkit/VOC2012```。进入```data```目录,运行```python prepare_voc_data.py```即可生成```trainval.txt```和```test.txt```。核心函数为: + +```python +def prepare_filelist(devkit_dir, years, output_dir): + trainval_list = [] + test_list = [] + for year in years: + trainval, test = walk_dir(devkit_dir, year) + trainval_list.extend(trainval) + test_list.extend(test) + random.shuffle(trainval_list) + with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval: + for item in trainval_list: + ftrainval.write(item[0] + ' ' + item[1] + '\n') + + with open(osp.join(output_dir, 'test.txt'), 'w') as ftest: + for item in test_list: + ftest.write(item[0] + ' ' + item[1] + '\n') +``` + +该函数首先对每一年(year)的数据进行处理,然后将训练图像的文件路径列表进行随机打乱,最后保存训练文件列表和测试文件列表。默认```prepare_voc_data.py```和```VOCdevkit```在相同目录下,且生成的文件列表也在该目录。需注意```trainval.txt```既包含VOC2007的训练数据,也包含VOC2012的训练数据,```test.txt```只包含VOC2007的测试数据。我们这里提供```trainval.txt```前几行输入作为样例: + +``` +VOCdevkit/VOC2007/JPEGImages/000005.jpg VOCdevkit/VOC2007/Annotations/000005.xml +VOCdevkit/VOC2007/JPEGImages/000007.jpg VOCdevkit/VOC2007/Annotations/000007.xml +VOCdevkit/VOC2007/JPEGImages/000009.jpg VOCdevkit/VOC2007/Annotations/000009.xml +``` + +文件共两个字段,第一个字段为图像文件的相对路径,第二个字段为对应标注文件的相对路径。 + +### 预训练模型准备 +下载预训练的VGG-16模型,我们提供了一个转换好的模型,具体下载地址为:http://paddlepaddle.bj.bcebos.com/model_zoo/detection/ssd_model/vgg_model.tar.gz ,下载好模型后,放置路径为```vgg/vgg_model.tar.gz```。 +### 模型训练 +直接执行```python train.py```即可进行训练。需要注意本示例仅支持CUDA GPU环境,无法在CPU上训练,主要因为使用CPU训练速度很慢,实践中一般使用GPU来处理图像任务,这里实现采用硬编码方式使用cuDNN,不提供CPU版本。```train.py```的一些关键执行逻辑: + +```python +paddle.init(use_gpu=True, trainer_count=4) +data_args = data_provider.Settings( + data_dir='./data', + label_file='label_list', + resize_h=cfg.IMG_HEIGHT, + resize_w=cfg.IMG_WIDTH, + mean_value=[104,117,124]) +train(train_file_list='./data/trainval.txt', + dev_file_list='./data/test.txt', + data_args=data_args, + init_model_path='./vgg/vgg_model.tar.gz') +``` + +主要包括: + +1. 调用```paddle.init```指定使用4卡GPU训练。 +2. 调用```data_provider.Settings```配置数据预处理所需参数,其中```cfg.IMG_HEIGHT```和```cfg.IMG_WIDTH```在配置文件```config/vgg_config.py```中设置,这里均为300,300x300是一个典型配置,兼顾效率和检测精度,也可以通过修改配置文件扩展到512x512。 +3. 调用```train```执行训练,其中```train_file_list```指定训练数据列表,```dev_file_list```指定评估数据列表,```init_model_path```指定预训练模型位置。 +4. 训练过程中会打印一些日志信息,每训练1个batch会输出当前的轮数、当前batch的cost及mAP(mean Average Precision,平均精度均值),每训练一个pass,会保存一次模型,默认保存在```checkpoints```目录下(注:需事先创建)。 + +下面给出SDD300x300在VOC数据集(train包括07+12,test为07)上的mAP曲线,迭代140轮mAP可达到71.52%。 + +

+
+图2. SSD300x300 mAP收敛曲线 +

+ + +### 模型评估 +执行```python eval.py```即可对模型进行评估,```eval.py```的关键执行逻辑如下: + +```python +paddle.init(use_gpu=True, trainer_count=4) # use 4 gpus + +data_args = data_provider.Settings( + data_dir='./data', + label_file='label_list', + resize_h=cfg.IMG_HEIGHT, + resize_w=cfg.IMG_WIDTH, + mean_value=[104, 117, 124]) + +eval( + eval_file_list='./data/test.txt', + batch_size=4, + data_args=data_args, + model_path='models/pass-00000.tar.gz') +``` + +调用```paddle.init```指定使用4卡GPU评估;```data_provider.Settings```参见训练阶段的配置;调用```eval```执行评估,其中```eval_file_list```指定评估数据列表,```batch_size```指定评估时batch size的大小,```model_path ```指定模型位置。评估结束会输出```loss```信息和```mAP```信息。 + +### 图像检测 +执行```python infer.py```即可使用训练好的模型对图片实施检测,```infer.py```关键逻辑如下: + +```python +infer( + eval_file_list='./data/infer.txt', + save_path='infer.res', + data_args=data_args, + batch_size=4, + model_path='models/pass-00000.tar.gz', + threshold=0.3) +``` + +其中```eval_file_list```指定图像路径列表;```save_path```指定预测结果保存路径;```data_args```如上;```batch_size```为每多少样本预测一次;```model_path```指模型的位置;```threshold```为置信度阈值,只有得分大于或等于该值的才会输出。下面给出```infer.res```的一些输出样例: + +``` +VOCdevkit/VOC2007/JPEGImages/006936.jpg 12 0.997844 131.255611777 162.271582842 396.475315094 334.0 +VOCdevkit/VOC2007/JPEGImages/006936.jpg 14 0.998557 229.160234332 49.5991278887 314.098775387 312.913876176 +VOCdevkit/VOC2007/JPEGImages/006936.jpg 14 0.372522 187.543615699 133.727034628 345.647156239 327.448492289 +... +``` + +一共包含4个字段,以tab分割,第一个字段是检测图像路径,第二字段为检测矩形框内类别,第三个字段是置信度,第四个字段是4个坐标值(以空格分割)。 + +示例还提供了一个可视化脚本,直接运行```python visual.py```即可,须指定输出检测结果路径及输出目录,默认可视化后图像保存在```./visual_res```,下面是用训练好的模型infer部分图像并可视化的效果: + +

+ + + +
+图3. SSD300x300 检测可视化示例 +

+ + +## 自有数据集 +在自有数据上训练PaddlePaddle SSD需要完成两个关键准备,首先需要适配网络可以接受的输入格式,这里提供一个推荐的结构,以```train.txt```为例 + +``` +image00000_file_path image00000_annotation_file_path +image00001_file_path image00001_annotation_file_path +image00002_file_path image00002_annotation_file_path +... +``` + +文件共两列,以空白符分割,第一列为图像文件的路径,第二列为对应标注数据的文件路径。对图像文件的读取比较直接,略微复杂的是对标注数据的解析,本示例中标注数据使用xml文件存储,所以需要在```data_provider.py```中对xml解析,核心逻辑如下: + +```python +bbox_labels = [] +root = xml.etree.ElementTree.parse(label_path).getroot() +for object in root.findall('object'): + bbox_sample = [] + # start from 1 + bbox_sample.append(float(settings.label_list.index( + object.find('name').text))) + bbox = object.find('bndbox') + difficult = float(object.find('difficult').text) + bbox_sample.append(float(bbox.find('xmin').text)/img_width) + bbox_sample.append(float(bbox.find('ymin').text)/img_height) + bbox_sample.append(float(bbox.find('xmax').text)/img_width) + bbox_sample.append(float(bbox.find('ymax').text)/img_height) + bbox_sample.append(difficult) + bbox_labels.append(bbox_sample) +``` + +这里一条标注数据包括:label、xmin、ymin、xmax、ymax和is\_difficult,is\_difficult表示该object是否为难例,实际中如果不需要,只需把该字段置零即可。自有数据也需要提供对应的解析逻辑,假设标注数据(比如image00000\_annotation\_file\_path)存储格式如下: + +``` +label1 xmin1 ymin1 xmax1 ymax1 +label2 xmin2 ymin2 xmax2 ymax2 +... +``` + +每行对应一个物体,共5个字段,第一个为label(注背景为0,需从1编号),剩余4个为坐标,对应的解析逻辑可更改为如下: + +``` +bbox_labels = [] +with open(label_path) as flabel: + for line in flabel: + bbox_sample = [] + bbox = [float(i) for i in line.strip().split()] + label = bbox[0] + bbox_sample.append(label) + bbox_sample.append(bbox[1]/float(img_width)) + bbox_sample.append(bbox[2]/float(img_height)) + bbox_sample.append(bbox[3]/float(img_width)) + bbox_sample.append(bbox[4]/float(img_height)) + bbox_sample.append(0.0) + bbox_labels.append(bbox_sample) +``` + +另一个重要的事情就是根据图像大小及检测物体的大小等更改网络结构的配置,主要是仿照```config/vgg_config.py```创建自己的配置文件,参数设置经验请参照论文\[[1](#引用)\]。 + +## 引用 +1. Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg. [SSD: Single shot multibox detector](https://arxiv.org/abs/1512.02325). European conference on computer vision. Springer, Cham, 2016. +2. Simonyan, Karen, and Andrew Zisserman. [Very deep convolutional networks for large-scale image recognition](https://arxiv.org/abs/1409.1556). arXiv preprint arXiv:1409.1556 (2014). +3. [The PASCAL Visual Object Classes Challenge 2007](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html) +4. [Visual Object Classes Challenge 2012 (VOC2012)](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) diff --git a/ssd/config/__init__.py b/ssd/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ssd/config/pascal_voc_conf.py b/ssd/config/pascal_voc_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..318b9ae6798be4855b9fdabdee85f690eb3139bc --- /dev/null +++ b/ssd/config/pascal_voc_conf.py @@ -0,0 +1,91 @@ +from easydict import EasyDict as edict +import numpy as np + +__C = edict() +cfg = __C + +__C.TRAIN = edict() + +__C.IMG_WIDTH = 300 +__C.IMG_HEIGHT = 300 +__C.IMG_CHANNEL = 3 +__C.CLASS_NUM = 21 +__C.BACKGROUND_ID = 0 + +# training settings +__C.TRAIN.LEARNING_RATE = 0.001 / 4 +__C.TRAIN.MOMENTUM = 0.9 +__C.TRAIN.BATCH_SIZE = 32 +__C.TRAIN.NUM_PASS = 200 +__C.TRAIN.L2REGULARIZATION = 0.0005 * 4 +__C.TRAIN.LEARNING_RATE_DECAY_A = 0.1 +__C.TRAIN.LEARNING_RATE_DECAY_B = 16551 * 80 +__C.TRAIN.LEARNING_RATE_SCHEDULE = 'discexp' + +__C.NET = edict() + +# configuration for multibox_loss_layer +__C.NET.MBLOSS = edict() +__C.NET.MBLOSS.OVERLAP_THRESHOLD = 0.5 +__C.NET.MBLOSS.NEG_POS_RATIO = 3.0 +__C.NET.MBLOSS.NEG_OVERLAP = 0.5 + +# configuration for detection_map +__C.NET.DETMAP = edict() +__C.NET.DETMAP.OVERLAP_THRESHOLD = 0.5 +__C.NET.DETMAP.EVAL_DIFFICULT = False +__C.NET.DETMAP.AP_TYPE = "11point" + +# configuration for detection_output_layer +__C.NET.DETOUT = edict() +__C.NET.DETOUT.CONFIDENCE_THRESHOLD = 0.01 +__C.NET.DETOUT.NMS_THRESHOLD = 0.45 +__C.NET.DETOUT.NMS_TOP_K = 400 +__C.NET.DETOUT.KEEP_TOP_K = 200 + +# configuration for priorbox_layer from conv4_3 +__C.NET.CONV4 = edict() +__C.NET.CONV4.PB = edict() +__C.NET.CONV4.PB.MIN_SIZE = [30] +__C.NET.CONV4.PB.ASPECT_RATIO = [2.] +__C.NET.CONV4.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +# configuration for priorbox_layer from fc7 +__C.NET.FC7 = edict() +__C.NET.FC7.PB = edict() +__C.NET.FC7.PB.MIN_SIZE = [60] +__C.NET.FC7.PB.MAX_SIZE = [114] +__C.NET.FC7.PB.ASPECT_RATIO = [2., 3.] +__C.NET.FC7.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +# configuration for priorbox_layer from conv6_2 +__C.NET.CONV6 = edict() +__C.NET.CONV6.PB = edict() +__C.NET.CONV6.PB.MIN_SIZE = [114] +__C.NET.CONV6.PB.MAX_SIZE = [168] +__C.NET.CONV6.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV6.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +# configuration for priorbox_layer from conv7_2 +__C.NET.CONV7 = edict() +__C.NET.CONV7.PB = edict() +__C.NET.CONV7.PB.MIN_SIZE = [168] +__C.NET.CONV7.PB.MAX_SIZE = [222] +__C.NET.CONV7.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV7.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +# configuration for priorbox_layer from conv8_2 +__C.NET.CONV8 = edict() +__C.NET.CONV8.PB = edict() +__C.NET.CONV8.PB.MIN_SIZE = [222] +__C.NET.CONV8.PB.MAX_SIZE = [276] +__C.NET.CONV8.PB.ASPECT_RATIO = [2., 3.] +__C.NET.CONV8.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] + +# configuration for priorbox_layer from pool6 +__C.NET.POOL6 = edict() +__C.NET.POOL6.PB = edict() +__C.NET.POOL6.PB.MIN_SIZE = [276] +__C.NET.POOL6.PB.MAX_SIZE = [330] +__C.NET.POOL6.PB.ASPECT_RATIO = [2., 3.] +__C.NET.POOL6.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2] diff --git a/ssd/data/label_list b/ssd/data/label_list new file mode 100644 index 0000000000000000000000000000000000000000..87df23ce0aebcd5ab96fc91c868598c3333da59c --- /dev/null +++ b/ssd/data/label_list @@ -0,0 +1,21 @@ +background +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/ssd/data/prepare_voc_data.py b/ssd/data/prepare_voc_data.py new file mode 100644 index 0000000000000000000000000000000000000000..a652956e91ab8277bc6670d4dc85905fc52a3203 --- /dev/null +++ b/ssd/data/prepare_voc_data.py @@ -0,0 +1,63 @@ +import os +import os.path as osp +import re +import random + +devkit_dir = './VOCdevkit' +years = ['2007', '2012'] + + +def get_dir(devkit_dir, year, type): + return osp.join(devkit_dir, 'VOC' + year, type) + + +def walk_dir(devkit_dir, year): + filelist_dir = get_dir(devkit_dir, year, 'ImageSets/Main') + annotation_dir = get_dir(devkit_dir, year, 'Annotations') + img_dir = get_dir(devkit_dir, year, 'JPEGImages') + trainval_list = [] + test_list = [] + added = set() + + for _, _, files in os.walk(filelist_dir): + for fname in files: + img_ann_list = [] + if re.match('[a-z]+_trainval\.txt', fname): + img_ann_list = trainval_list + elif re.match('[a-z]+_test\.txt', fname): + img_ann_list = test_list + else: + continue + fpath = osp.join(filelist_dir, fname) + for line in open(fpath): + name_prefix = line.strip().split()[0] + if name_prefix in added: + continue + added.add(name_prefix) + ann_path = osp.join(annotation_dir, name_prefix + '.xml') + img_path = osp.join(img_dir, name_prefix + '.jpg') + assert os.path.isfile(ann_path), 'file %s not found.' % ann_path + assert os.path.isfile(img_path), 'file %s not found.' % img_path + img_ann_list.append((img_path, ann_path)) + + return trainval_list, test_list + + +def prepare_filelist(devkit_dir, years, output_dir): + trainval_list = [] + test_list = [] + for year in years: + trainval, test = walk_dir(devkit_dir, year) + trainval_list.extend(trainval) + test_list.extend(test) + random.shuffle(trainval_list) + with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval: + for item in trainval_list: + ftrainval.write(item[0] + ' ' + item[1] + '\n') + + with open(osp.join(output_dir, 'test.txt'), 'w') as ftest: + for item in test_list: + ftest.write(item[0] + ' ' + item[1] + '\n') + + +prepare_filelist(devkit_dir, years, '.') diff --git a/ssd/data_provider.py b/ssd/data_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..e59d324b497977ec02c1f728cb49a432f864382c --- /dev/null +++ b/ssd/data_provider.py @@ -0,0 +1,175 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import image_util +from paddle.utils.image_util import * +import random +from PIL import Image +import numpy as np +import xml.etree.ElementTree +import os + + +class Settings(object): + def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value): + self._data_dir = data_dir + self._label_list = [] + label_fpath = os.path.join(data_dir, label_file) + for line in open(label_fpath): + self._label_list.append(line.strip()) + + self._resize_height = resize_h + self._resize_width = resize_w + self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype( + 'float32') + + @property + def data_dir(self): + return self._data_dir + + @property + def label_list(self): + return self._label_list + + @property + def resize_h(self): + return self._resize_height + + @property + def resize_w(self): + return self._resize_width + + @property + def img_mean(self): + return self._img_mean + + +def _reader_creator(settings, file_list, mode, shuffle): + def reader(): + with open(file_list) as flist: + lines = [line.strip() for line in flist] + if shuffle: + random.shuffle(lines) + for line in lines: + if mode == 'train' or mode == 'test': + img_path, label_path = line.split() + img_path = os.path.join(settings.data_dir, img_path) + label_path = os.path.join(settings.data_dir, label_path) + elif mode == 'infer': + img_path = os.path.join(settings.data_dir, line) + + img = Image.open(img_path) + img_width, img_height = img.size + img = np.array(img) + + # layout: label | xmin | ymin | xmax | ymax | difficult + if mode == 'train' or mode == 'test': + bbox_labels = [] + root = xml.etree.ElementTree.parse(label_path).getroot() + for object in root.findall('object'): + bbox_sample = [] + # start from 1 + bbox_sample.append( + float( + settings.label_list.index( + object.find('name').text))) + bbox = object.find('bndbox') + difficult = float(object.find('difficult').text) + bbox_sample.append( + float(bbox.find('xmin').text) / img_width) + bbox_sample.append( + float(bbox.find('ymin').text) / img_height) + bbox_sample.append( + float(bbox.find('xmax').text) / img_width) + bbox_sample.append( + float(bbox.find('ymax').text) / img_height) + bbox_sample.append(difficult) + bbox_labels.append(bbox_sample) + + sample_labels = bbox_labels + if mode == 'train': + batch_sampler = [] + # hard-code here + batch_sampler.append( + image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, + 0.0)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, + 0.0)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, + 0.0)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, + 0.0)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, + 0.0)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, + 0.0)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, + 1.0)) + """ random crop """ + sampled_bbox = image_util.generate_batch_samples( + batch_sampler, bbox_labels, img_width, img_height) + + if len(sampled_bbox) > 0: + idx = int(random.uniform(0, len(sampled_bbox))) + img, sample_labels = image_util.crop_image( + img, bbox_labels, sampled_bbox[idx], img_width, + img_height) + + img = Image.fromarray(img) + img = img.resize((settings.resize_w, settings.resize_h), + Image.ANTIALIAS) + img = np.array(img) + + if mode == 'train': + mirror = int(random.uniform(0, 2)) + if mirror == 1: + img = img[:, ::-1, :] + for i in xrange(len(sample_labels)): + tmp = sample_labels[i][1] + sample_labels[i][1] = 1 - sample_labels[i][3] + sample_labels[i][3] = 1 - tmp + + if len(img.shape) == 3: + img = np.swapaxes(img, 1, 2) + img = np.swapaxes(img, 1, 0) + + img = img.astype('float32') + img -= settings.img_mean + img = img.flatten() + + if mode == 'train' or mode == 'test': + if mode == 'train' and len(sample_labels) == 0: continue + yield img.astype('float32'), sample_labels + elif mode == 'infer': + yield img.astype('float32') + + return reader + + +def train(settings, file_list, shuffle=True): + return _reader_creator(settings, file_list, 'train', shuffle) + + +def test(settings, file_list): + return _reader_creator(settings, file_list, 'test', False) + + +def infer(settings, file_list): + return _reader_creator(settings, file_list, 'infer', False) diff --git a/ssd/eval.py b/ssd/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..345e46f98b098480877a54dac842bd576112b1a3 --- /dev/null +++ b/ssd/eval.py @@ -0,0 +1,48 @@ +import paddle.v2 as paddle +import data_provider +import vgg_ssd_net +import os, sys +import gzip +from config.pascal_voc_conf import cfg + + +def eval(eval_file_list, batch_size, data_args, model_path): + cost, detect_out = vgg_ssd_net.net_conf(mode='eval') + + assert os.path.isfile(model_path), 'Invalid model.' + parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path)) + + optimizer = paddle.optimizer.Momentum() + + trainer = paddle.trainer.SGD( + cost=cost, + parameters=parameters, + extra_layers=[detect_out], + update_equation=optimizer) + + feeding = {'image': 0, 'bbox': 1} + + reader = paddle.batch( + data_provider.test(data_args, eval_file_list), batch_size=batch_size) + + result = trainer.test(reader=reader, feeding=feeding) + + print "TestCost: %f, Detection mAP=%g" % \ + (result.cost, result.metrics['detection_evaluator']) + + +if __name__ == "__main__": + paddle.init(use_gpu=True, trainer_count=4) # use 4 gpus + + data_args = data_provider.Settings( + data_dir='./data', + label_file='label_list', + resize_h=cfg.IMG_HEIGHT, + resize_w=cfg.IMG_WIDTH, + mean_value=[104, 117, 124]) + + eval( + eval_file_list='./data/test.txt', + batch_size=4, + data_args=data_args, + model_path='models/pass-00000.tar.gz') diff --git a/ssd/image_util.py b/ssd/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8744eda0a078acd38cad9b10ca7511185efc43 --- /dev/null +++ b/ssd/image_util.py @@ -0,0 +1,161 @@ +from PIL import Image +import numpy as np +import random +import math + + +class sampler(): + def __init__(self, max_sample, max_trial, min_scale, max_scale, + min_aspect_ratio, max_aspect_ratio, min_jaccard_overlap, + max_jaccard_overlap): + self.max_sample = max_sample + self.max_trial = max_trial + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + self.min_jaccard_overlap = min_jaccard_overlap + self.max_jaccard_overlap = max_jaccard_overlap + + +class bbox(): + def __init__(self, xmin, ymin, xmax, ymax): + self.xmin = xmin + self.ymin = ymin + self.xmax = xmax + self.ymax = ymax + + +def bbox_area(src_bbox): + width = src_bbox.xmax - src_bbox.xmin + height = src_bbox.ymax - src_bbox.ymin + return width * height + + +def generate_sample(sampler): + scale = random.uniform(sampler.min_scale, sampler.max_scale) + min_aspect_ratio = max(sampler.min_aspect_ratio, (scale**2.0)) + max_aspect_ratio = min(sampler.max_aspect_ratio, 1 / (scale**2.0)) + aspect_ratio = random.uniform(min_aspect_ratio, max_aspect_ratio) + bbox_width = scale * (aspect_ratio**0.5) + bbox_height = scale / (aspect_ratio**0.5) + xmin_bound = 1 - bbox_width + ymin_bound = 1 - bbox_height + xmin = random.uniform(0, xmin_bound) + ymin = random.uniform(0, ymin_bound) + xmax = xmin + bbox_width + ymax = ymin + bbox_height + sampled_bbox = bbox(xmin, ymin, xmax, ymax) + return sampled_bbox + + +def jaccard_overlap(sample_bbox, object_bbox): + if sample_bbox.xmin >= object_bbox.xmax or \ + sample_bbox.xmax <= object_bbox.xmin or \ + sample_bbox.ymin >= object_bbox.ymax or \ + sample_bbox.ymax <= object_bbox.ymin: + return 0 + intersect_xmin = max(sample_bbox.xmin, object_bbox.xmin) + intersect_ymin = max(sample_bbox.ymin, object_bbox.ymin) + intersect_xmax = min(sample_bbox.xmax, object_bbox.xmax) + intersect_ymax = min(sample_bbox.ymax, object_bbox.ymax) + intersect_size = (intersect_xmax - intersect_xmin) * ( + intersect_ymax - intersect_ymin) + sample_bbox_size = bbox_area(sample_bbox) + object_bbox_size = bbox_area(object_bbox) + overlap = intersect_size / ( + sample_bbox_size + object_bbox_size - intersect_size) + return overlap + + +def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): + if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0: + return True + for i in range(len(bbox_labels)): + object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2], + bbox_labels[i][3], bbox_labels[i][4]) + overlap = jaccard_overlap(sample_bbox, object_bbox) + if sampler.min_jaccard_overlap != 0 and \ + overlap < sampler.min_jaccard_overlap: + continue + if sampler.max_jaccard_overlap != 0 and \ + overlap > sampler.max_jaccard_overlap: + continue + return True + return False + + +def generate_batch_samples(batch_sampler, bbox_labels, image_width, + image_height): + sampled_bbox = [] + index = [] + c = 0 + for sampler in batch_sampler: + found = 0 + for i in range(sampler.max_trial): + if found >= sampler.max_sample: + break + sample_bbox = generate_sample(sampler) + if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): + sampled_bbox.append(sample_bbox) + found = found + 1 + index.append(c) + c = c + 1 + return sampled_bbox + + +def clip_bbox(src_bbox): + src_bbox.xmin = max(min(src_bbox.xmin, 1.0), 0.0) + src_bbox.ymin = max(min(src_bbox.ymin, 1.0), 0.0) + src_bbox.xmax = max(min(src_bbox.xmax, 1.0), 0.0) + src_bbox.ymax = max(min(src_bbox.ymax, 1.0), 0.0) + return src_bbox + + +def meet_emit_constraint(src_bbox, sample_bbox): + center_x = (src_bbox.xmax + src_bbox.xmin) / 2 + center_y = (src_bbox.ymax + src_bbox.ymin) / 2 + if center_x >= sample_bbox.xmin and \ + center_x <= sample_bbox.xmax and \ + center_y >= sample_bbox.ymin and \ + center_y <= sample_bbox.ymax: + return True + return False + + +def transform_labels(bbox_labels, sample_bbox): + proj_bbox = bbox(0, 0, 0, 0) + sample_labels = [] + for i in range(len(bbox_labels)): + sample_label = [] + object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2], + bbox_labels[i][3], bbox_labels[i][4]) + if not meet_emit_constraint(object_bbox, sample_bbox): + continue + sample_width = sample_bbox.xmax - sample_bbox.xmin + sample_height = sample_bbox.ymax - sample_bbox.ymin + proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width + proj_bbox.ymin = (object_bbox.ymin - sample_bbox.ymin) / sample_height + proj_bbox.xmax = (object_bbox.xmax - sample_bbox.xmin) / sample_width + proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height + proj_bbox = clip_bbox(proj_bbox) + if bbox_area(proj_bbox) > 0: + sample_label.append(bbox_labels[i][0]) + sample_label.append(float(proj_bbox.xmin)) + sample_label.append(float(proj_bbox.ymin)) + sample_label.append(float(proj_bbox.xmax)) + sample_label.append(float(proj_bbox.ymax)) + sample_label.append(bbox_labels[i][5]) + sample_labels.append(sample_label) + return sample_labels + + +def crop_image(img, bbox_labels, sample_bbox, image_width, image_height): + sample_bbox = clip_bbox(sample_bbox) + xmin = int(sample_bbox.xmin * image_width) + xmax = int(sample_bbox.xmax * image_width) + ymin = int(sample_bbox.ymin * image_height) + ymax = int(sample_bbox.ymax * image_height) + sample_img = img[ymin:ymax, xmin:xmax] + sample_labels = transform_labels(bbox_labels, sample_bbox) + return sample_img, sample_labels diff --git a/ssd/images/SSD300x300_map.png b/ssd/images/SSD300x300_map.png new file mode 100644 index 0000000000000000000000000000000000000000..a40a1e028be7ba979052034c152028976bc4b715 Binary files /dev/null and b/ssd/images/SSD300x300_map.png differ diff --git a/ssd/images/ssd_network.png b/ssd/images/ssd_network.png new file mode 100644 index 0000000000000000000000000000000000000000..193caa0168a4f981506ad7b97f8b9fb35557ed20 Binary files /dev/null and b/ssd/images/ssd_network.png differ diff --git a/ssd/images/vis_1.jpg b/ssd/images/vis_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c317462ee6053df15fa8d44d0f35398e47156e8d Binary files /dev/null and b/ssd/images/vis_1.jpg differ diff --git a/ssd/images/vis_2.jpg b/ssd/images/vis_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7bc59b239cb9c123087fdecbb210ad52a3a35f10 Binary files /dev/null and b/ssd/images/vis_2.jpg differ diff --git a/ssd/images/vis_3.jpg b/ssd/images/vis_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a79598343a7e2707ba79c2e8891d7af0c24df491 Binary files /dev/null and b/ssd/images/vis_3.jpg differ diff --git a/ssd/images/vis_4.jpg b/ssd/images/vis_4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..96b2c99c9ef986cc0d4802b31c33f076fce6f965 Binary files /dev/null and b/ssd/images/vis_4.jpg differ diff --git a/ssd/index.html b/ssd/index.html new file mode 100644 index 0000000000000000000000000000000000000000..c31c21889c7309b449940edcc0323b50d453efd0 --- /dev/null +++ b/ssd/index.html @@ -0,0 +1,290 @@ + + + + + + + + + + + + + + + + + +
+
+ + + + + + + diff --git a/ssd/infer.py b/ssd/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bc79189935d8bdd59f17756b9c95581870f36a --- /dev/null +++ b/ssd/infer.py @@ -0,0 +1,98 @@ +import paddle.v2 as paddle +import data_provider +import vgg_ssd_net +import os, sys +import numpy as np +import gzip +from PIL import Image +from config.pascal_voc_conf import cfg + + +def _infer(inferer, infer_data, threshold): + ret = [] + infer_res = inferer.infer(input=infer_data) + keep_inds = np.where(infer_res[:, 2] >= threshold)[0] + for idx in keep_inds: + ret.append([ + infer_res[idx][0], infer_res[idx][1] - 1, infer_res[idx][2], + infer_res[idx][3], infer_res[idx][4], infer_res[idx][5], + infer_res[idx][6] + ]) + return ret + + +def save_batch_res(ret_res, img_w, img_h, fname_list, fout): + for det_res in ret_res: + img_idx = int(det_res[0]) + label = int(det_res[1]) + conf_score = det_res[2] + xmin = det_res[3] * img_w[img_idx] + ymin = det_res[4] * img_h[img_idx] + xmax = det_res[5] * img_w[img_idx] + ymax = det_res[6] * img_h[img_idx] + fout.write(fname_list[img_idx] + '\t' + str(label) + '\t' + str( + conf_score) + '\t' + str(xmin) + ' ' + str(ymin) + ' ' + str(xmax) + + ' ' + str(ymax)) + fout.write('\n') + + +def infer(eval_file_list, save_path, data_args, batch_size, model_path, + threshold): + detect_out = vgg_ssd_net.net_conf(mode='infer') + + assert os.path.isfile(model_path), 'Invalid model.' + parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path)) + + inferer = paddle.inference.Inference( + output_layer=detect_out, parameters=parameters) + + reader = data_provider.infer(data_args, eval_file_list) + all_fname_list = [line.strip() for line in open(eval_file_list).readlines()] + + test_data = [] + fname_list = [] + img_w = [] + img_h = [] + idx = 0 + """Do inference batch by batch, + coords of bbox will be scaled based on image size + """ + with open(save_path, 'w') as fout: + for img in reader(): + test_data.append([img]) + fname_list.append(all_fname_list[idx]) + w, h = Image.open(os.path.join('./data', fname_list[-1])).size + img_w.append(w) + img_h.append(h) + if len(test_data) == batch_size: + ret_res = _infer(inferer, test_data, threshold) + save_batch_res(ret_res, img_w, img_h, fname_list, fout) + test_data = [] + fname_list = [] + img_w = [] + img_h = [] + + idx += 1 + + if len(test_data) > 0: + ret_res = _infer(inferer, test_data, threshold) + save_batch_res(ret_res, img_w, img_h, fname_list, fout) + + +if __name__ == "__main__": + paddle.init(use_gpu=True, trainer_count=1) + + data_args = data_provider.Settings( + data_dir='./data', + label_file='label_list', + resize_h=cfg.IMG_HEIGHT, + resize_w=cfg.IMG_WIDTH, + mean_value=[104, 117, 124]) + + infer( + eval_file_list='./data/infer.txt', + save_path='infer.res', + data_args=data_args, + batch_size=4, + model_path='models/pass-00000.tar.gz', + threshold=0.3) diff --git a/ssd/train.py b/ssd/train.py new file mode 100644 index 0000000000000000000000000000000000000000..783944214b67d15af31267c8ba1ded3fa48e6cb0 --- /dev/null +++ b/ssd/train.py @@ -0,0 +1,84 @@ +import paddle.v2 as paddle +import data_provider +import vgg_ssd_net +import os, sys +import gzip +import tarfile +from config.pascal_voc_conf import cfg + + +def train(train_file_list, dev_file_list, data_args, init_model_path): + optimizer = paddle.optimizer.Momentum( + momentum=cfg.TRAIN.MOMENTUM, + learning_rate=cfg.TRAIN.LEARNING_RATE, + regularization=paddle.optimizer.L2Regularization( + rate=cfg.TRAIN.L2REGULARIZATION), + learning_rate_decay_a=cfg.TRAIN.LEARNING_RATE_DECAY_A, + learning_rate_decay_b=cfg.TRAIN.LEARNING_RATE_DECAY_B, + learning_rate_schedule=cfg.TRAIN.LEARNING_RATE_SCHEDULE) + + cost, detect_out = vgg_ssd_net.net_conf('train') + + parameters = paddle.parameters.create(cost) + if not (init_model_path is None): + assert os.path.isfile(init_model_path), 'Invalid model.' + parameters.init_from_tar(gzip.open(init_model_path)) + + trainer = paddle.trainer.SGD( + cost=cost, + parameters=parameters, + extra_layers=[detect_out], + update_equation=optimizer) + + feeding = {'image': 0, 'bbox': 1} + + train_reader = paddle.batch( + data_provider.train(data_args, train_file_list), + batch_size=cfg.TRAIN.BATCH_SIZE) # generate a batch image each time + + dev_reader = paddle.batch( + data_provider.test(data_args, dev_file_list), + batch_size=cfg.TRAIN.BATCH_SIZE) + + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 1 == 0: + print "\nPass %d, Batch %d, TrainCost %f, Detection mAP=%f" % \ + (event.pass_id, + event.batch_id, + event.cost, + event.metrics['detection_evaluator']) + else: + sys.stdout.write('.') + sys.stdout.flush() + + if isinstance(event, paddle.event.EndPass): + with gzip.open('checkpoints/params_pass_%05d.tar.gz' % \ + event.pass_id, 'w') as f: + parameters.to_tar(f) + result = trainer.test(reader=dev_reader, feeding=feeding) + print "\nTest with Pass %d, TestCost: %f, Detection mAP=%g" % \ + (event.pass_id, + result.cost, + result.metrics['detection_evaluator']) + + trainer.train( + reader=train_reader, + event_handler=event_handler, + num_passes=cfg.TRAIN.NUM_PASS, + feeding=feeding) + + +if __name__ == "__main__": + paddle.init(use_gpu=True, trainer_count=4) + data_args = data_provider.Settings( + data_dir='./data', + label_file='label_list', + resize_h=cfg.IMG_HEIGHT, + resize_w=cfg.IMG_WIDTH, + mean_value=[104, 117, 124]) + train( + train_file_list='./data/trainval.txt', + dev_file_list='./data/test.txt', + data_args=data_args, + init_model_path='./vgg/vgg_model.tar.gz') diff --git a/ssd/vgg_ssd_net.py b/ssd/vgg_ssd_net.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5c107e6fda6e58ff2b27c55bd4773639d36aab --- /dev/null +++ b/ssd/vgg_ssd_net.py @@ -0,0 +1,278 @@ +import paddle.v2 as paddle +from config.pascal_voc_conf import cfg + + +def net_conf(mode): + """Network configuration. Total three modes included 'train' 'eval' + and 'infer'. Loss and mAP evaluation layer will return if using 'train' + and 'eval'. In 'infer' mode, only detection output layer will be returned. + """ + default_l2regularization = cfg.TRAIN.L2REGULARIZATION + + default_bias_attr = paddle.attr.ParamAttr(l2_rate=0.0, learning_rate=2.0) + default_static_bias_attr = paddle.attr.ParamAttr(is_static=True) + + def get_param_attr(local_lr, regularization): + is_static = False + if local_lr == 0.0: + is_static = True + return paddle.attr.ParamAttr( + learning_rate=local_lr, l2_rate=regularization, is_static=is_static) + + def conv_group(stack_num, name_list, input, filter_size_list, num_channels, + num_filters_list, stride_list, padding_list, + common_bias_attr, common_param_attr, common_act): + conv = input + in_channels = num_channels + for i in xrange(stack_num): + conv = paddle.layer.img_conv( + name=name_list[i], + input=conv, + filter_size=filter_size_list[i], + num_channels=in_channels, + num_filters=num_filters_list[i], + stride=stride_list[i], + padding=padding_list[i], + bias_attr=common_bias_attr, + param_attr=common_param_attr, + act=common_act) + in_channels = num_filters_list[i] + return conv + + def vgg_block(idx_str, input, num_channels, num_filters, pool_size, + pool_stride, pool_pad): + layer_name = "conv%s_" % idx_str + stack_num = 3 + name_list = [layer_name + str(i + 1) for i in xrange(3)] + + conv = conv_group(stack_num, name_list, input, [3] * stack_num, + num_channels, [num_filters] * stack_num, + [1] * stack_num, [1] * stack_num, default_bias_attr, + get_param_attr(1, default_l2regularization), + paddle.activation.Relu()) + + pool = paddle.layer.img_pool( + input=conv, + pool_size=pool_size, + num_channels=num_filters, + pool_type=paddle.pooling.CudnnMax(), + stride=pool_stride, + padding=pool_pad) + return conv, pool + + def mbox_block(layer_idx, input, num_channels, filter_size, loc_filters, + conf_filters): + mbox_loc_name = layer_idx + "_mbox_loc" + mbox_loc = paddle.layer.img_conv( + name=mbox_loc_name, + input=input, + filter_size=filter_size, + num_channels=num_channels, + num_filters=loc_filters, + stride=1, + padding=1, + bias_attr=default_bias_attr, + param_attr=get_param_attr(1, default_l2regularization), + act=paddle.activation.Identity()) + + mbox_conf_name = layer_idx + "_mbox_conf" + mbox_conf = paddle.layer.img_conv( + name=mbox_conf_name, + input=input, + filter_size=filter_size, + num_channels=num_channels, + num_filters=conf_filters, + stride=1, + padding=1, + bias_attr=default_bias_attr, + param_attr=get_param_attr(1, default_l2regularization), + act=paddle.activation.Identity()) + + return mbox_loc, mbox_conf + + def ssd_block(layer_idx, input, img_shape, num_channels, num_filters1, + num_filters2, aspect_ratio, variance, min_size, max_size): + layer_name = "conv" + layer_idx + "_" + stack_num = 2 + conv1_name = layer_name + "1" + conv2_name = layer_name + "2" + conv2 = conv_group(stack_num, [conv1_name, conv2_name], input, [1, 3], + num_channels, [num_filters1, num_filters2], [1, 2], + [0, 1], default_bias_attr, + get_param_attr(1, default_l2regularization), + paddle.activation.Relu()) + + loc_filters = (len(aspect_ratio) * 2 + 1 + len(max_size)) * 4 + conf_filters = ( + len(aspect_ratio) * 2 + 1 + len(max_size)) * cfg.CLASS_NUM + mbox_loc, mbox_conf = mbox_block(conv2_name, conv2, num_filters2, 3, + loc_filters, conf_filters) + mbox_priorbox = paddle.layer.priorbox( + input=conv2, + image=img_shape, + min_size=min_size, + max_size=max_size, + aspect_ratio=aspect_ratio, + variance=variance) + + return conv2, mbox_loc, mbox_conf, mbox_priorbox + + img = paddle.layer.data( + name='image', + type=paddle.data_type.dense_vector(cfg.IMG_CHANNEL * cfg.IMG_HEIGHT * + cfg.IMG_WIDTH), + height=cfg.IMG_HEIGHT, + width=cfg.IMG_WIDTH) + + stack_num = 2 + conv1_2 = conv_group(stack_num, ['conv1_1', 'conv1_2'], img, + [3] * stack_num, 3, [64] * stack_num, [1] * stack_num, + [1] * stack_num, default_static_bias_attr, + get_param_attr(0, 0), paddle.activation.Relu()) + + pool1 = paddle.layer.img_pool( + name="pool1", + input=conv1_2, + pool_type=paddle.pooling.CudnnMax(), + pool_size=2, + num_channels=64, + stride=2) + + stack_num = 2 + conv2_2 = conv_group(stack_num, ['conv2_1', 'conv2_2'], pool1, [3] * + stack_num, 64, [128] * stack_num, [1] * stack_num, + [1] * stack_num, default_static_bias_attr, + get_param_attr(0, 0), paddle.activation.Relu()) + + pool2 = paddle.layer.img_pool( + name="pool2", + input=conv2_2, + pool_type=paddle.pooling.CudnnMax(), + pool_size=2, + num_channels=128, + stride=2) + + conv3_3, pool3 = vgg_block("3", pool2, 128, 256, 2, 2, 0) + + conv4_3, pool4 = vgg_block("4", pool3, 256, 512, 2, 2, 0) + conv4_3_mbox_priorbox = paddle.layer.priorbox( + input=conv4_3, + image=img, + min_size=cfg.NET.CONV4.PB.MIN_SIZE, + aspect_ratio=cfg.NET.CONV4.PB.ASPECT_RATIO, + variance=cfg.NET.CONV4.PB.VARIANCE) + conv4_3_norm = paddle.layer.cross_channel_norm( + name="conv4_3_norm", + input=conv4_3, + param_attr=paddle.attr.ParamAttr( + initial_mean=20, initial_std=0, is_static=False, learning_rate=1)) + conv4_3_norm_mbox_loc, conv4_3_norm_mbox_conf = \ + mbox_block("conv4_3_norm", conv4_3_norm, 512, 3, 12, 63) + + conv5_3, pool5 = vgg_block("5", pool4, 512, 512, 3, 1, 1) + + stack_num = 2 + fc7 = conv_group(stack_num, ['fc6', 'fc7'], pool5, [3, 1], 512, [1024] * + stack_num, [1] * stack_num, [1, 0], default_bias_attr, + get_param_attr(1, default_l2regularization), + paddle.activation.Relu()) + + fc7_mbox_loc, fc7_mbox_conf = mbox_block("fc7", fc7, 1024, 3, 24, 126) + fc7_mbox_priorbox = paddle.layer.priorbox( + input=fc7, + image=img, + min_size=cfg.NET.FC7.PB.MIN_SIZE, + max_size=cfg.NET.FC7.PB.MAX_SIZE, + aspect_ratio=cfg.NET.FC7.PB.ASPECT_RATIO, + variance=cfg.NET.FC7.PB.VARIANCE) + + conv6_2, conv6_2_mbox_loc, conv6_2_mbox_conf, conv6_2_mbox_priorbox = \ + ssd_block("6", fc7, img, 1024, 256, 512, + cfg.NET.CONV6.PB.ASPECT_RATIO, + cfg.NET.CONV6.PB.VARIANCE, + cfg.NET.CONV6.PB.MIN_SIZE, + cfg.NET.CONV6.PB.MAX_SIZE) + conv7_2, conv7_2_mbox_loc, conv7_2_mbox_conf, conv7_2_mbox_priorbox = \ + ssd_block("7", conv6_2, img, 512, 128, 256, + cfg.NET.CONV7.PB.ASPECT_RATIO, + cfg.NET.CONV7.PB.VARIANCE, + cfg.NET.CONV7.PB.MIN_SIZE, + cfg.NET.CONV7.PB.MAX_SIZE) + conv8_2, conv8_2_mbox_loc, conv8_2_mbox_conf, conv8_2_mbox_priorbox = \ + ssd_block("8", conv7_2, img, 256, 128, 256, + cfg.NET.CONV8.PB.ASPECT_RATIO, + cfg.NET.CONV8.PB.VARIANCE, + cfg.NET.CONV8.PB.MIN_SIZE, + cfg.NET.CONV8.PB.MAX_SIZE) + + pool6 = paddle.layer.img_pool( + name="pool6", + input=conv8_2, + pool_size=3, + num_channels=256, + stride=1, + pool_type=paddle.pooling.Avg()) + pool6_mbox_loc, pool6_mbox_conf = mbox_block("pool6", pool6, 256, 3, 24, + 126) + pool6_mbox_priorbox = paddle.layer.priorbox( + input=pool6, + image=img, + min_size=cfg.NET.POOL6.PB.MIN_SIZE, + max_size=cfg.NET.POOL6.PB.MAX_SIZE, + aspect_ratio=cfg.NET.POOL6.PB.ASPECT_RATIO, + variance=cfg.NET.POOL6.PB.VARIANCE) + + mbox_priorbox = paddle.layer.concat( + name="mbox_priorbox", + input=[ + conv4_3_mbox_priorbox, fc7_mbox_priorbox, conv6_2_mbox_priorbox, + conv7_2_mbox_priorbox, conv8_2_mbox_priorbox, pool6_mbox_priorbox + ]) + + loc_loss_input = [ + conv4_3_norm_mbox_loc, fc7_mbox_loc, conv6_2_mbox_loc, conv7_2_mbox_loc, + conv8_2_mbox_loc, pool6_mbox_loc + ] + + conf_loss_input = [ + conv4_3_norm_mbox_conf, fc7_mbox_conf, conv6_2_mbox_conf, + conv7_2_mbox_conf, conv8_2_mbox_conf, pool6_mbox_conf + ] + + detection_out = paddle.layer.detection_output( + input_loc=loc_loss_input, + input_conf=conf_loss_input, + priorbox=mbox_priorbox, + confidence_threshold=cfg.NET.DETOUT.CONFIDENCE_THRESHOLD, + nms_threshold=cfg.NET.DETOUT.NMS_THRESHOLD, + num_classes=cfg.CLASS_NUM, + nms_top_k=cfg.NET.DETOUT.NMS_TOP_K, + keep_top_k=cfg.NET.DETOUT.KEEP_TOP_K, + background_id=cfg.BACKGROUND_ID, + name="detection_output") + + if mode == 'train' or mode == 'eval': + bbox = paddle.layer.data( + name='bbox', type=paddle.data_type.dense_vector_sequence(6)) + loss = paddle.layer.multibox_loss( + input_loc=loc_loss_input, + input_conf=conf_loss_input, + priorbox=mbox_priorbox, + label=bbox, + num_classes=cfg.CLASS_NUM, + overlap_threshold=cfg.NET.MBLOSS.OVERLAP_THRESHOLD, + neg_pos_ratio=cfg.NET.MBLOSS.NEG_POS_RATIO, + neg_overlap=cfg.NET.MBLOSS.NEG_OVERLAP, + background_id=cfg.BACKGROUND_ID, + name="multibox_loss") + paddle.evaluator.detection_map( + input=detection_out, + label=bbox, + overlap_threshold=cfg.NET.DETMAP.OVERLAP_THRESHOLD, + background_id=cfg.BACKGROUND_ID, + evaluate_difficult=cfg.NET.DETMAP.EVAL_DIFFICULT, + ap_type=cfg.NET.DETMAP.AP_TYPE, + name="detection_evaluator") + return loss, detection_out + elif mode == 'infer': + return detection_out diff --git a/ssd/visual.py b/ssd/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..278fd34af1a7e817012c27f38647f9ce76f0c803 --- /dev/null +++ b/ssd/visual.py @@ -0,0 +1,33 @@ +import cv2 +import os + +data_dir = './data' +infer_file = './infer.res' +out_dir = './visual_res' + +path_to_im = dict() + +for line in open(infer_file): + img_path, _, _, _ = line.strip().split('\t') + if img_path not in path_to_im: + im = cv2.imread(os.path.join(data_dir, img_path)) + path_to_im[img_path] = im + +for line in open(infer_file): + img_path, label, conf, bbox = line.strip().split('\t') + xmin, ymin, xmax, ymax = map(float, bbox.split(' ')) + xmin = int(round(xmin)) + ymin = int(round(ymin)) + xmax = int(round(xmax)) + ymax = int(round(ymax)) + + img = path_to_im[img_path] + cv2.rectangle(img, (xmin, ymin), (xmax, ymax), + (0, (1 - xmin) * 255, xmin * 255), 2) + +for img_path in path_to_im: + im = path_to_im[img_path] + out_path = os.path.join(out_dir, os.path.basename(img_path)) + cv2.imwrite(out_path, im) + +print 'Done.'