未验证 提交 2b8f904e 编写于 作者: C chengjuntao 提交者: GitHub

Add RRPN models for PaddleCV (#4148)



* add rrpn for models
上级 1853d687
# RRPN 旋转物体检测
---
## 内容
- [安装](#安装)
- [简介](#简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断及可视化](#模型推断及可视化)
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的develop或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/)中的说明来更新PaddlePaddle。
## 简介
RRPN是在Faster RCNN基础上拓展出的两阶段目标检测器,可用于文字检测和旋转物体检测。通过对图像生成候选区域,提取特征,判别特征类别并修正候选框位置。
[RRPN](https://arxiv.org/abs/1703.01086) 整体网络可以分为4个主要内容:
1. 基础卷积层。作为一种卷积神经网络目标检测方法,RRPN首先使用一组基础的卷积网络提取图像的特征图。特征图被后续RPN层和全连接层共享。本示例采用[ResNet-50](https://arxiv.org/abs/1512.03385)作为基础卷积层。
2. 区域生成网络(RPN)。RPN网络用于生成候选区域(proposals)。该层通过一组固定的尺寸、比例和角度得到一组带方向锚点(anchors), 通过softmax判断旋转的锚点属于前景或者背景,再利用区域回归修正锚点从而获得精确的候选区域。
3. Rotated RoI Align。该层收集输入的特征图和带方向的候选区域,将带方向的候选区域映射到特征图中进行并池化为统一大小的区域特征图,送入全连接层判定目标类别。
4. 检测层。利用区域特征图计算候选区域的类别,同时再次通过区域回归获得检测框最终的精确位置。
### 编译自定义OP
自定义OP编译方式如下:
进入 `ext_op/src` 目录,执行编译脚本
```
cd ext_op/src
sh make.sh ${cuda_path} ${cudnn_path} ${nccl_path}
'''
其中${cuda_path}、$cudnn_path}和{nccl_path}分别为cuda、cudnn、nccl的安装路径,需通过命令行进行指定
成功编译后,`ext_op/src` 目录下将会生成 `rrpn_lib.so`
## 数据准备
### 公开数据集
[ICDAR2015数据集](https://rrc.cvc.uab.es/?ch=4&com=downloads)上进行训练,数据集需进入官网进行注册后方可下载。
数据目录结构如下:
```
dataset/icdar2015/
├── ch4_training_images
│ ├── img_143.jpg
│ ├── img_144.jpg
| ...
├── ch4_training_localization_transcription_gt
│ ├── gt_img_143.txt
│ ├── gt_img_144.txt
| ...
├── ch4_test_images
│ ├── img_111.jpg
│ ├── img_112.jpg
| ...
├── ch4_test_localization_transcription_gt
│ ├── img_111.jpg
│ ├── img_112.jpg
| ...
```
### 自定义数据
原始的RRPN只提供了二分类,若要使用自己数据进行训练多分类,需在utility.py中将dataset改为icdar2017,然后将class_num改为需求类别数,其中0为背景类。
训练自定义数据时,数据目录结构和ICDAR2015一致,标注数据格式如下:
```
x1, y1, x2, y2, x3, y3, x4, y4, class_name
x1, y1, x2, y2, x3, y3, x4, y4, class_name
```
## 模型训练
**下载预训练模型:** 本示例提供Resnet-50预训练模型,采用如下命令下载预训练模型:
sh ./pretrained/download.sh
通过初始化`pretrained_model` 加载预训练模型。同时在参数微调时也采用该设置加载已训练模型。
请在训练前确认预训练模型下载与加载正确,否则训练过程中损失可能会出现NAN。
- RRPN
```
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
```
- 通过设置export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 可选参数见:
python train.py --help
**数据读取器说明:** 数据读取器定义在reader.py中。所有图像将短边等比例缩放至`scales`,若长边大于`max_size`, 则再次将长边等比例缩放至`max_size`。在训练阶段,对图像采用随机旋转。
**模型设置:**
* 使用RotatedRoIAlign方法。
* 训练过程pre\_nms=12000, post\_nms=2000,测试过程pre\_nms=6000, post\_nms=1000。nms阈值为0.7。
* RPN网络得到labels的过程中,fg\_fraction=0.25,fg\_thresh=0.5,bg\_thresh_hi=0.5,bg\_thresh\_lo=0.0
* RPN选择anchor时,rpn\_fg\_fraction=0.5,rpn\_positive\_overlap=0.7,rpn\_negative\_overlap=0.3
**训练策略:**
* 默认配置采用8卡,每卡batch size=1
* 采用momentum优化算法训练,momentum=0.9。
* 权重衰减系数为0.02,前500轮学习率从0.00333线性增加至0.01。在6250,12500轮时使用0.1,0.01乘子进行学习率衰减,最大训练17500轮。训练最大轮数和学习率策略可以在config.py中对max_iter和lr_steps进行设置。
* 非基础卷积层卷积bias学习率为整体学习率2倍。
* 基础卷积层中,affine_layers参数不更新,res2层参数不更新。
## 模型评估
模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[ICDAR2015官方评估](https://rrc.cvc.uab.es/?com=contestant)
`eval.py`是评估模块的主要执行程序,调用示例如下:
- RRPN
```
python eval.py \
--dataset=icdar2015 \
--pretrained_model=${path_to_trained_model}
```
- 通过设置`--pretrained_model=${path_to_trained_model}`指定训练好的模型,注意不是初始化的模型。
- 通过设置`export CUDA\_VISIBLE\_DEVICES=0`指定单卡GPU评估。
下表为模型评估结果:
RRPN
| 模型 | 批量大小 | 迭代次数 | F1 |
| :--------------- | :------------: | :------------------: |------: |
| [RRPN](https://paddleseg.bj.bcebos.com/deploy/temp/model_final.tar) |8 | 17500 | 0.8048 |
## 模型推断及可视化
模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下:
```
python infer.py \
--pretrained_model=${path_to_trained_model} \
--image_path=dataset/icdar2015 \
--draw_threshold=0.6
```
注意,请正确设置模型路径`${path_to_trained_model}`和预测图片路径。默认使用GPU设备,也可通过设置`--use_gpu=False`使用CPU设备。可通过设置`draw_threshold`调节得分阈值控制检测框的个数。
下图为模型可视化预测结果:
<p align="center">
<img src="image/img_120.jpg" height=576 width=1024 hspace='10'/>
<img src="image/img_119.jpg" height=576 width=1024 hspace='10'/> <br />
RRPN 预测可视化
</p>
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import errno
import os
import shutil
import time
import numpy as np
import re
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
def load_params(exe, prog, path):
"""
Load model from the given path.
Args:
exe (fluid.Executor): The fluid.Executor object.
prog (fluid.Program): load weight to which Program object.
path (string): URL string or loca model path.
"""
if not os.path.exists(path):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info('Loading parameters from {}...'.format(path))
def _if_exist(var):
param_exist = os.path.exists(os.path.join(path, var.name))
do_load = param_exist
if do_load:
logger.debug('load weight {}'.format(var.name))
return do_load
fluid.io.load_vars(exe, path, prog, predicate=_if_exist)
def save(exe, prog, path):
"""
Load model from the given path.
Args:
exe (fluid.Executor): The fluid.Executor object.
prog (fluid.Program): save weight from which Program object.
path (string): the path to save model.
"""
if os.path.isdir(path):
shutil.rmtree(path)
logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, prog)
def load_and_fusebn(exe, prog, path):
"""
Fuse params of batch norm to scale and bias.
Args:
exe (fluid.Executor): The fluid.Executor object.
prog (fluid.Program): save weight from which Program object.
path (string): the path to save model.
"""
logger.info('Load model and fuse batch norm if have from {}...'.format(
path))
if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path))
def _if_exist(var):
b = os.path.exists(os.path.join(path, var.name))
if b:
logger.debug('load weight {}'.format(var.name))
return b
all_vars = list(filter(_if_exist, prog.list_vars()))
# Since the program uses affine-channel, there is no running mean and var
# in the program, here append running mean and var.
# NOTE, the params of batch norm should be like:
# x_scale
# x_offset
# x_mean
# x_variance
# x is any prefix
mean_variances = set()
bn_vars = []
bn_in_path = True
inner_prog = fluid.Program()
inner_start_prog = fluid.Program()
inner_block = inner_prog.global_block()
with fluid.program_guard(inner_prog, inner_start_prog):
for block in prog.blocks:
ops = list(block.ops)
if not bn_in_path:
break
for op in ops:
if op.type == 'affine_channel':
# remove 'scale' as prefix
scale_name = op.input('Scale')[0] # _scale
bias_name = op.input('Bias')[0] # _offset
prefix = scale_name[:-5]
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'
if not os.path.exists(os.path.join(path, mean_name)):
bn_in_path = False
break
if not os.path.exists(os.path.join(path, variance_name)):
bn_in_path = False
break
bias = block.var(bias_name)
mean_vb = inner_block.create_var(
name=mean_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
variance_vb = inner_block.create_var(
name=variance_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
mean_variances.add(mean_vb)
mean_variances.add(variance_vb)
bn_vars.append(
[scale_name, bias_name, mean_name, variance_name])
if not bn_in_path:
fluid.io.load_vars(exe, path, prog, vars=all_vars)
logger.warning(
"There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path))
return
# load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace()
exe_cpu = fluid.Executor(place)
fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])
# load params on real place into global scope.
fluid.io.load_vars(exe, path, prog, vars=all_vars)
eps = 1e-5
for names in bn_vars:
scale_name, bias_name, mean_name, var_name = names
scale = fluid.global_scope().find_var(scale_name).get_tensor()
bias = fluid.global_scope().find_var(bias_name).get_tensor()
mean = fluid.global_scope().find_var(mean_name).get_tensor()
var = fluid.global_scope().find_var(var_name).get_tensor()
scale_arr = np.array(scale)
bias_arr = np.array(bias)
mean_arr = np.array(mean)
var_arr = np.array(var)
bn_std = np.sqrt(np.add(var_arr, eps))
new_scale = np.float32(np.divide(scale_arr, bn_std))
new_bias = bias_arr - mean_arr * new_scale
# fuse to scale and bias in affine_channel
scale.set(new_scale, exe.place)
bias.set(new_bias, exe.place)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from edict import AttrDict
import six
import numpy as np
_C = AttrDict()
cfg = _C
#
# Training options
#
_C.TRAIN = AttrDict()
# scales an image's shortest side
_C.TRAIN.scales = [800]
# max size of longest side
_C.TRAIN.max_size = 1333
# images per GPU in minibatch
_C.TRAIN.im_per_batch = 1
# roi minibatch size per image
_C.TRAIN.batch_size_per_im = 256
# target fraction of foreground roi minibatch
_C.TRAIN.fg_fractrion = 0.25
# overlap threshold for a foreground roi
_C.TRAIN.fg_thresh = 0.5
# overlap threshold for a background roi
_C.TRAIN.bg_thresh_hi = 0.5
_C.TRAIN.bg_thresh_lo = 0.0
# If False, only resize image and not pad, image shape is different between
# GPUs in one mini-batch. If True, image shape is the same in one mini-batch.
_C.TRAIN.padding_minibatch = False
# Snapshot period
_C.TRAIN.snapshot_iter = 1000
# number of RPN proposals to keep before NMS
_C.TRAIN.rpn_pre_nms_top_n = 12000
# number of RPN proposals to keep after NMS
_C.TRAIN.rpn_post_nms_top_n = 2000
# NMS threshold used on RPN proposals
_C.TRAIN.rpn_nms_thresh = 0.7
# min size in RPN proposals
_C.TRAIN.rpn_min_size = 0.0
# eta for adaptive NMS in RPN
_C.TRAIN.rpn_eta = 1.0
# number of RPN examples per image
_C.TRAIN.rpn_batch_size_per_im = 256
# remove anchors out of the image
_C.TRAIN.rpn_straddle_thresh = 0.
# target fraction of foreground examples pre RPN minibatch
_C.TRAIN.rpn_fg_fraction = 0.5
# min overlap between anchor and gt box to be a positive examples
_C.TRAIN.rpn_positive_overlap = 0.7
# max overlap between anchor and gt box to be a negative examples
_C.TRAIN.rpn_negative_overlap = 0.3
# stopgrad at a specified stage
_C.TRAIN.freeze_at = 2
# min area of ground truth box
_C.TRAIN.gt_min_area = -1
#
# Inference options
#
_C.TEST = AttrDict()
# scales an image's shortest side
_C.TEST.scales = [800]
# max size of longest side
_C.TEST.max_size = 1333
# eta for adaptive NMS in RPN
_C.TEST.rpn_eta = 1.0
# min score threshold to infer
_C.TEST.score_thresh = 0.01
# overlap threshold used for NMS
_C.TEST.nms_thresh = 0.3
# number of RPN proposals to keep before NMS
_C.TEST.rpn_pre_nms_top_n = 6000
# number of RPN proposals to keep after NMS
_C.TEST.rpn_post_nms_top_n = 1000
# min size in RPN proposals
_C.TEST.rpn_min_size = 0.0
# max number of detections
_C.TEST.detections_per_im = 300
# NMS threshold used on RPN proposals
_C.TEST.rpn_nms_thresh = 0.7
#
# Model options
#
# Whether use mask rcnn head
_C.MASK_ON = True
# weight for bbox regression targets
_C.bbox_reg_weights = [10.0, 10.0, 5.0, 5.0, 1.0]
# RPN anchor sizes
_C.anchor_sizes = [128, 256, 512]
# RPN anchor ratio
_C.aspect_ratio = [0.2, 0.5, 1.0]
# RPN anchor angle
_C.anchor_angle = [-30.0, 0.0, 30.0, 60.0, 90.0, 120.0]
# variance of anchors
_C.variances = [1., 1., 1., 1., 1.]
# stride of feature map
_C.rpn_stride = [16.0, 16.0]
# pooled width and pooled height
_C.roi_resolution = 14
# spatial scale
_C.spatial_scale = 1. / 16.
# resolution to represent rotated roi align
_C.resolution = 14
#
# SOLVER options
#
# derived learning rate the to get the final learning rate.
_C.learning_rate = 0.01
# maximum number of iterations
_C.max_iter = 140000
# warm up to learning rate
_C.warm_up_iter = 500
_C.start_factor = 1. / 3
# lr steps_with_decay
_C.lr_steps = [6250, 12500]
_C.lr_gamma = 0.1
# L2 regularization hyperparameter
_C.weight_decay = 0.0001
# momentum with SGD
_C.momentum = 0.9
#
# ENV options
#
# support both CPU and GPU
_C.use_gpu = True
# Whether use parallel
_C.parallel = True
# Class number
_C.class_num = 81
# support pyreader
_C.use_pyreader = True
_C.TRAIN.min_size = 800
_C.TRAIN.max_size = 1333
_C.TEST.min_size = 1000
# pixel mean values
_C.pixel_means = [0.485, 0.456, 0.406]
_C.pixel_std = [0.229, 0.224, 0.225]
# clip box to prevent overflowing
_C.bbox_clip = np.log(1000. / 16.)
def merge_cfg_from_args(args, mode):
"""Merge config keys, values in args into the global config."""
if mode == 'train':
sub_d = _C.TRAIN
else:
sub_d = _C.TEST
for k, v in sorted(six.iteritems(vars(args))):
d = _C
try:
value = eval(v)
except:
value = v
if k in sub_d:
sub_d[k] = value
else:
d[k] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import cv2
import numpy as np
from config import cfg
import os
from PIL import Image
class DatasetPath(object):
def __init__(self, mode, dataset_name):
self.mode = mode
self.data_dir = dataset_name
def get_data_dir(self):
if self.mode == 'train':
return os.path.join(self.data_dir, 'ch4_training_images')
elif self.mode == 'val':
return os.path.join(self.data_dir, 'ch4_test_images')
def get_file_list(self):
if self.mode == 'train':
return os.path.join(self.data_dir,
'ch4_training_localization_transcription_gt')
elif self.mode == 'val':
return os.path.join(self.data_dir,
'ch4_test_localization_transcription_gt')
def get_image_blob(roidb, mode):
"""Builds an input blob from the images in the roidb at the specified
scales.
"""
if mode == 'train' or mode == 'val':
with open(roidb['image'], 'rb') as f:
data = f.read()
data = np.frombuffer(data, dtype='uint8')
img = cv2.imdecode(data, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gt_boxes = roidb['boxes']
gt_label = roidb['gt_classes']
# resize
if mode == 'train':
img, im_scale = _resize(img, target_size=800, max_size=1333)
need_gt_boxes = gt_boxes.copy()
need_gt_boxes[:, :4] *= im_scale
img, need_gt_boxes, need_gt_label = _rotation(
img, need_gt_boxes, gt_label, prob=1.0, gt_margin=1.4)
else:
img, im_scale = _resize(img, target_size=1000, max_size=1778)
need_gt_boxes = gt_boxes
need_gt_label = gt_label
img = img.astype(np.float32, copy=False)
img = img / 255.0
mean = np.array(cfg.pixel_means)[np.newaxis, np.newaxis, :]
std = np.array(cfg.pixel_std)[np.newaxis, np.newaxis, :]
img -= mean
img /= std
img = img.transpose((2, 0, 1))
return img, im_scale, need_gt_boxes, need_gt_label
def _get_size_scale(w, h, min_size, max_size=None):
size = min_size
scale = 1.0
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w), scale
if w < h:
ow = size
oh = int(size * h / w)
scale = size / w
else:
oh = size
ow = int(size * w / h)
scale = size / h
scale = ow / w
return (oh, ow), scale
def _resize(im, target_size=800, max_size=1333):
if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.")
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.')
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
selected_size = target_size
if float(im_size_min) == 0:
raise ZeroDivisionError('min size of image is 0')
if max_size != 0:
im_scale = float(selected_size) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
resize_w = np.round(im_scale_x * float(im_shape[1]))
resize_h = np.round(im_scale_y * float(im_shape[0]))
im_info = [resize_h, resize_w, im_scale]
else:
im_scale_x = float(selected_size) / float(im_shape[1])
im_scale_y = float(selected_size) / float(im_shape[0])
resize_w = selected_size
resize_h = selected_size
im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), 2)
im = np.array(im)
return im, im_scale_x
def _rotation(image,
gt_boxes,
gt_label,
prob,
fixed_angle=-1,
r_range=(360, 0),
gt_margin=1.4):
rotate_range = r_range[0]
shift = r_range[1]
angle = np.array([np.max([0, fixed_angle])])
if np.random.rand() <= prob:
angle = np.array(
np.random.rand(1) * rotate_range - shift, dtype=np.int16)
'''
rotate image
'''
image = np.array(image)
(h, w) = image.shape[:2]
scale = 1.0
# set the rotation center
center = (w / 2, h / 2)
# anti-clockwise angle in the function
M = cv2.getRotationMatrix2D(center, angle, scale)
image = cv2.warpAffine(image, M, (w, h))
# back to PIL image
im_width, im_height = w, h
'''
rotate boxes
'''
need_gt_boxes = gt_boxes.copy()
origin_gt_boxes = need_gt_boxes
rotated_gt_boxes = np.empty((len(need_gt_boxes), 5), dtype=np.float32)
# anti-clockwise to clockwise arc
cos_cita = np.cos(np.pi / 180 * angle)
sin_cita = np.sin(np.pi / 180 * angle)
# clockwise matrix
rotation_matrix = np.array([[cos_cita, sin_cita], [-sin_cita, cos_cita]])
pts_ctr = origin_gt_boxes[:, 0:2]
pts_ctr = pts_ctr - np.tile((im_width / 2, im_height / 2),
(gt_boxes.shape[0], 1))
pts_ctr = np.array(np.dot(pts_ctr, rotation_matrix), dtype=np.int16)
pts_ctr = np.squeeze(
pts_ctr, axis=-1) + np.tile((im_width / 2, im_height / 2),
(gt_boxes.shape[0], 1))
origin_gt_boxes[:, 0:2] = pts_ctr
len_of_gt = len(origin_gt_boxes)
# rectificate the angle in the range of [-45, 45]
for idx in range(len_of_gt):
ori_angle = origin_gt_boxes[idx, 4]
height = origin_gt_boxes[idx, 3]
width = origin_gt_boxes[idx, 2]
# step 1: normalize gt (-45,135)
if width < height:
ori_angle += 90
width, height = height, width
# step 2: rotate (-45,495)
rotated_angle = ori_angle + angle
# step 3: normalize rotated_angle (-45,135)
while rotated_angle > 135:
rotated_angle = rotated_angle - 180
rotated_gt_boxes[idx, 0] = origin_gt_boxes[idx, 0]
rotated_gt_boxes[idx, 1] = origin_gt_boxes[idx, 1]
rotated_gt_boxes[idx, 3] = height * gt_margin
rotated_gt_boxes[idx, 2] = width * gt_margin
rotated_gt_boxes[idx, 4] = rotated_angle
x_inbound = np.logical_and(rotated_gt_boxes[:, 0] >= 0,
rotated_gt_boxes[:, 0] < im_width)
y_inbound = np.logical_and(rotated_gt_boxes[:, 1] >= 0,
rotated_gt_boxes[:, 1] < im_height)
inbound = np.logical_and(x_inbound, y_inbound)
need_gt_boxes = rotated_gt_boxes[inbound]
need_gt_label = gt_label.copy()
need_gt_label = need_gt_label[inbound]
return image, need_gt_boxes, need_gt_label
def prep_im_for_blob(im, pixel_means, target_size, max_size):
"""Prepare an image for use as a network input blob. Specially:
- Subtract per-channel pixel mean
- Convert to float32
- Rescale to each of the specified target size (capped at max_size)
Returns a list of transformed images, one for each target size. Also returns
the scale factors that were used to compute each returned image.
"""
im = im.astype(np.float32, copy=False)
im -= pixel_means
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(target_size) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
im = cv2.resize(
im,
None,
None,
fx=im_scale,
fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_height, im_width, channel = im.shape
channel_swap = (2, 0, 1) #(batch, channel, height, width)
im = im.transpose(channel_swap)
return im, im_scale
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import time
import numpy as np
import pickle
import paddle
import paddle.fluid as fluid
import reader
import models.model_builder as model_builder
import models.resnet as resnet
import checkpoint as checkpoint
from config import cfg
from utility import print_arguments, parse_args, check_gpu
from data_utils import DatasetPath
from eval_helper import *
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def eval():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
class_nums = cfg.class_num
model = model_builder.RRPN(
add_conv_body_func=resnet.ResNet(),
add_roi_box_head_func=resnet.ResNetC5(),
use_pyreader=False,
mode='val')
startup_prog = fluid.Program()
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
model.build_model(image_shape)
pred_boxes = model.eval_bbox_out()
infer_prog = infer_prog.clone(True)
exe.run(startup_prog)
# yapf: disable
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
if cfg.pretrained_model:
checkpoint.load_params(exe, infer_prog, cfg.pretrained_model)
# yapf: enable
test_reader = reader.test(1)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
fetch_list = [pred_boxes]
res_list = []
keys = [
'bbox', 'gt_box', 'gt_class', 'is_crowed', 'im_info', 'im_id',
'is_difficult'
]
for i, data in enumerate(test_reader()):
im_info = [data[0][1]]
result = exe.run(infer_prog,
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
pred_boxes_v = result[0]
nmsed_out = pred_boxes_v
outs = np.array(nmsed_out)
res = get_key_dict(outs, data[0], keys)
res_list.append(res)
if i % 50 == 0:
logger.info('test_iter {}'.format(i))
icdar_eval(res_list)
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
check_gpu(args.use_gpu)
eval()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import numpy as np
import paddle.fluid as fluid
import math
from config import cfg
import six
import numpy as np
import cv2
import Polygon as plg
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from config import cfg
import logging
logger = logging.getLogger(__name__)
def get_key_dict(out, data, key):
res = {}
for i in range(len(key)):
if i == 0:
res[key[i]] = out
else:
res[key[i]] = data[i]
return res
def get_labels_maps():
default_labels_maps = {1: 'text'}
if cfg.dataset == 'icdar2015':
return default_labels_maps
labels_map = {}
with open(os.path.join(cfg.data_dir, 'label_list')) as f:
lines = f.readlines()
for idx, line in enumerate(lines):
labels_map[idx + 1] = line.strip()
return labels_map
def draw_bounding_box_on_image(image_path,
image_name,
nms_out,
im_scale,
draw_threshold=0.8):
#if image is None:
image = Image.open(os.path.join(image_path, image_name))
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
labels_map = get_labels_maps()
for dt in np.array(nms_out):
num_id, score = dt.tolist()[:2]
x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist()[2:] / im_scale
if score < draw_threshold:
continue
draw.line(
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
width=2,
fill='red')
if image.mode == 'RGB':
draw.text((x1, y1), labels_map[num_id], (255, 255, 0))
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
def polygon_from_points(points):
"""
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
"""
res_boxes = np.empty([1, 8], dtype='int32')
res_boxes[0, 0] = int(points[0])
res_boxes[0, 4] = int(points[1])
res_boxes[0, 1] = int(points[2])
res_boxes[0, 5] = int(points[3])
res_boxes[0, 2] = int(points[4])
res_boxes[0, 6] = int(points[5])
res_boxes[0, 3] = int(points[6])
res_boxes[0, 7] = int(points[7])
point_mat = res_boxes[0].reshape([2, 4]).T
return plg.Polygon(point_mat)
def clip_box(bbox, im_info):
h = im_info[0]
w = im_info[1]
res = []
for b in bbox:
pts = b.reshape(4, 2)
pts[np.where(pts < 0)] = 1
pts[np.where(pts[:, 0] > w), 0] = w - 1
pts[np.where(pts[:, 1] > h), 1] = h - 1
pts = pts.reshape(-1)
pts /= im_info[2]
res.append(pts)
return np.array(res)
def get_union(det, gt):
area_det = det.area()
area_gt = gt.area()
return area_det + area_gt - get_intersection(det, gt)
def get_intersection_over_union(det, gt):
try:
return get_intersection(det, gt) / get_union(det, gt)
except:
return 0
def get_intersection(det, gt):
inter = det & gt
if len(inter) == 0:
return 0
return inter.area()
def parse_gt(result, im_id):
for res in result:
if res['im_id'] == im_id:
gt_boxes = list(res['gt_box'])
gt_class = res['gt_class']
is_difficult = res['is_difficult'].reshape(-1)
objects = []
for i in range(len(gt_boxes)):
object_struct = {}
object_struct['bbox'] = gt_boxes[i]
object_struct['class'] = gt_class[i]
if is_difficult[i] == 1:
object_struct['difficult'] = 1
else:
object_struct['difficult'] = 0
object_struct['im_id'] = im_id
objects.append(object_struct)
return objects
def calculate_ap(rec, prec):
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
return ap
def icdar_map(result, class_name, ovthresh):
im_ids = []
for res in result:
im_ids.append(res['im_id'])
recs = {}
for i, im_id in enumerate(im_ids):
recs[str(im_id)] = parse_gt(result, im_id)
class_recs = {}
npos = 0
for k in im_ids:
res = [obj for obj in recs[str(k)] if obj['class'] == class_name]
bbox = np.array([x['bbox'] for x in res])
difficult = np.array([x['difficult'] for x in res]).astype(np.bool)
det = [False] * len(res)
npos = npos + sum(~difficult)
class_recs[k] = {'bbox': bbox, 'difficult': difficult, 'det': det}
image_ids = []
confidence = []
bbox = []
for res in result:
im_info = res['im_info']
pred_boxes = res['bbox']
for box in pred_boxes:
if box[0] == class_name:
image_ids.append(res['im_id'])
confidence.append(box[1])
clipd_box = clip_box(box[2:].reshape(-1, 8), im_info)
bbox.append(clipd_box[0])
confidence = np.array(confidence)
sorted_ind = np.argsort(-confidence)
sorted_scores = np.sort(-confidence)
bbox = np.array(bbox)
bbox = bbox[sorted_ind, :]
image_ids = [image_ids[x] for x in sorted_ind]
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
for d in range(nd):
res = class_recs[image_ids[d]]
bb = bbox[d, :].astype(float)
ovmax = -np.inf
gt_bbox = res['bbox'].astype(float)
if gt_bbox.size > 0:
# compute overlaps
gt_bbox_xmin = np.min(gt_bbox[:, 0::2], axis=1)
gt_bbox_ymin = np.min(gt_bbox[:, 1::2], axis=1)
gt_bbox_xmax = np.max(gt_bbox[:, 0::2], axis=1)
gt_bbox_ymax = np.max(gt_bbox[:, 1::2], axis=1)
bb_xmin = np.min(bb[0::2])
bb_ymin = np.min(bb[1::2])
bb_xmax = np.max(bb[0::2])
bb_ymax = np.max(bb[1::2])
ixmin = np.maximum(gt_bbox_xmin, bb_xmin)
iymin = np.maximum(gt_bbox_ymin, bb_ymin)
ixmax = np.minimum(gt_bbox_xmax, bb_xmax)
iymax = np.minimum(gt_bbox_ymax, bb_ymax)
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
inters = iw * ih
# union
uni = ((bb_xmax - bb_xmin + 1.) * (bb_ymax - bb_ymin + 1.) +
(gt_bbox_xmax - gt_bbox_xmin + 1.) *
(gt_bbox_ymax - gt_bbox_ymin + 1.) - inters)
overlaps = inters / uni
gt_bbox_keep_mask = overlaps > 0
gt_bbox_keep = gt_bbox[gt_bbox_keep_mask, :]
gt_bbox_keep_index = np.where(overlaps > 0)[0]
def calcoverlaps(gt_bbox_keep, bb):
overlaps = []
for index, _ in enumerate(gt_bbox_keep):
p_g = polygon_from_points(gt_bbox_keep[index])
p_d = polygon_from_points(bb)
overlap = get_intersection_over_union(p_d, p_g)
overlaps.append(overlap)
return overlaps
if len(gt_bbox_keep) > 0:
overlaps = calcoverlaps(gt_bbox_keep, bb)
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
jmax = gt_bbox_keep_index[jmax]
if ovmax > ovthresh:
if not res['difficult'][jmax]:
if not res['det'][jmax]:
tp[d] = 1.
res['det'][jmax] = 1
else:
fp[d] = 1.
else:
fp[d] = 1.
# compute precision recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(npos)
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = calculate_ap(rec, prec)
return rec, prec, ap
def icdar_map_eval(result, num_class):
map = 0
for i in range(num_class - 1):
rec, prec, ap = icdar_map(result, i + 1, ovthresh=0.5)
map = map + ap
map = map / (num_class - 1)
logger.info('mAP {}'.format(map))
def icdar_box_eval(result, thresh):
matched_sum = 0
num_global_care_gt = 0
num_global_care_det = 0
for res in result:
im_info = res['im_info']
h = im_info[1]
w = im_info[2]
gt_boxes = res['gt_box']
pred_boxes = res['bbox']
pred_boxes = pred_boxes[np.where(pred_boxes[:, 1] > thresh)]
pred_boxes = pred_boxes[:, 2:]
pred_boxes = clip_box(pred_boxes, im_info)
is_difficult = res['is_difficult']
det_matched = 0
iou_mat = np.empty([1, 1])
gt_pols = []
det_pols = []
gt_pol_points = []
det_pol_points = []
gt_dont_care_pols_num = []
det_dont_care_pols_num = []
det_matched_nums = []
points_list = list(gt_boxes)
dony_care = is_difficult.reshape(-1)
for i, points in enumerate(points_list):
gt_pol = polygon_from_points(list(points))
gt_pols.append(gt_pol)
gt_pol_points.append(list(points))
if dony_care[i] == 1:
gt_dont_care_pols_num.append(len(gt_pols) - 1)
for i, points in enumerate(pred_boxes):
points = list(points.reshape(8).astype(np.int32))
det_pol = polygon_from_points(points)
det_pols.append(det_pol)
det_pol_points.append(points)
if len(gt_dont_care_pols_num) > 0:
for dont_care_pol in gt_dont_care_pols_num:
dont_care_pol = gt_pols[dont_care_pol]
intersected_area = get_intersection(dont_care_pol, det_pol)
pd_dimensions = det_pol.area()
precision = 0 if pd_dimensions == 0 else intersected_area / pd_dimensions
if (precision > 0.5):
det_dont_care_pols_num.append(len(det_pols) - 1)
break
if len(gt_pols) > 0 and len(det_pols) > 0:
# Calculate IoU and precision matrixs
output_shape = [len(gt_pols), len(det_pols)]
iou_mat = np.empty(output_shape)
gt_rect_mat = np.zeros(len(gt_pols), np.int8)
det_rect_mat = np.zeros(len(det_pols), np.int8)
for gt_num in range(len(gt_pols)):
for det_num in range(len(det_pols)):
p_d = gt_pols[gt_num]
p_g = det_pols[det_num]
iou_mat[gt_num, det_num] = get_intersection_over_union(p_d,
p_g)
for gt_num in range(len(gt_pols)):
for det_num in range(len(det_pols)):
if gt_rect_mat[gt_num] == 0 and det_rect_mat[
det_num] == 0 and gt_num not in gt_dont_care_pols_num and det_num not in det_dont_care_pols_num:
if iou_mat[gt_num, det_num] > 0.5:
gt_rect_mat[gt_num] = 1
det_rect_mat[det_num] = 1
det_matched += 1
det_matched_nums.append(det_num)
num_gt_care = (len(gt_pols) - len(gt_dont_care_pols_num))
num_det_care = (len(det_pols) - len(det_dont_care_pols_num))
matched_sum += det_matched
num_global_care_gt += num_gt_care
num_global_care_det += num_det_care
method_recall = 0 if num_global_care_gt == 0 else float(
matched_sum) / num_global_care_gt
method_precision = 0 if num_global_care_det == 0 else float(
matched_sum) / num_global_care_det
method_hmean = 0 if method_recall + method_precision == 0 else 2 * method_recall * method_precision / (
method_recall + method_precision)
logger.info('Recall {}'.format(method_recall))
logger.info('Precision {}'.format(method_precision))
logger.info('F1 {}'.format(method_hmean))
def icdar_eval(result):
if cfg.dataset == 'icdar2015':
icdar_box_eval(result, 0.8)
else:
icdar_map_eval(result, cfg.class_num)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import time
import numpy as np
import pickle
import paddle
import paddle.fluid as fluid
import reader
import models.model_builder as model_builder
import models.resnet as resnet
import checkpoint as checkpoint
from config import cfg
from data_utils import DatasetPath
from eval_helper import *
from utility import print_arguments, parse_args, check_gpu
def infer():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
class_nums = cfg.class_num
model = model_builder.RRPN(
add_conv_body_func=resnet.ResNet(),
add_roi_box_head_func=resnet.ResNetC5(),
use_pyreader=False,
mode='infer')
startup_prog = fluid.Program()
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
model.build_model(image_shape)
pred_boxes = model.eval_bbox_out()
infer_prog = infer_prog.clone(True)
exe.run(startup_prog)
# yapf: disable
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
if cfg.pretrained_model:
checkpoint.load_params(exe, infer_prog, cfg.pretrained_model)
# yapf: enable
infer_reader = reader.infer(cfg.image_path)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
fetch_list = [pred_boxes]
imgs = os.listdir(cfg.image_path)
imgs.sort()
for i, data in enumerate(infer_reader()):
result = exe.run(infer_prog,
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
nmsed_out = result[0]
im_info = data[0][1]
im_scale = im_info[2]
outs = np.array(nmsed_out)
draw_bounding_box_on_image(cfg.image_path, imgs[i], outs, im_scale,
cfg.draw_threshold)
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
check_gpu(args.use_gpu)
infer()
此差异已折叠。
# 自定义OP的编译过程
## 代码结构
- src: 扩展OP C++/CUDA 源码
- rrpn_lib.py: Python封装
## 安装PaddlePaddle
请通过如下方式安装PaddlePaddle:
- 通过[Paddle develop分支](https://github.com/PaddlePaddle/Paddle/tree/develop)源码编译安装,编译方法如下:
1. [Ubuntu](https://www.paddlepaddle.org.cn/install/doc/source/ubuntu)
1. [CentOS](https://www.paddlepaddle.org.cn/install/doc/source/centos)
1. [MasOS](https://www.paddlepaddle.org.cn/install/doc/source/macos)
1. [Windows](https://www.paddlepaddle.org.cn/install/doc/source/windows)
**说明:** 推荐使用docker编译
- 安装Paddle develop[每日版本whl包](https://www.paddlepaddle.org.cn/install/doc/tables#多版本whl包列表-dev-11)
**注意:** 编译自定义OP使用的gcc版本须与Paddle编译使用gcc版本一致,Paddle develop每日版本目前采用**gcc 4.8.2**版本编译,若使用每日版本,请使用**gcc 4.8.2**版本编译自定义OP,否则可能出现兼容性问题。
## 编译自定义OP
自定义op需要将实现的C++、CUDA代码编译成动态库,mask.sh中通过g++/nvcc编译,当然您也可以写Makefile或者CMake。
编译需要include PaddlePaddle的相关头文件,链接PaddlePaddle的lib库。 头文件和lib库可通过下面命令获取到:
```
# python
>>> import paddle
>>> print(paddle.sysconfig.get_include())
/paddle/pyenv/local/lib/python2.7/site-packages/paddle/include
>>> print(paddle.sysconfig.get_lib())
/paddle/pyenv/local/lib/python2.7/site-packages/paddle/libs
```
我们提供动态库编译脚本如下:
```
cd src
sh make.sh
```
最终编译会产出`rrpn_lib.so`
**说明:** 若使用源码编译安装PaddlePaddle的方式,编译过程中`cmake`未设置`WITH_MKLDNN`的方式,
编译自定义OP时会报错找不到`mkldnn.h`等文件,可在`make.sh`中删除编译命令中的`-DPADDLE_WITH_MKLDNN`选项。
## 设置环境变量
需要将Paddle的核心库设置到`LD_LIBRARY_PATH`里, 先运行下面程序获取路径:
```
import paddle
print(paddle.sysconfig.get_lib())
```
可通过如下方式添加动态库路径:
```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
```
更多关于如何在框架外部自定义 C++ OP,可阅读[官网说明文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/index_cn.html)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Based on
--------------------------------------------------------
@misc{ma2019rrpn,
author = {Jianqi Ma},
title = {{RRPN in pytorch}},
year = {2019},
howpublished = {\url{https://github.com/mjq11302010044/RRPN_pytorch}},
}
@article{Jianqi17RRPN,
Author = {Jianqi Ma and Weiyuan Shao and Hao Ye and Li Wang and Hong Wang
and Yingbin Zheng and Xiangyang Xue},
Title = {Arbitrary-Oriented Scene Text Detection via Rotation Proposals},
journal = {IEEE Transactions on Multimedia},
volume={20},
number={11},
pages={3111-3122},
year={2018}
}
--------------------------------------------------------
*/
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
#define PI 3.141592654
struct RangeInitFunctor {
int start;
int delta;
int* out;
HOSTDEVICE void operator()(size_t i) { out[i] = start + i * delta; }
};
// get trangle area after decompose intersecting polygons into triangles
template <typename T>
inline T trangle_area(T* a, T* b, T* c) {
return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) * (b[0] - c[0])) / 2.0;
}
// get area of intersecting
template <typename T>
inline T get_area(T* int_pts, int num_of_inter) {
T area = 0.0;
for (int i = 0; i < num_of_inter - 2; i++) {
area += fabs(
trangle_area<T>(int_pts, int_pts + 2 * i + 2, int_pts + 2 * i + 4));
}
return area;
}
// sort points to decompose intersecting polygons into triangles
template <typename T>
inline void reorder_pts(T* int_pts, int num_of_inter) {
if (num_of_inter > 0) {
T center[2] = {0.0, 0.0};
for (int i = 0; i < num_of_inter; i++) {
center[0] += int_pts[2 * i];
center[1] += int_pts[2 * i + 1];
}
center[0] /= num_of_inter;
center[1] /= num_of_inter;
T vs[16];
T v[2];
T d;
for (int i = 0; i < num_of_inter; i++) {
v[0] = int_pts[2 * i] - center[0];
v[1] = int_pts[2 * i + 1] - center[1];
d = sqrt(v[0] * v[0] + v[1] * v[1]);
v[0] = v[0] / d;
v[1] = v[1] / d;
if (v[1] < 0) {
v[0] = -2 - v[0];
}
vs[i] = v[0];
}
float temp, tx, ty;
int j;
for (int i = 1; i < num_of_inter; ++i) {
if (vs[i - 1] > vs[i]) {
temp = vs[i];
tx = int_pts[2 * i];
ty = int_pts[2 * i + 1];
j = i;
while (j > 0 && vs[j - 1] > temp) {
vs[j] = vs[j - 1];
int_pts[j * 2] = int_pts[j * 2 - 2];
int_pts[j * 2 + 1] = int_pts[j * 2 - 1];
j--;
}
vs[j] = temp;
int_pts[j * 2] = tx;
int_pts[j * 2 + 1] = ty;
}
}
}
}
// determine if points intersect
template <typename T>
inline bool inter2line(T* pts1, T* pts2, int i, int j, T* temp_pts) {
T a[2] = {pts1[2 * i], pts1[2 * i + 1]};
T b[2] = {pts1[2 * ((i + 1) % 4)], pts1[2 * ((i + 1) % 4) + 1]};
T c[2] = {pts2[2 * j], pts2[2 * j + 1]};
T d[2] = {pts2[2 * ((j + 1) % 4)], pts2[2 * ((j + 1) % 4) + 1]};
T area_abc, area_abd, area_cda, area_cdb;
area_abc = trangle_area<T>(a, b, c);
area_abd = trangle_area<T>(a, b, d);
if (area_abc * area_abd >= -1e-5) {
return false;
}
area_cda = trangle_area<T>(c, d, a);
area_cdb = area_cda + area_abc - area_abd;
if (area_cda * area_cdb >= -1e-5) {
return false;
}
T t = area_cda / (area_abd - area_abc);
T dx = t * (b[0] - a[0]);
T dy = t * (b[1] - a[1]);
temp_pts[0] = a[0] + dx;
temp_pts[1] = a[1] + dy;
return true;
}
template <typename T>
inline bool inrect(T pt_x, T pt_y, T* pts) {
T ab[2] = {pts[2] - pts[0], pts[3] - pts[1]};
T ad[2] = {pts[6] - pts[0], pts[7] - pts[1]};
T ap[2] = {pt_x - pts[0], pt_y - pts[1]};
T abab = ab[0] * ab[0] + ab[1] * ab[1];
T abap = ab[0] * ap[0] + ab[1] * ap[1];
T adad = ad[0] * ad[0] + ad[1] * ad[1];
T adap = ad[0] * ap[0] + ad[1] * ap[1];
bool result = (abab - abap >= -1) and (abap >= -1) and (adad - adap >= -1) and
(adap >= -1);
return result;
}
// calculate the number of intersection points
template <typename T>
inline int inter_pts(T* pts1, T* pts2, T* int_pts) {
int num_of_inter = 0;
for (int i = 0; i < 4; i++) {
if (inrect<T>(pts1[2 * i], pts1[2 * i + 1], pts2)) {
int_pts[num_of_inter * 2] = pts1[2 * i];
int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1];
num_of_inter++;
}
if (inrect<T>(pts2[2 * i], pts2[2 * i + 1], pts1)) {
int_pts[num_of_inter * 2] = pts2[2 * i];
int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1];
num_of_inter++;
}
}
T out_pts[2];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
bool has_pts = inter2line<T>(pts1, pts2, i, j, out_pts);
if (has_pts) {
int_pts[num_of_inter * 2] = out_pts[0];
int_pts[num_of_inter * 2 + 1] = out_pts[1];
num_of_inter++;
}
}
}
return num_of_inter;
}
// convert x,y,w,h,angle to x1,y1,x2,y2,x3,y3,x4,y4
template <typename T>
inline void convert_region(T* pts,
const framework::Tensor& _region,
int index) {
auto region = framework::EigenTensor<T, 2>::From(_region);
T angle = region(index, 4);
T a_cos = cos(angle / 180.0 * PI);
T a_sin = -sin(angle / 180.0 * PI); // anti clock-wise
T ctr_x = region(index, 0);
T ctr_y = region(index, 1);
T h = region(index, 3);
T w = region(index, 2);
T pts_x[4] = {-w / 2, -w / 2, w / 2, w / 2};
T pts_y[4] = {-h / 2, h / 2, h / 2, -h / 2};
for (int i = 0; i < 4; i++) {
pts[2 * i] = a_cos * pts_x[i] - a_sin * pts_y[i] + ctr_x;
pts[2 * i + 1] = a_sin * pts_x[i] + a_cos * pts_y[i] + ctr_y;
}
}
// Calculate the area of intersection
template <typename T>
inline float inter(const framework::Tensor& _region1,
const framework::Tensor& _region2,
const int& r,
const int& c) {
T pts1[8];
T pts2[8];
T int_pts[16];
int num_of_inter;
convert_region<T>(pts1, _region1, r);
convert_region<T>(pts2, _region2, c);
num_of_inter = inter_pts<T>(pts1, pts2, int_pts);
reorder_pts<T>(int_pts, num_of_inter);
return get_area<T>(int_pts, num_of_inter);
}
template <typename T>
inline float devRotateIoU(const framework::Tensor& _region1,
const framework::Tensor& _region2,
const int r,
const int c) {
auto __region1 = framework::EigenTensor<T, 2>::From(_region1);
auto __region2 = framework::EigenTensor<T, 2>::From(_region2);
if ((fabs(__region1(r, 0) - __region2(c, 0)) < 1e-5) &&
(fabs(__region1(r, 1) - __region2(c, 1)) < 1e-5) &&
(fabs(__region1(r, 2) - __region2(c, 2)) < 1e-5) &&
(fabs(__region1(r, 3) - __region2(c, 3)) < 1e-5) &&
(fabs(__region1(r, 4) - __region2(c, 4)) < 1e-5)) {
return 1.0;
}
T area1, area2, area_inter;
area1 = __region1(r, 2) * __region1(r, 3);
area2 = __region2(c, 2) * __region2(c, 3);
area_inter = inter<T>(_region1, _region2, r, c);
auto result = area_inter / (area1 + area2 - area_inter);
if (result < 0) {
result = 0.0;
}
// may have bugs which cause overlap > 1
if (result > 1.00000001) {
result = 0.0;
}
return result;
}
template <typename T>
inline void BoxToDelta2(const int box_num,
const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes,
const float* weights,
framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
auto trg = framework::EigenTensor<T, 2>::From(*box_delta);
T ex_w, ex_h, ex_ctr_x, ex_ctr_y, ex_angle, gt_w, gt_h, gt_ctr_x, gt_ctr_y,
gt_angle;
for (int64_t i = 0; i < box_num; ++i) {
ex_w = ex_boxes_et(i, 2);
ex_h = ex_boxes_et(i, 3);
ex_ctr_x = ex_boxes_et(i, 0);
ex_ctr_y = ex_boxes_et(i, 1);
ex_angle = ex_boxes_et(i, 4);
gt_w = gt_boxes_et(i, 2);
gt_h = gt_boxes_et(i, 3);
gt_ctr_x = gt_boxes_et(i, 0);
gt_ctr_y = gt_boxes_et(i, 1);
gt_angle = gt_boxes_et(i, 4);
trg(i, 0) = (gt_ctr_x - ex_ctr_x) / ex_w;
trg(i, 1) = (gt_ctr_y - ex_ctr_y) / ex_h;
trg(i, 2) = std::log(gt_w / ex_w);
trg(i, 3) = std::log(gt_h / ex_h);
trg(i, 4) = gt_angle - ex_angle;
if (weights) {
trg(i, 0) = trg(i, 0) * weights[0];
trg(i, 1) = trg(i, 1) * weights[1];
trg(i, 2) = trg(i, 2) * weights[2];
trg(i, 3) = trg(i, 3) * weights[3];
trg(i, 4) = trg(i, 4) * weights[4];
}
if (gt_angle <= -30 && ex_angle >= 120) {
trg(i, 4) = trg(i, 4) + 180.0;
}
if (gt_angle >= 120 && ex_angle <= -30) {
trg(i, 4) = trg(i, 4) - 180.0;
}
trg(i, 4) = (PI / 180) * trg(i, 4);
}
}
template <typename T>
void Gather(
const T* in, const int in_stride, const int* index, const int num, T* out) {
const int stride_bytes = in_stride * sizeof(T);
for (int i = 0; i < num; ++i) {
int id = index[i];
memcpy(out + i * in_stride, in + id * in_stride, stride_bytes);
}
}
template <typename T>
void BboxOverlaps2(const framework::Tensor& r_boxes,
const framework::Tensor& c_boxes,
framework::Tensor* overlaps) {
auto overlaps_et = framework::EigenTensor<T, 2>::From(*overlaps);
int r_num = r_boxes.dims()[0];
int c_num = c_boxes.dims()[0];
for (int i = 0; i < r_num; ++i) {
for (int j = 0; j < c_num; ++j) {
overlaps_et(i, j) = devRotateIoU<T>(r_boxes, c_boxes, i, j);
}
}
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef PADDLE_WITH_LIBXSMM
#include <libxsmm.h>
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif
namespace paddle {
namespace operators {
namespace math {
/**
* Matrix Descriptor of a memory buffer.
*
* It is used for Blas::MatMul. MatMul operator can be batched.
* if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
* `batch_size` times of GEMM. The batched GEMM could be faster base on the
* implementation of the blas library. The batch size could be zero. If any
* matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
* Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
* [BatchSize, H1, W2]
*
* The boolean flag, `trans`, describe the memory is the transpose of matrix or
* not. If the trans is true, the last two dims of matrix are transposed. The
* memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
*
* The MatDescriptor is not only the dimension or shape of a matrix, it also
* contains the layout, stride of matrix. It is clearer to have a structure than
* reuse `DDim`.
*/
struct MatDescriptor {
int64_t height_;
int64_t width_;
int64_t stride_{0};
int64_t batch_size_{0};
bool trans_;
};
/**
* Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
* flag
*
* @param tensor_dim: The dimension of the tensor. The rank of this dimension
* must larger than 1.
*
* @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first
* dimension(column length) will be the product of tensor's first `num_col_dims`
* dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
* batch_size of descriptor.
*
* @param trans: True if the matrix is transposed.
*/
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
int num_flatten_cols,
bool trans);
template <typename DeviceContext>
class Blas {
public:
explicit Blas(const DeviceContext& context) : context_(context) {}
template <typename T>
void GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
T alpha,
const T* A,
const T* B,
T beta,
T* C) const;
template <typename T>
void GEMM(bool transA,
bool transB,
int M,
int N,
int K,
T alpha,
const T* A,
int lda,
const T* B,
int ldb,
T beta,
T* C,
int ldc) const;
template <typename T>
void GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
T alpha,
const T* A,
int lda,
const T* B,
int ldb,
T beta,
T* C,
int ldc) const;
#ifdef PADDLE_WITH_MKLML
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id,
const int M,
const int N,
const int K) const;
template <typename T>
void GEMM_PACK(const CBLAS_IDENTIFIER id,
const CBLAS_TRANSPOSE trans,
int M,
int N,
int K,
const T alpha,
const T* src,
const int ld,
T* dst) const;
template <typename T>
void GEMM_COMPUTE(int transA,
int transB,
int M,
int N,
int K,
const T* A,
const int lda,
const T* B,
const int ldb,
T beta,
T* C,
const int ldc) const;
template <typename T>
void GEMM_FREE(T* data) const;
template <typename T>
void CSRMM(const char* transa,
const int* m,
const int* n,
const int* k,
const T* alpha,
const char* matdescra,
const T* val,
const int* indx,
const int* pntrb,
const int* pntre,
const T* b,
const int* ldb,
const T* beta,
T* c,
const int* ldc) const;
#if !defined(PADDLE_WITH_CUDA)
template <typename T>
void MatMulWithHead(const framework::Tensor& mat_a,
const MatDescriptor& dim_a,
const framework::Tensor& mat_b,
const MatDescriptor& dim_b,
T alpha,
int head_number,
framework::Tensor* mat_out,
T beta,
bool mat_y_split_vertical) const;
#endif
#endif
template <typename T>
void MatMul(const int M,
const int N,
const int K,
const T* A,
const T* B,
T* C) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a,
bool trans_a,
const framework::Tensor& mat_b,
bool trans_b,
T alpha,
framework::Tensor* mat_out,
T beta) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a,
bool trans_a,
const framework::Tensor& mat_b,
bool trans_b,
framework::Tensor* mat_out) const {
MatMul(mat_a,
trans_a,
mat_b,
trans_b,
static_cast<T>(1.0),
mat_out,
static_cast<T>(0.0));
}
template <typename T>
void MatMul(const framework::Tensor& mat_a,
const framework::Tensor& mat_b,
framework::Tensor* mat_out) const {
this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
}
template <typename T>
void AXPY(int n, T alpha, const T* x, T* y) const;
template <typename T>
void VADD(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VSUB(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VMUL(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VDIV(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
template <typename T>
void VEXP(int n, const T* x, T* y) const;
template <typename T>
void VSQUARE(int n, const T* x, T* y) const;
template <typename T>
void VPOW(int n, const T* x, T alpha, T* y) const;
template <typename T>
void GEMV(bool trans_a,
int M,
int N,
T alpha,
const T* A,
const T* B,
T beta,
T* C) const;
template <typename T>
T DOT(int n, const T* x, const T* y) const;
template <typename T>
void SCAL(int n, const T a, T* x) const;
template <typename T>
T ASUM(int n, T* x, int inc) const;
template <typename T>
void BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
T alpha,
const T* A,
const T* B,
T beta,
T* C,
int batchCount,
int64_t strideA,
int64_t strideB) const;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int W1,
int H1,
int W2,
int H2,
T alpha,
const T* A,
const T* B,
T beta,
T* C,
int batchCount,
int64_t strideA,
int64_t strideB,
int64_t head_number,
bool split_b_vertical) const;
#endif
template <typename T>
void MatMul(const framework::Tensor& mat_a,
const MatDescriptor& dim_a,
const framework::Tensor& mat_b,
const MatDescriptor& dim_b,
T alpha,
framework::Tensor* mat_out,
T beta) const;
template <typename T>
void VINV(int n, const T* a, T* y) const;
template <typename T>
void VMERF(int n, const T* a, T* y, int64_t mode) const;
private:
const DeviceContext& context_;
};
template <typename DeviceContext, typename T>
class BlasT : private Blas<DeviceContext> {
public:
using Blas<DeviceContext>::Blas;
template <typename... ARGS>
void GEMM(ARGS... args) const {
Base()->template GEMM<T>(args...);
}
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
return Base()->template GEMM_ALLOC<T>(args...);
}
template <typename... ARGS>
void GEMM_PACK(ARGS... args) const {
Base()->template GEMM_PACK<T>(args...);
}
template <typename... ARGS>
void GEMM_COMPUTE(ARGS... args) const {
Base()->template GEMM_COMPUTE<T>(args...);
}
template <typename... ARGS>
void GEMM_FREE(ARGS... args) const {
Base()->template GEMM_FREE<T>(args...);
}
template <typename... ARGS>
void CSRMM(ARGS... args) const {
Base()->template CSRMM<T>(args...);
}
#if !defined(PADDLE_WITH_CUDA)
template <typename... ARGS>
void MatMulWithHead(ARGS... args) const {
Base()->template MatMulWithHead<T>(args...);
}
#endif
#endif
template <typename... ARGS>
void MatMul(ARGS... args) const {
Base()->template MatMul<T>(args...);
}
template <typename... ARGS>
void AXPY(ARGS... args) const {
Base()->template AXPY<T>(args...);
}
template <typename... ARGS>
void VADD(ARGS... args) const {
Base()->template VADD<T>(args...);
}
template <typename... ARGS>
void VSUB(ARGS... args) const {
Base()->template VSUB<T>(args...);
}
template <typename... ARGS>
void VMUL(ARGS... args) const {
Base()->template VMUL<T>(args...);
}
template <typename... ARGS>
void VDIV(ARGS... args) const {
Base()->template VDIV<T>(args...);
}
template <typename... ARGS>
void VCOPY(ARGS... args) const {
Base()->template VCOPY<T>(args...);
}
template <typename... ARGS>
void VEXP(ARGS... args) const {
Base()->template VEXP<T>(args...);
}
template <typename... ARGS>
void VSQUARE(ARGS... args) const {
Base()->template VSQUARE<T>(args...);
}
template <typename... ARGS>
void VPOW(ARGS... args) const {
Base()->template VPOW<T>(args...);
}
template <typename... ARGS>
void GEMV(ARGS... args) const {
Base()->template GEMV<T>(args...);
}
template <typename... ARGS>
T DOT(ARGS... args) const {
return Base()->template DOT<T>(args...);
}
template <typename... ARGS>
void SCAL(ARGS... args) const {
Base()->template SCAL<T>(args...);
}
template <typename... ARGS>
T ASUM(ARGS... args) const {
return Base()->template ASUM<T>(args...);
}
template <typename... ARGS>
void BatchedGEMM(ARGS... args) const {
Base()->template BatchedGEMM<T>(args...);
}
template <typename... ARGS>
void VINV(ARGS... args) const {
Base()->template VINV<T>(args...);
}
template <typename... ARGS>
void VMERF(ARGS... args) const {
Base()->template VMERF<T>(args...);
}
private:
const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this);
}
};
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(
const framework::ExecutionContext& exe_ctx) {
return BlasT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return BlasT<DeviceContext, T>(dev_ctx);
}
} // namespace math
} // namespace operators
} // namespace paddle
#include "paddle/fluid/operators/math/blas_impl.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
/* Copyright (c) 2019 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "concat_and_split.h"
#include <vector>
namespace paddle {
namespace operators {
namespace math {
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input,
int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());
// computation
auto output_data = output->data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = input[j].data<T>();
for (int k = 0; k < out_rows; ++k) {
memory::Copy(cpu_place,
output_data + k * out_cols + col_idx,
cpu_place,
input_data + k * col_len,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
};
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace operators {
namespace math {
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template <typename DeviceContext, typename T>
class ConcatFunctor {
public:
void operator()(const DeviceContext& context,
const std::vector<framework::Tensor>& input,
int axis,
framework::Tensor* output);
};
} // namespace math
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::DeviceContext;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params,
const IndexT* indices,
T* output,
size_t index_size,
size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input,
const int* input_dims,
const IndexT* indices,
T* output,
size_t remain_size,
size_t slice_size,
size_t end_size) {
CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
auto index_value = indices[indices_i * end_size + j];
assert(index_value >= 0 && index_value < input_dims[j]);
gather_i += (index_value * temp);
temp *= input_dims[j];
}
IndexT input_i = gather_i + slice_i;
*(output + i) = *(input + input_i);
}
}
/**
* A thin wrapper on gpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T, typename IndexT = int>
void GPUGather(const platform::DeviceContext& ctx,
const Tensor& src,
const Tensor& index,
Tensor* output) {
// check index of shape 1-D
if (index.dims().size() == 1) {
PADDLE_ENFORCE_GT(index.dims()[0],
0,
"The index of gather_op should not be empty when the "
"index's rank is 1.");
} else if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1],
1,
" If the index's rank of gather_op is 2, the second "
"dimension should be 1.");
}
int index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
GatherCUDAKernel<T, IndexT><<<
grid,
block,
0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory.h>
#include <cstring>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using framework::Tensor;
/**
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T, typename IndexT = int>
void CPUGather(const platform::DeviceContext& ctx,
const Tensor& src,
const Tensor& index,
Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1],
1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in gather_op.");
} else {
PADDLE_ENFORCE_EQ(index.dims().size(),
1,
"index.dims().size() should be 1 or 2 in gather_op.");
}
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < index_size; ++i) {
IndexT index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
}
} // namespace operators
} // namespace paddle
include_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_include())' )
lib_dir=$( python -c 'import paddle; print(paddle.sysconfig.get_lib())' )
echo $include_dir
echo $lib_dir
CUDA=$1
CUDNN=$2
NCCL=$3
if [ ! -d "$CUDA" ]; then
echo "Usage: sh make.sh \$CUDA_PATH \$CUDNN_PATH \$NCCL_PATH"
exit
fi
if [ ! -d "$CUDNN" ]; then
echo "Usage: sh make.sh \${CUDA_PATH} \${CUDNN_PATH} \${NCCL_PATH}"
exit
fi
if [ ! -d "$NCCL" ]; then
echo "Usage: sh make.sh \${CUDA_PATH} \${CUDNN_PATH} \${NCCL_PATH}"
exit
fi
git clone https://github.com/NVlabs/cub.git
nvcc rrpn_generate_proposals_op.cu -c -o rrpn_generate_proposals_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
-I ${include_dir} \
-I ${include_dir}/third_party \
-I ${CUDA}/include \
-I ${CUDNN}/include \
-I ${NCCL}/include \
-L ${lib_dir} -lpaddle_framework \
-L ${CUDA}/lib64 -lcudart
nvcc rotated_anchor_generator_op.cu -c -o rotated_anchor_generator_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
-I ${include_dir} \
-I ${include_dir}/third_party \
-I ${CUDA}/include \
-I ${CUDNN}/include \
-I ${NCCL}/include \
-L ${lib_dir} -lpaddle_framework \
-L ${CUDA}/lib64 -lcudart
nvcc rrpn_box_coder_op.cu -c -o rrpn_box_coder_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
-I ${include_dir} \
-I ${include_dir}/third_party \
-I ${CUDA}/include \
-I ${CUDNN}/include \
-I ${NCCL}/include \
-L ${lib_dir} -lpaddle_framework \
-L ${CUDA}/lib64 -lcudart
nvcc rrpn_rotated_roi_align_op.cu -c -o rrpn_rotated_roi_align_op.cu.o -ccbin cc -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O3 -DNVCC \
-I ${include_dir} \
-I ${include_dir}/third_party \
-I ${CUDA}/include \
-I ${CUDNN}/include \
-I ${NCCL}/include \
-L ${lib_dir} -lpaddle_framework \
-L ${CUDA}/lib64 -lcudart
g++ rotated_anchor_generator_op.cc concat_and_split.cc rrpn_generate_proposal_labels_op.cc rrpn_generate_proposals_op.cc rrpn_target_assign_op.cc rrpn_box_coder_op.cc rrpn_rotated_roi_align_op.cc rrpn_rotated_roi_align_op.cu.o rrpn_box_coder_op.cu.o rotated_anchor_generator_op.cu.o rrpn_generate_proposals_op.cu.o -o rrpn_lib.so -shared -fPIC -std=c++11 -O3 -DPADDLE_WITH_MKLDNN -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO \
-I ${include_dir} \
-I ${include_dir}/third_party \
-I ${CUDA}/include \
-I ${CUDNN}/include \
-I ${NCCL}/include \
-L ${lib_dir} -lpaddle_framework \
-L ${CUDA}/lib64 -lcudart
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "math_function.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#endif
#include <vector>
#include "math_function_impl.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
namespace math {
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, \
platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
DEFINE_CPU_TRANS(3);
DEFINE_CPU_TRANS(4);
DEFINE_CPU_TRANS(5);
DEFINE_CPU_TRANS(6);
template <typename DeviceContext, typename T, int Rank>
void Transpose<DeviceContext, T, Rank>::operator()(
const DeviceContext& context,
const framework::Tensor& in,
framework::Tensor* out,
const std::vector<int>& axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(*out);
auto* dev = context.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T, int Rank>
struct Transpose {
void operator()(const DeviceContext& context,
const framework::Tensor& in,
framework::Tensor* out,
const std::vector<int>& axis);
};
void set_constant(const platform::DeviceContext& context,
framework::Tensor* tensor,
float value);
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "rotated_anchor_generator_op.h"
namespace paddle {
namespace operators {
class RotatedAnchorGeneratorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Input"),
"Input(Input) of RotatedAnchorGeneratorOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Anchors"),
"Output(Anchors) of RotatedAnchorGeneratorOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Variances"),
"Output(Variances) of RotatedAnchorGeneratorOp should not be null.");
auto input_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
auto anchor_sizes = ctx->Attrs().Get<std::vector<float>>("anchor_sizes");
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
auto angles = ctx->Attrs().Get<std::vector<float>>("angles");
auto stride = ctx->Attrs().Get<std::vector<float>>("stride");
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
size_t num_anchors =
aspect_ratios.size() * anchor_sizes.size() * angles.size();
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_anchors;
dim_vec[3] = 5;
ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
}
};
class RotatedAnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor, default Tensor<float>), "
"the input feature is a tensor with a rank of 4. "
"The layout is NCHW.");
AddOutput("Anchors",
"(Tensor, default Tensor<float>), the output is a "
"tensor with a rank of 4. The layout is [H, W, num_anchors, 5]. "
"H is the height of input, W is the width of input, num_anchors "
"is the box count of each position. "
"Each anchor is in (xctr, yctr, w, h, thelta) format");
AddOutput("Variances",
"(Tensor, default Tensor<float>), the expanded variances for "
"normalizing bbox regression targets. The layout is [H, W, "
"num_anchors, 5]. "
"H is the height of input, W is the width of input, num_anchors "
"is the box count of each position. "
"Each variance is in (xctr, yctr, w, h, thelta) format");
AddAttr<std::vector<float>>(
"anchor_sizes",
"(vector<float>) List of Rotated Region Proposal Network(RRPN) anchor "
"sizes "
" given in absolute pixels e.g. (64, 128, 256, 512)."
" For instance, the anchor size of 64 means the area of this anchor "
"equals to 64**2.")
.AddCustomChecker([](const std::vector<float>& anchor_sizes) {
PADDLE_ENFORCE_GT(anchor_sizes.size(),
0UL,
"Size of anchor_sizes must be at least 1.");
for (size_t i = 0; i < anchor_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(
anchor_sizes[i], 0.0, "anchor_sizes[%d] must be positive.", i);
}
});
AddAttr<std::vector<float>>(
"aspect_ratios",
"(vector<float>) List of Rotated Region Proposal Network(RRPN) anchor "
"aspect "
"ratios, e.g. (0.5, 1, 2)."
"For instacne, the aspect ratio of 0.5 means the height / width of "
"this anchor equals 0.5.");
AddAttr<std::vector<float>>(
"angles",
"(vector<float>) List of Rotated Region Proposal Network(RRPN) anchor "
"angles, "
"e.g. (-30.0, 0.0, 30.0, 60.0, 90.0, 120.0)."
"For instacne, the aspect ratio of 0.5 means the height / width of "
"this anchor equals 0.5.");
AddAttr<std::vector<float>>("variances",
"(vector<float>) List of variances to be used "
"in box regression deltas")
.AddCustomChecker([](const std::vector<float>& variances) {
PADDLE_ENFORCE_EQ(
variances.size(), 5UL, "Must and only provide 5 variance.");
for (size_t i = 0; i < variances.size(); ++i) {
PADDLE_ENFORCE_GT(
variances[i], 0.0, "variance[%d] must be greater than 0.", i);
}
});
AddAttr<std::vector<float>>("stride",
"Anchors stride across width and height, "
"with a default of (16, 16)")
.SetDefault(std::vector<float>(2, 16.0))
.AddCustomChecker([](const std::vector<float>& stride) {
PADDLE_ENFORCE_EQ(
stride.size(),
2UL,
"Must and only provide 2 stride for width and height.");
for (size_t i = 0; i < stride.size(); ++i) {
PADDLE_ENFORCE_GT(
stride[i], 0.0, "stride[%d] should be larger than 0.", i);
}
});
AddAttr<float>("offset",
"(float) "
"Anchor center offset, with a default of 0.5")
.SetDefault(0.5);
AddComment(R"DOC(
RotatedAnchorGenerator operator
Generates anchors for RRPN. algorithm.
Each position of the input produce N anchors, N =
size(anchor_sizes) * size(aspect_ratios) * size(angles).
Please get more information from the following papers:
https://arxiv.org/abs/1703.01086.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
rotated_anchor_generator,
ops::RotatedAnchorGeneratorOp,
ops::RotatedAnchorGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(rotated_anchor_generator,
ops::RotatedAnchorGeneratorOpKernel<float>,
ops::RotatedAnchorGeneratorOpKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "rotated_anchor_generator_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void GenRAnchors(T* out,
const T* aspect_ratios,
const int ar_num,
const T* anchor_sizes,
const int as_num,
const T* angles,
const int aa_num,
const T* stride,
const int sd_num,
const int height,
const int width,
const T offset) {
int num_anchors = as_num * ar_num * aa_num;
int box_num = height * width * num_anchors;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
i += blockDim.x * gridDim.x) {
int h_idx = i / (num_anchors * width);
int w_idx = (i / num_anchors) % width;
T stride_width = stride[0];
T stride_height = stride[1];
T x_ctr = (w_idx * stride_width) + offset * stride_width - 1;
T y_ctr = (h_idx * stride_height) + offset * stride_height - 1;
T area, area_ratios;
T base_w, base_h;
T scale_w, scale_h;
T anchor_width, anchor_height;
int anch_idx = i % num_anchors;
int ar_idx = anch_idx / (as_num * aa_num);
int as_idx = anch_idx / aa_num % as_num;
int aa_idx = anch_idx % aa_num;
T aspect_ratio = aspect_ratios[ar_idx];
T anchor_size = anchor_sizes[as_idx];
T angle = angles[aa_idx];
area = stride_width * stride_height;
area_ratios = area / aspect_ratio;
base_w = round(sqrt(area_ratios));
base_h = round(base_w * aspect_ratio);
scale_w = anchor_size / stride_width;
scale_h = anchor_size / stride_height;
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
out[i * 5] = x_ctr;
out[i * 5 + 1] = y_ctr;
out[i * 5 + 2] = anchor_width;
out[i * 5 + 3] = anchor_height;
out[i * 5 + 4] = angle;
}
}
template <typename T>
__global__ void SetVariance(T* out,
const T* var,
const int vnum,
const int num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
out[i] = var[i % vnum];
}
}
template <typename T>
class RotatedAnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
auto angles = ctx.Attr<std::vector<float>>("angles");
auto stride = ctx.Attr<std::vector<float>>("stride");
auto variances = ctx.Attr<std::vector<float>>("variances");
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto width = input->dims()[3];
auto height = input->dims()[2];
int num_anchors =
aspect_ratios.size() * anchor_sizes.size() * angles.size();
int box_num = width * height * num_anchors;
int block = 512;
int grid = (box_num + block - 1) / block;
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
anchors->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
framework::Tensor ar;
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar);
framework::Tensor as;
framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as);
framework::Tensor aa;
framework::TensorFromVector(angles, ctx.device_context(), &aa);
framework::Tensor sd;
framework::TensorFromVector(stride, ctx.device_context(), &sd);
GenRAnchors<T><<<grid, block, 0, stream>>>(anchors->data<T>(),
ar.data<T>(),
aspect_ratios.size(),
as.data<T>(),
anchor_sizes.size(),
aa.data<T>(),
angles.size(),
sd.data<T>(),
stride.size(),
height,
width,
offset);
framework::Tensor v;
framework::TensorFromVector(variances, ctx.device_context(), &v);
grid = (box_num * 5 + block - 1) / block;
SetVariance<T><<<grid, block, 0, stream>>>(
vars->data<T>(), v.data<T>(), variances.size(), box_num * 5);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(rotated_anchor_generator,
ops::RotatedAnchorGeneratorOpCUDAKernel<float>,
ops::RotatedAnchorGeneratorOpCUDAKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
//#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
class RotatedAnchorGeneratorOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
auto angles = ctx.Attr<std::vector<float>>("angles");
auto stride = ctx.Attr<std::vector<float>>("stride");
auto variances = ctx.Attr<std::vector<float>>("variances");
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2];
T stride_width, stride_height;
stride_width = stride[0];
stride_height = stride[1];
int num_anchors =
aspect_ratios.size() * anchor_sizes.size() * angles.size();
anchors->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
auto e_anchors = framework::EigenTensor<T, 4>::From(*anchors);
for (int h_idx = 0; h_idx < feature_height; ++h_idx) {
for (int w_idx = 0; w_idx < feature_width; ++w_idx) {
T x_ctr = (w_idx * stride_width) + offset * stride_width - 1;
T y_ctr = (h_idx * stride_height) + offset * stride_height - 1;
T area, area_ratios;
T base_w, base_h;
T scale_w, scale_h;
T anchor_width, anchor_height;
int idx = 0;
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
auto ar = aspect_ratios[r];
for (size_t s = 0; s < anchor_sizes.size(); ++s) {
auto anchor_size = anchor_sizes[s];
area = stride_width * stride_height;
area_ratios = area / ar;
base_w = round(sqrt(area_ratios));
base_h = round(base_w * ar);
scale_w = anchor_size / stride_width;
scale_h = anchor_size / stride_height;
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
for (size_t a = 0; a < angles.size(); ++a) {
auto angle = angles[a];
e_anchors(h_idx, w_idx, idx, 0) = x_ctr;
e_anchors(h_idx, w_idx, idx, 1) = y_ctr;
e_anchors(h_idx, w_idx, idx, 2) = anchor_width;
e_anchors(h_idx, w_idx, idx, 3) = anchor_height;
e_anchors(h_idx, w_idx, idx, 4) = angle;
idx++;
}
}
}
}
}
framework::Tensor var_t;
var_t.mutable_data<T>(
framework::make_ddim({1, static_cast<int>(variances.size())}),
ctx.GetPlace());
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
for (size_t i = 0; i < variances.size(); ++i) {
var_et(0, i) = variances[i];
}
int anchor_num = feature_height * feature_width * num_anchors;
auto var_dim = vars->dims();
vars->Resize({anchor_num, static_cast<int>(variances.size())});
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(anchor_num, 1));
vars->Resize(var_dim);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//#include "rrpn_box_coder_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class RRPNBoxCoderOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(PriorBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("TargetBox"),
"Input(TargetBox) of BoxCoderOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutputBox"),
"Output(OutputBox) of BoxCoderOp should not be null.");
auto prior_box_dims = ctx->GetInputDim("PriorBox");
// auto target_box_dims = ctx->GetInputDim("TargetBox");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
prior_box_dims.size(), 2, "The rank of Input PriorBox must be 2");
PADDLE_ENFORCE_EQ(
prior_box_dims[1], 5, "The shape of PriorBox is [N, 5]");
if (ctx->HasInput("PriorBoxVar")) {
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
PADDLE_ENFORCE(prior_box_var_dims.size() == 2,
"Input(PriorBoxVar) of BoxCoderOp should be 2.");
PADDLE_ENFORCE_EQ(
prior_box_dims,
prior_box_var_dims,
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox) when the rank is 2.");
}
}
}
};
class RRPNBoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"PriorBox",
"(Tensor, default Tensor<float>) "
"Box list PriorBox is a 2-D Tensor with shape [M, 5] holds M boxes, "
"each box is represented as [x, y, w, h, angle], "
"[x, y] is the center coordinate of the anchor box, "
"if the input is image feature map, they are close to the origin "
"of the coordinate system. [w, h] is the width and height "
"of the anchor box, angle is angle of rotation.");
AddInput("PriorBoxVar",
"(Tensor, default Tensor<float>, optional) "
"PriorBoxVar is a 2-D Tensor with shape [M, 5] holds M group "
"of variance. PriorBoxVar will set all elements to 1 by "
"default.")
.AsDispensable();
AddInput(
"TargetBox",
"(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape "
"[N, 5], each box is represented as [x, y, w, h, angle],"
"[x, y] is the center coordinate of the box, [w, h] is width and "
"height of the box,"
"angle is angle of rotation around the center of box.");
AddAttr<std::vector<float>>(
"variance",
"(vector<float>, default {}),"
"variance of prior box with shape [5]. PriorBoxVar and variance can"
"not be provided at the same time.")
.SetDefault(std::vector<float>{});
AddOutput("OutputBox",
"(Tensor) "
"2-D Tensor with shape [M, 5] which M represents the number of "
"deocded boxes"
"and 5 represents [x, y, w, h, angle]");
AddComment(R"DOC(
Rotatedi Bounding Box Coder.
Decode the target bounding box with the priorbox information.
The Decoding schema described below:
ox = pw * tx / pxv + cx
oy = ph * ty / pyv + cy
ow = exp(tw / pwv) * pw
oh = exp(th / phv) * ph
oa = ta / pav * 1.0 / 3.141592653 * 180 + pa
where `tx`, `ty`, `tw`, `th`, `ta` denote the target box's center coordinates, width
,height and angle respectively. Similarly, `px`, `py`, `pw`, `ph`, `pa` denote the
priorbox's (anchor) center coordinates, width, height and angle. `pxv`, `pyv`, `pwv`,
`phv`, `pav` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh`, `oa`
denote the encoded/decoded coordinates, width and height.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
rrpn_box_coder,
ops::RRPNBoxCoderOp,
ops::RRPNBoxCoderOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <string>
#include <vector>
#include "paddle/fluid/memory/memory.h"
//#include "rrpn_box_coder_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
#define PI 3.141592654
template <typename T>
__global__ void DecodeCenterSizeKernel(const T* prior_box_data,
const T* prior_box_var_data,
const T* target_box_data,
const int row,
const int len,
const T prior_box_var_size,
const float* variance,
const int var_size,
T* output) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
int prior_box_offset = 0;
if (idx < row) {
const int row_idx = idx;
prior_box_offset = row_idx * len;
T prior_box_width = prior_box_data[prior_box_offset + 2];
T prior_box_height = prior_box_data[prior_box_offset + 3];
T prior_box_center_x = prior_box_data[prior_box_offset];
T prior_box_center_y = prior_box_data[prior_box_offset + 1];
T prior_box_angle = prior_box_data[prior_box_offset + 4];
T target_box_width, target_box_height, target_box_angle;
T target_box_center_x, target_box_center_y;
T box_var_x = T(1), box_var_y = T(1);
T box_var_w = T(1), box_var_h = T(1), box_var_angle = T(1);
if (prior_box_var_data) {
int prior_var_offset = row_idx * len;
box_var_x = prior_box_var_data[prior_var_offset];
box_var_y = prior_box_var_data[prior_var_offset + 1];
box_var_w = prior_box_var_data[prior_var_offset + 2];
box_var_h = prior_box_var_data[prior_var_offset + 3];
box_var_angle = prior_box_var_data[prior_var_offset + 4];
} else if (var_size == 5) {
box_var_x = static_cast<T>(variance[0]);
box_var_y = static_cast<T>(variance[1]);
box_var_w = static_cast<T>(variance[2]);
box_var_h = static_cast<T>(variance[3]);
box_var_angle = static_cast<T>(variance[4]);
}
target_box_width =
exp(target_box_data[idx * len + 2] / box_var_w) * prior_box_width / 1.4;
target_box_height = exp(target_box_data[idx * len + 3] / box_var_h) *
prior_box_height / 1.4;
target_box_center_x =
target_box_data[idx * len] / box_var_x * prior_box_width +
prior_box_center_x;
target_box_center_y =
target_box_data[idx * len + 1] / box_var_y * prior_box_height +
prior_box_center_y;
target_box_angle =
(target_box_data[idx * len + 4] / box_var_angle) * 1.0 / PI * 180 +
prior_box_angle;
T a_cos = cos(PI / 180 * target_box_angle);
T a_sin = -sin(PI / 180 * target_box_angle);
T rotation_matrix[3][3];
rotation_matrix[0][0] = a_cos;
rotation_matrix[0][1] = a_sin;
rotation_matrix[0][2] = 0;
rotation_matrix[1][0] = -a_sin;
rotation_matrix[1][1] = a_cos;
rotation_matrix[1][2] = 0;
rotation_matrix[2][0] = -target_box_center_x * a_cos +
target_box_center_y * a_sin + target_box_center_x;
rotation_matrix[2][1] = -target_box_center_x * a_sin -
target_box_center_y * a_cos + target_box_center_y;
rotation_matrix[2][2] = 1;
T pt_x0 = target_box_center_x - target_box_width / 2;
T pt_x1 = target_box_center_x + target_box_width / 2;
T pt_x2 = target_box_center_x + target_box_width / 2;
T pt_x3 = target_box_center_x - target_box_width / 2;
T pt_y0 = target_box_center_y - target_box_height / 2;
T pt_y1 = target_box_center_y - target_box_height / 2;
T pt_y2 = target_box_center_y + target_box_height / 2;
T pt_y3 = target_box_center_y + target_box_height / 2;
output[idx * 8] = pt_x0 * rotation_matrix[0][0] +
pt_y0 * rotation_matrix[1][0] + rotation_matrix[2][0];
output[idx * 8 + 1] = pt_x0 * rotation_matrix[0][1] +
pt_y0 * rotation_matrix[1][1] + rotation_matrix[2][1];
output[idx * 8 + 2] = pt_x1 * rotation_matrix[0][0] +
pt_y1 * rotation_matrix[1][0] + rotation_matrix[2][0];
output[idx * 8 + 3] = pt_x1 * rotation_matrix[0][1] +
pt_y1 * rotation_matrix[1][1] + rotation_matrix[2][1];
output[idx * 8 + 4] = pt_x2 * rotation_matrix[0][0] +
pt_y2 * rotation_matrix[1][0] + rotation_matrix[2][0];
output[idx * 8 + 5] = pt_x2 * rotation_matrix[0][1] +
pt_y2 * rotation_matrix[1][1] + rotation_matrix[2][1];
output[idx * 8 + 6] = pt_x3 * rotation_matrix[0][0] +
pt_y3 * rotation_matrix[1][0] + rotation_matrix[2][0];
output[idx * 8 + 7] = pt_x3 * rotation_matrix[0][1] +
pt_y3 * rotation_matrix[1][1] + rotation_matrix[2][1];
}
}
template <typename DeviceContext, typename T>
class RRPNBoxCoderCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto* prior_box = context.Input<framework::Tensor>("PriorBox");
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
auto* output_box = context.Output<framework::Tensor>("OutputBox");
std::vector<float> variance = context.Attr<std::vector<float>>("variance");
const T* prior_box_data = prior_box->data<T>();
const T* target_box_data = target_box->data<T>();
const T* prior_box_var_data = nullptr;
auto prior_box_var_size = 0;
if (prior_box_var) {
PADDLE_ENFORCE(variance.empty(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time.");
prior_box_var_data = prior_box_var->data<T>();
prior_box_var_size = prior_box_var->dims().size();
}
if (!(variance.empty())) {
PADDLE_ENFORCE(static_cast<int>(variance.size()) == 5,
"Size of attribute 'variance' should be 4");
}
if (target_box->lod().size()) {
PADDLE_ENFORCE_EQ(
target_box->lod().size(), 1, "Only support 1 level of LoD.");
}
const int var_size = static_cast<int>(variance.size());
auto row = target_box->dims()[0];
auto len = 5;
int block = 512;
int grid = (row + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
int bytes = var_size * sizeof(float);
auto dev_var = memory::Alloc(device_ctx, bytes);
float* dev_var_data = reinterpret_cast<float*>(dev_var->ptr());
auto cplace = platform::CPUPlace();
const auto gplace = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(
gplace, dev_var_data, cplace, &variance[0], bytes, device_ctx.stream());
output_box->mutable_data<T>({row, 8}, context.GetPlace());
T* output = output_box->data<T>();
DecodeCenterSizeKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
prior_box_data,
prior_box_var_data,
target_box_data,
row,
len,
prior_box_var_size,
dev_var_data,
var_size,
output);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
rrpn_box_coder,
ops::RRPNBoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::RRPNBoxCoderCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
/**
* Get Reference From Pointer with check. The error message is printf format,
* and passed by `args`
*/
template <typename T, typename... ARGS>
inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE_NOT_NULL(ptr, ::paddle::string::Sprintf(args...));
return *ptr;
}
} // namespace detail
} // namespace operators
} // namespace paddle
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class NameAdapter(object):
"""Fix the backbones variable names for pretrained weight"""
def __init__(self, model):
super(NameAdapter, self).__init__()
self.model = model
@property
def model_type(self):
return getattr(self.model, '_model_type', '')
@property
def variant(self):
return getattr(self.model, 'variant', '')
def fix_conv_norm_name(self, name):
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
# the naming rule is same as pretrained weight
if self.model_type == 'SEResNeXt':
bn_name = name + "_bn"
return bn_name
def fix_shortcut_name(self, name):
if self.model_type == 'SEResNeXt':
name = 'conv' + name + '_prj'
return name
def fix_bottleneck_name(self, name):
if self.model_type == 'SEResNeXt':
conv_name1 = 'conv' + name + '_x1'
conv_name2 = 'conv' + name + '_x2'
conv_name3 = 'conv' + name + '_x3'
shortcut_name = name
else:
conv_name1 = name + "_branch2a"
conv_name2 = name + "_branch2b"
conv_name3 = name + "_branch2c"
shortcut_name = name + "_branch1"
return conv_name1, conv_name2, conv_name3, shortcut_name
def fix_layer_warp_name(self, stage_num, count, i):
name = 'res' + str(stage_num)
if count > 10 and stage_num == 4:
if i == 0:
conv_name = name + "a"
else:
conv_name = name + "b" + str(i)
else:
conv_name = name + chr(ord("a") + i)
return conv_name
def fix_c1_stage_name(self):
return "conv1"
此差异已折叠。
# Download the data.
echo "Downloading..."
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar --no-check-certificate
echo "Extracting..."
tar -xf ResNet50_cos_pretrained.tar
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册