diff --git a/paddle/fluid/operators/prior_box_op.cc b/paddle/fluid/operators/prior_box_op.cc index c22a55bce263423d5c17fffdb06b7ece02ae26da..82e54139c8c1f42b1d8f74811a6793ec5c66473e 100644 --- a/paddle/fluid/operators/prior_box_op.cc +++ b/paddle/fluid/operators/prior_box_op.cc @@ -73,7 +73,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), - platform::CPUPlace()); + ctx.device_context()); } }; @@ -171,6 +171,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - prior_box, ops::PriorBoxOpKernel, - ops::PriorBoxOpKernel); +REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/fluid/operators/prior_box_op.cu b/paddle/fluid/operators/prior_box_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..76bf2b3b7de7a24c80e927c16199f89c5b7fb794 --- /dev/null +++ b/paddle/fluid/operators/prior_box_op.cu @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/prior_box_op.h" + +namespace paddle { +namespace operators { + +template +__device__ inline T clip(T in) { + return min(max(in, 0.), 1.); +} + +template +__global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height, + const int width, const int im_height, + const int im_width, const int as_num, + const T offset, const T step_width, + const T step_height, const T* min_sizes, + const T* max_sizes, const int min_num, + bool is_clip) { + int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num; + int box_num = height * width * num_priors; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; + i += blockDim.x * gridDim.x) { + int h = i / (num_priors * width); + int w = (i / num_priors) % width; + int p = i % num_priors; + int m = max_sizes ? p / (as_num + 1) : p / as_num; + T cx = (w + offset) * step_width; + T cy = (h + offset) * step_height; + T bw, bh; + T min_size = min_sizes[m]; + if (max_sizes) { + int s = p % (as_num + 1); + if (s < as_num) { + T ar = aspect_ratios[s]; + bw = min_size * sqrt(ar) / 2.; + bh = min_size / sqrt(ar) / 2.; + } else { + T max_size = max_sizes[m]; + bw = sqrt(min_size * max_size) / 2.; + bh = bw; + } + } else { + int s = p % as_num; + T ar = aspect_ratios[s]; + bw = min_size * sqrt(ar) / 2.; + bh = min_size / sqrt(ar) / 2.; + } + T xmin = (cx - bw) / im_width; + T ymin = (cy - bh) / im_height; + T xmax = (cx + bw) / im_width; + T ymax = (cy + bh) / im_height; + out[i * 4] = is_clip ? clip(xmin) : xmin; + out[i * 4 + 1] = is_clip ? clip(ymin) : ymin; + out[i * 4 + 2] = is_clip ? clip(xmax) : xmax; + out[i * 4 + 3] = is_clip ? clip(ymax) : ymax; + } +} + +template +__global__ void SetVariance(T* out, const T* var, const int vnum, + const int num) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + out[i] = var[i % vnum]; + } +} + +template +class PriorBoxOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* vars = ctx.Output("Variances"); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); + auto variances = ctx.Attr>("variances"); + auto flip = ctx.Attr("flip"); + auto clip = ctx.Attr("clip"); + + std::vector aspect_ratios; + ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios); + + T step_w = static_cast(ctx.Attr("step_w")); + T step_h = static_cast(ctx.Attr("step_h")); + T offset = static_cast(ctx.Attr("offset")); + + auto im_width = image->dims()[3]; + auto im_height = image->dims()[2]; + + auto width = input->dims()[3]; + auto height = input->dims()[2]; + + T step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(im_width) / width; + step_height = static_cast(im_height) / height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + int min_num = static_cast(min_sizes.size()); + int box_num = width * height * num_priors; + + int block = 512; + int grid = (box_num + block - 1) / block; + + auto stream = + ctx.template device_context().stream(); + + boxes->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + + framework::Tensor r; + framework::TensorFromVector(aspect_ratios, ctx.device_context(), &r); + + framework::Tensor min; + framework::TensorFromVector(min_sizes, ctx.device_context(), &min); + + T* max_data = nullptr; + framework::Tensor max; + if (max_sizes.size() > 0) { + framework::TensorFromVector(max_sizes, ctx.device_context(), &max); + max_data = max.data(); + } + + GenPriorBox<<>>( + boxes->data(), r.data(), height, width, im_height, im_width, + aspect_ratios.size(), offset, step_width, step_height, min.data(), + max_data, min_num, clip); + + framework::Tensor v; + framework::TensorFromVector(variances, ctx.device_context(), &v); + grid = (box_num * 4 + block - 1) / block; + SetVariance<<>>(vars->data(), v.data(), + variances.size(), box_num * 4); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(prior_box, ops::PriorBoxOpCUDAKernel, + ops::PriorBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/prior_box_op.h b/paddle/fluid/operators/prior_box_op.h index 18bb2deb6b5acf626dfb2883a5771d9d195d45c0..1e4a12aac1c5f1c3b7e2e1bc83170de9ad590fc3 100644 --- a/paddle/fluid/operators/prior_box_op.h +++ b/paddle/fluid/operators/prior_box_op.h @@ -51,7 +51,7 @@ struct ClipFunctor { } }; -template +template class PriorBoxOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -106,49 +106,24 @@ class PriorBoxOpKernel : public framework::OpKernel { int idx = 0; for (size_t s = 0; s < min_sizes.size(); ++s) { auto min_size = min_sizes[s]; - // first prior: aspect_ratio = 1, size = min_size - box_width = box_height = min_size / 2.; - // xmin - e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width; - // ymin - e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height; - // xmax - e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width; - // ymax - e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height; - - idx++; - if (max_sizes.size() > 0) { - auto max_size = max_sizes[s]; - // second prior: aspect_ratio = 1, - // size = sqrt(min_size * max_size) - box_width = box_height = sqrt(min_size * max_size) / 2.; - // xmin + // priors with different aspect ratios + for (size_t r = 0; r < aspect_ratios.size(); ++r) { + float ar = aspect_ratios[r]; + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width; - // ymin e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height; - // xmax e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width; - // ymax e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height; idx++; } - - // rest of priors - for (size_t r = 0; r < aspect_ratios.size(); ++r) { - float ar = aspect_ratios[r]; - if (fabs(ar - 1.) < 1e-6) { - continue; - } - box_width = min_size * sqrt(ar) / 2.; - box_height = min_size / sqrt(ar) / 2.; - // xmin + if (max_sizes.size() > 0) { + auto max_size = max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width; - // ymin e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height; - // xmax e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width; - // ymax e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height; idx++; } diff --git a/python/paddle/fluid/tests/unittests/test_prior_box_op.py b/python/paddle/fluid/tests/unittests/test_prior_box_op.py index c21138c13e6753f9dfcbd7d439269f7cf9a04f23..bcbc02a2baa46b9ab583ecf3006bd3262e6038fd 100644 --- a/python/paddle/fluid/tests/unittests/test_prior_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_prior_box_op.py @@ -28,7 +28,6 @@ class TestPriorBoxOp(OpTest): self.attrs = { 'min_sizes': self.min_sizes, - 'max_sizes': self.max_sizes, 'aspect_ratios': self.aspect_ratios, 'variances': self.variances, 'flip': self.flip, @@ -37,25 +36,28 @@ class TestPriorBoxOp(OpTest): 'step_h': self.step_h, 'offset': self.offset } + if len(self.max_sizes) > 0: + self.attrs['max_sizes'] = self.max_sizes self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} def test_check_output(self): self.check_output() - def test_check_grad(self): - return - def setUp(self): self.op_type = "prior_box" self.set_data() + def set_max_sizes(self): + max_sizes = [5, 10] + self.max_sizes = np.array(max_sizes).astype('float32').tolist() + def init_test_params(self): - self.layer_w = 4 - self.layer_h = 4 + self.layer_w = 32 + self.layer_h = 32 - self.image_w = 20 - self.image_h = 20 + self.image_w = 40 + self.image_h = 40 self.step_w = float(self.image_w) / float(self.layer_w) self.step_h = float(self.image_h) / float(self.layer_h) @@ -66,8 +68,7 @@ class TestPriorBoxOp(OpTest): self.min_sizes = [2, 4] self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() - self.max_sizes = [5, 10] - self.max_sizes = np.array(self.max_sizes).astype('float32').tolist() + self.set_max_sizes() self.aspect_ratios = [2.0, 3.0] self.flip = True self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] @@ -79,7 +80,7 @@ class TestPriorBoxOp(OpTest): self.clip = True self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) - if len(self.max_sizes) > 1: + if len(self.max_sizes) > 0: self.num_priors += len(self.max_sizes) self.offset = 0.5 @@ -105,35 +106,27 @@ class TestPriorBoxOp(OpTest): idx = 0 for s in range(len(self.min_sizes)): min_size = self.min_sizes[s] - c_w = c_h = min_size / 2. - out_boxes[h, w, idx, :] = [ - (c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h, - (c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h - ] - idx += 1 - - if len(self.max_sizes) > 0: - max_size = self.max_sizes[s] - # second prior: aspect_ratio = 1, - c_w = c_h = math.sqrt(min_size * max_size) / 2 + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h, (c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h] idx += 1 - # rest of priors - for r in range(len(self.real_aspect_ratios)): - ar = self.real_aspect_ratios[r] - if math.fabs(ar - 1.) < 1e-6: - continue - c_w = min_size * math.sqrt(ar) / 2 - c_h = (min_size / math.sqrt(ar)) / 2 + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h, (c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h] idx += 1 + # clip the prior's coordidate such that it is within[0, 1] if self.clip: out_boxes = np.clip(out_boxes, 0.0, 1.0) @@ -144,5 +137,10 @@ class TestPriorBoxOp(OpTest): self.out_var = out_var.astype('float32') +class TestPriorBoxOpWithMaxSize(TestPriorBoxOp): + def set_max_sizes(self): + self.max_sizes = [] + + if __name__ == '__main__': unittest.main()