未验证 提交 a84a580e 编写于 作者: Q qingqing01 提交者: GitHub

Add CUDA kernel for prior_box_op. (#9553)

上级 d139f2ca
......@@ -73,7 +73,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
platform::CPUPlace());
ctx.device_context());
}
};
......@@ -171,6 +171,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
prior_box, ops::PriorBoxOpKernel<paddle::platform::CPUPlace, float>,
ops::PriorBoxOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel<float>,
ops::PriorBoxOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/prior_box_op.h"
namespace paddle {
namespace operators {
template <typename T>
__device__ inline T clip(T in) {
return min(max(in, 0.), 1.);
}
template <typename T>
__global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
const int width, const int im_height,
const int im_width, const int as_num,
const T offset, const T step_width,
const T step_height, const T* min_sizes,
const T* max_sizes, const int min_num,
bool is_clip) {
int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num;
int box_num = height * width * num_priors;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
i += blockDim.x * gridDim.x) {
int h = i / (num_priors * width);
int w = (i / num_priors) % width;
int p = i % num_priors;
int m = max_sizes ? p / (as_num + 1) : p / as_num;
T cx = (w + offset) * step_width;
T cy = (h + offset) * step_height;
T bw, bh;
T min_size = min_sizes[m];
if (max_sizes) {
int s = p % (as_num + 1);
if (s < as_num) {
T ar = aspect_ratios[s];
bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
} else {
T max_size = max_sizes[m];
bw = sqrt(min_size * max_size) / 2.;
bh = bw;
}
} else {
int s = p % as_num;
T ar = aspect_ratios[s];
bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
}
T xmin = (cx - bw) / im_width;
T ymin = (cy - bh) / im_height;
T xmax = (cx + bw) / im_width;
T ymax = (cy + bh) / im_height;
out[i * 4] = is_clip ? clip<T>(xmin) : xmin;
out[i * 4 + 1] = is_clip ? clip<T>(ymin) : ymin;
out[i * 4 + 2] = is_clip ? clip<T>(xmax) : xmax;
out[i * 4 + 3] = is_clip ? clip<T>(ymax) : ymax;
}
}
template <typename T>
__global__ void SetVariance(T* out, const T* var, const int vnum,
const int num) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
out[i] = var[i % vnum];
}
}
template <typename T>
class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip");
auto clip = ctx.Attr<bool>("clip");
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios);
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto im_width = image->dims()[3];
auto im_height = image->dims()[2];
auto width = input->dims()[3];
auto height = input->dims()[2];
T step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(im_width) / width;
step_height = static_cast<T>(im_height) / height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (max_sizes.size() > 0) {
num_priors += max_sizes.size();
}
int min_num = static_cast<int>(min_sizes.size());
int box_num = width * height * num_priors;
int block = 512;
int grid = (box_num + block - 1) / block;
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
framework::Tensor r;
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &r);
framework::Tensor min;
framework::TensorFromVector(min_sizes, ctx.device_context(), &min);
T* max_data = nullptr;
framework::Tensor max;
if (max_sizes.size() > 0) {
framework::TensorFromVector(max_sizes, ctx.device_context(), &max);
max_data = max.data<T>();
}
GenPriorBox<T><<<grid, block, 0, stream>>>(
boxes->data<T>(), r.data<T>(), height, width, im_height, im_width,
aspect_ratios.size(), offset, step_width, step_height, min.data<T>(),
max_data, min_num, clip);
framework::Tensor v;
framework::TensorFromVector(variances, ctx.device_context(), &v);
grid = (box_num * 4 + block - 1) / block;
SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(),
variances.size(), box_num * 4);
}
}; // namespace operators
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(prior_box, ops::PriorBoxOpCUDAKernel<float>,
ops::PriorBoxOpCUDAKernel<double>);
......@@ -51,7 +51,7 @@ struct ClipFunctor {
}
};
template <typename Place, typename T>
template <typename T>
class PriorBoxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -106,49 +106,24 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
// first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size / 2.;
// xmin
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
// ymin
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
// ymax
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
// second prior: aspect_ratio = 1,
// size = sqrt(min_size * max_size)
box_width = box_height = sqrt(min_size * max_size) / 2.;
// xmin
// priors with different aspect ratios
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
float ar = aspect_ratios[r];
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
// ymin
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
// ymax
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
// rest of priors
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
float ar = aspect_ratios[r];
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
// xmin
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
// ymin
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
// ymax
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
......
......@@ -28,7 +28,6 @@ class TestPriorBoxOp(OpTest):
self.attrs = {
'min_sizes': self.min_sizes,
'max_sizes': self.max_sizes,
'aspect_ratios': self.aspect_ratios,
'variances': self.variances,
'flip': self.flip,
......@@ -37,25 +36,28 @@ class TestPriorBoxOp(OpTest):
'step_h': self.step_h,
'offset': self.offset
}
if len(self.max_sizes) > 0:
self.attrs['max_sizes'] = self.max_sizes
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
return
def setUp(self):
self.op_type = "prior_box"
self.set_data()
def set_max_sizes(self):
max_sizes = [5, 10]
self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def init_test_params(self):
self.layer_w = 4
self.layer_h = 4
self.layer_w = 32
self.layer_h = 32
self.image_w = 20
self.image_h = 20
self.image_w = 40
self.image_h = 40
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
......@@ -66,8 +68,7 @@ class TestPriorBoxOp(OpTest):
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
self.max_sizes = [5, 10]
self.max_sizes = np.array(self.max_sizes).astype('float32').tolist()
self.set_max_sizes()
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
......@@ -79,7 +80,7 @@ class TestPriorBoxOp(OpTest):
self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 1:
if len(self.max_sizes) > 0:
self.num_priors += len(self.max_sizes)
self.offset = 0.5
......@@ -105,35 +106,27 @@ class TestPriorBoxOp(OpTest):
idx = 0
for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s]
c_w = c_h = min_size / 2.
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h
]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if math.fabs(ar - 1.) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
......@@ -144,5 +137,10 @@ class TestPriorBoxOp(OpTest):
self.out_var = out_var.astype('float32')
class TestPriorBoxOpWithMaxSize(TestPriorBoxOp):
def set_max_sizes(self):
self.max_sizes = []
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册