提交 5056d3ec 编写于 作者: X Xingyuan Bu 提交者: qingqing01

FasterRCNN Anchor Generator Op (#11218)

* Add anchor generator operator for Faster-RCNN.
* Add unittest testing.
* Add Python API.
上级 5f79c7fb
......@@ -22,6 +22,8 @@ iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(anchor_generator_op SRCS anchor_generator_op.cc
anchor_generator_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
target_assign_op.cu)
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
namespace paddle {
namespace operators {
class AnchorGeneratorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of AnchorGeneratorOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Anchors"),
"Output(Anchors) of AnchorGeneratorOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Variances"),
"Output(Variances) of AnchorGeneratorOp should not be null.");
auto input_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
auto anchor_sizes = ctx->Attrs().Get<std::vector<float>>("anchor_sizes");
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
auto stride = ctx->Attrs().Get<std::vector<float>>("stride");
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
size_t num_anchors = aspect_ratios.size() * anchor_sizes.size();
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_anchors;
dim_vec[3] = 4;
ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
ctx.device_context());
}
};
class AnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor, default Tensor<float>), "
"the input feature is a tensor with a rank of 4. "
"The layout is NCHW.");
AddOutput("Anchors",
"(Tensor, default Tensor<float>), the output is a "
"tensor with a rank of 4. The layout is [H, W, num_anchors, 4]. "
"H is the height of input, W is the width of input, num_anchors "
"is the box count of each position. "
"Each anchor is in (xmin, ymin, xmax, ymax) format");
AddOutput("Variances",
"(Tensor, default Tensor<float>), the expanded variances for "
"normalizing bbox regression targets. The layout is [H, W, "
"num_anchors, 4]. "
"H is the height of input, W is the width of input, num_anchors "
"is the box count of each position. "
"Each variance is in (xcenter, ycenter, w, h) format");
AddAttr<std::vector<float>>(
"anchor_sizes",
"(vector<float>) List of Region Proposal Network(RPN) anchor sizes "
" given in absolute pixels e.g. (64, 128, 256, 512)."
" For instance, the anchor size of 64 means the area of this anchor "
"equals to 64**2.")
.AddCustomChecker([](const std::vector<float>& anchor_sizes) {
PADDLE_ENFORCE_GT(anchor_sizes.size(), 0,
"Size of anchor_sizes must be at least 1.");
for (size_t i = 0; i < anchor_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(anchor_sizes[i], 0.0,
"anchor_sizes[%d] must be positive.", i);
}
});
AddAttr<std::vector<float>>(
"aspect_ratios",
"(vector<float>) List of Region Proposal Network(RPN) anchor aspect "
"ratios, e.g. (0.5, 1, 2)."
"For instacne, the aspect ratio of 0.5 means the height / width of "
"this anchor equals 0.5.");
AddAttr<std::vector<float>>("variances",
"(vector<float>) List of variances to be used "
"in box regression deltas")
.AddCustomChecker([](const std::vector<float>& variances) {
PADDLE_ENFORCE_EQ(variances.size(), 4,
"Must and only provide 4 variance.");
for (size_t i = 0; i < variances.size(); ++i) {
PADDLE_ENFORCE_GT(variances[i], 0.0,
"variance[%d] must be greater than 0.", i);
}
});
AddAttr<std::vector<float>>("stride",
"Anchors stride across width and height, "
"with a default of (16, 16)")
.SetDefault(std::vector<float>(2, 16.0))
.AddCustomChecker([](const std::vector<float>& stride) {
PADDLE_ENFORCE_EQ(
stride.size(), 2,
"Must and only provide 2 stride for width and height.");
for (size_t i = 0; i < stride.size(); ++i) {
PADDLE_ENFORCE_GT(stride[i], 0.0,
"stride[%d] should be larger than 0.", i);
}
});
AddAttr<float>("offset",
"(float) "
"Anchor center offset, with a default of 0.5")
.SetDefault(0.5);
AddComment(R"DOC(
AnchorGenerator operator
Generates anchors for Faster RCNN, FPN etc. algorithm.
Each position of the input produce N anchors, N =
size(anchor_sizes) * size(aspect_ratios).
Please get more information from the following papers:
https://arxiv.org/abs/1506.01497.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(anchor_generator, ops::AnchorGeneratorOp,
ops::AnchorGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel<float>,
ops::AnchorGeneratorOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/anchor_generator_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num,
const T* anchor_sizes, const int as_num,
const T* stride, const int sd_num, const int height,
const int width, const T offset) {
int num_anchors = as_num * ar_num;
int box_num = height * width * num_anchors;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
i += blockDim.x * gridDim.x) {
int h_idx = i / (num_anchors * width);
int w_idx = (i / num_anchors) % width;
T stride_width = stride[0];
T stride_height = stride[1];
T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1);
T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1);
T area, area_ratios;
T base_w, base_h;
T scale_w, scale_h;
T anchor_width, anchor_height;
int anch_idx = i % num_anchors;
int ar_idx = anch_idx / as_num;
int as_idx = anch_idx % as_num;
T aspect_ratio = aspect_ratios[ar_idx];
T anchor_size = anchor_sizes[as_idx];
area = stride_width * stride_height;
area_ratios = area / aspect_ratio;
base_w = round(sqrt(area_ratios));
base_h = round(base_w * aspect_ratio);
scale_w = anchor_size / stride_width;
scale_h = anchor_size / stride_height;
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
T xmin = (x_ctr - 0.5 * (anchor_width - 1));
T ymin = (y_ctr - 0.5 * (anchor_height - 1));
T xmax = (x_ctr + 0.5 * (anchor_width - 1));
T ymax = (y_ctr + 0.5 * (anchor_height - 1));
out[i * 4] = xmin;
out[i * 4 + 1] = ymin;
out[i * 4 + 2] = xmax;
out[i * 4 + 3] = ymax;
}
}
template <typename T>
__global__ void SetVariance(T* out, const T* var, const int vnum,
const int num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
out[i] = var[i % vnum];
}
}
template <typename T>
class AnchorGeneratorOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
auto stride = ctx.Attr<std::vector<float>>("stride");
auto variances = ctx.Attr<std::vector<float>>("variances");
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto width = input->dims()[3];
auto height = input->dims()[2];
int num_anchors = aspect_ratios.size() * anchor_sizes.size();
int box_num = width * height * num_anchors;
int block = 512;
int grid = (box_num + block - 1) / block;
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
anchors->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
framework::Tensor ar;
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar);
framework::Tensor as;
framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as);
framework::Tensor sd;
framework::TensorFromVector(stride, ctx.device_context(), &sd);
GenAnchors<T><<<grid, block, 0, stream>>>(
anchors->data<T>(), ar.data<T>(), aspect_ratios.size(), as.data<T>(),
anchor_sizes.size(), sd.data<T>(), stride.size(), height, width,
offset);
framework::Tensor v;
framework::TensorFromVector(variances, ctx.device_context(), &v);
grid = (box_num * 4 + block - 1) / block;
SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(),
variances.size(), box_num * 4);
}
}; // namespace operators
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(anchor_generator,
ops::AnchorGeneratorOpCUDAKernel<float>,
ops::AnchorGeneratorOpCUDAKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
template <typename T>
class AnchorGeneratorOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* anchors = ctx.Output<paddle::framework::Tensor>("Anchors");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto anchor_sizes = ctx.Attr<std::vector<float>>("anchor_sizes");
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
auto stride = ctx.Attr<std::vector<float>>("stride");
auto variances = ctx.Attr<std::vector<float>>("variances");
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2];
T stride_width, stride_height;
stride_width = stride[0];
stride_height = stride[1];
int num_anchors = aspect_ratios.size() * anchor_sizes.size();
anchors->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
auto e_anchors = framework::EigenTensor<T, 4>::From(*anchors);
for (int h_idx = 0; h_idx < feature_height; ++h_idx) {
for (int w_idx = 0; w_idx < feature_width; ++w_idx) {
T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1);
T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1);
T area, area_ratios;
T base_w, base_h;
T scale_w, scale_h;
T anchor_width, anchor_height;
int idx = 0;
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
auto ar = aspect_ratios[r];
for (size_t s = 0; s < anchor_sizes.size(); ++s) {
auto anchor_size = anchor_sizes[s];
area = stride_width * stride_height;
area_ratios = area / ar;
base_w = round(sqrt(area_ratios));
base_h = round(base_w * ar);
scale_w = anchor_size / stride_width;
scale_h = anchor_size / stride_height;
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
e_anchors(h_idx, w_idx, idx, 0) =
(x_ctr - 0.5 * (anchor_width - 1));
e_anchors(h_idx, w_idx, idx, 1) =
(y_ctr - 0.5 * (anchor_height - 1));
e_anchors(h_idx, w_idx, idx, 2) =
(x_ctr + 0.5 * (anchor_width - 1));
e_anchors(h_idx, w_idx, idx, 3) =
(y_ctr + 0.5 * (anchor_height - 1));
idx++;
}
}
}
}
framework::Tensor var_t;
var_t.mutable_data<T>(
framework::make_ddim({1, static_cast<int>(variances.size())}),
ctx.GetPlace());
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
for (size_t i = 0; i < variances.size(); ++i) {
var_et(0, i) = variances[i];
}
int anchor_num = feature_height * feature_width * num_anchors;
auto var_dim = vars->dims();
vars->Resize({anchor_num, static_cast<int>(variances.size())});
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(anchor_num, 1));
vars->Resize(var_dim);
}
}; // namespace operators
} // namespace operators
} // namespace paddle
......@@ -30,6 +30,7 @@ __all__ = [
'detection_output',
'ssd_loss',
'detection_map',
'anchor_generator',
]
__auto__ = [
......@@ -998,3 +999,95 @@ def multi_box_head(inputs,
box.stop_gradient = True
var.stop_gradient = True
return mbox_locs_concat, mbox_confs_concat, box, var
def anchor_generator(input,
anchor_sizes=None,
aspect_ratios=None,
variance=[0.1, 0.1, 0.2, 0.2],
stride=None,
offset=0.5,
name=None):
"""
**Anchor generator operator**
Generate anchors for Faster RCNN algorithm.
Each position of the input produce N anchors, N =
size(anchor_sizes) * size(aspect_ratios). The order of generated anchors
is firstly aspect_ratios loop then anchor_sizes loop.
Args:
input(Variable): The input feature map, the format is NCHW.
anchor_sizes(list|tuple|float): The anchor sizes of generated anchors,
given in absolute pixels e.g. [64., 128., 256., 512.].
For instance, the anchor size of 64 means the area of this anchor equals to 64**2.
aspect_ratios(list|tuple|float): The height / width ratios of generated
anchors, e.g. [0.5, 1.0, 2.0].
variance(list|tuple): The variances to be used in box regression deltas.
Default:[0.1, 0.1, 0.2, 0.2].
stride(list|turple): The anchors stride across width and height,
e.g. [16.0, 16.0]
offset(float): Prior boxes center offset. Default: 0.5
name(str): Name of the prior box op. Default: None.
Returns:
Anchors(Variable): The output anchors with a layout of [H, W, num_anchors, 4].
H is the height of input, W is the width of input,
num_anchors is the box count of each position.
Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized.
Variances(Variable): The expanded variances of anchors
with a layout of [H, W, num_priors, 4].
H is the height of input, W is the width of input
num_anchors is the box count of each position.
Each variance is in (xcenter, ycenter, w, h) format.
Examples:
.. code-block:: python
anchor, var = anchor_generator(
input=conv1,
anchor_sizes=[64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
"""
helper = LayerHelper("anchor_generator", **locals())
dtype = helper.input_dtype()
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(anchor_sizes):
anchor_sizes = [anchor_sizes]
if not _is_list_or_tuple_(aspect_ratios):
aspect_ratios = [aspect_ratios]
if not (_is_list_or_tuple_(stride) and len(stride) == 2):
raise ValueError('stride should be a list or tuple ',
'with length 2, (stride_width, stride_height).')
anchor_sizes = list(map(float, anchor_sizes))
aspect_ratios = list(map(float, aspect_ratios))
stride = list(map(float, stride))
attrs = {
'anchor_sizes': anchor_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'stride': stride,
'offset': offset
}
anchor = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
type="anchor_generator",
inputs={"Input": input},
outputs={"Anchors": anchor,
"Variances": var},
attrs=attrs, )
anchor.stop_gradient = True
var.stop_gradient = True
return anchor, var
......@@ -127,6 +127,24 @@ class TestPriorBox(unittest.TestCase):
assert box.shape[3] == 4
class TestAnchorGenerator(unittest.TestCase):
def test_anchor_generator(self):
data_shape = [3, 224, 224]
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
anchor, var = fluid.layers.anchor_generator(
input=conv1,
anchor_sizes=[64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
assert len(anchor.shape) == 4
assert anchor.shape == var.shape
assert anchor.shape[3] == 4
class TestMultiBoxHead(unittest.TestCase):
def test_multi_box_head(self):
data_shape = [3, 224, 224]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://w_idxw.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
def anchor_generator_in_python(input_feat, anchor_sizes, aspect_ratios,
variances, stride, offset):
num_anchors = len(aspect_ratios) * len(anchor_sizes)
layer_h = input_feat.shape[2]
layer_w = input_feat.shape[3]
out_dim = (layer_h, layer_w, num_anchors, 4)
out_anchors = np.zeros(out_dim).astype('float32')
for h_idx in range(layer_h):
for w_idx in range(layer_w):
x_ctr = (w_idx * stride[0]) + offset * (stride[0] - 1)
y_ctr = (h_idx * stride[1]) + offset * (stride[1] - 1)
idx = 0
for r in range(len(aspect_ratios)):
ar = aspect_ratios[r]
for s in range(len(anchor_sizes)):
anchor_size = anchor_sizes[s]
area = stride[0] * stride[1]
area_ratios = area / ar
base_w = np.round(np.sqrt(area_ratios))
base_h = np.round(base_w * ar)
scale_w = anchor_size / stride[0]
scale_h = anchor_size / stride[1]
w = scale_w * base_w
h = scale_h * base_h
out_anchors[h_idx, w_idx, idx, :] = [
(x_ctr - 0.5 * (w - 1)), (y_ctr - 0.5 * (h - 1)),
(x_ctr + 0.5 * (w - 1)), (y_ctr + 0.5 * (h - 1))
]
idx += 1
# set the variance.
out_var = np.tile(variances, (layer_h, layer_w, num_anchors, 1))
out_anchors = out_anchors.astype('float32')
out_var = out_var.astype('float32')
return out_anchors, out_var
class TestAnchorGeneratorOp(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input}
self.attrs = {
'anchor_sizes': self.anchor_sizes,
'aspect_ratios': self.aspect_ratios,
'stride': self.stride,
'offset': self.offset,
'variances': self.variances,
}
self.outputs = {'Anchors': self.out_anchors, 'Variances': self.out_var}
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "anchor_generator"
self.set_data()
def init_test_params(self):
self.batch_size = 1
self.input_channels = 2
self.layer_h = 2
self.layer_w = 2
self.anchor_sizes = [64., 128., 256., 512.]
self.aspect_ratios = [0.5, 1., 2.]
self.stride = [16., 16.]
self.offset = 0.5
self.variances = [0.1, 0.1, 0.2, 0.2]
def init_test_input(self):
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_h,
self.layer_w)).astype('float32')
def init_test_output(self):
self.out_anchors, self.out_var = anchor_generator_in_python(
self.input, self.anchor_sizes, self.aspect_ratios, self.variances,
self.stride, self.offset)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册