diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe1ccceb0665eb6407ac0203ebf69f7c7a776702 --- /dev/null +++ b/paddle/operators/prior_box_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace paddle { +namespace operators { + +class PriorBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(X) of SequenceSliceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Image"), + "Input(Offset) of SequenceSliceOp should not be null."); + + auto image_dims = ctx->GetInputDim("Image"); + auto input_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE(image_dims.size() == 4, + "The format of input tensor is NCHW."); + + auto min_sizes = ctx->Attrs().Get>("min_sizes"); + auto max_sizes = ctx->Attrs().Get>("max_sizes"); + auto variances = ctx->Attrs().Get>("variances"); + auto input_aspect_ratio = + ctx->Attrs().Get>("aspect_ratios"); + bool flip = ctx->Attrs().Get("flip"); + + PADDLE_ENFORCE_GT(min_sizes.size(), 0, "must provide min_size."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i); + } + + std::vector aspect_ratios; + expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios); + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), + "The length of min_size and max_size must be equal."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], + "max_size[%d] must be greater than min_size[%d].", i, + i); + num_priors += 1; + } + } + + if (variances.size() > 1) { + PADDLE_ENFORCE_EQ(variances.size(), 4, + "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + } else if (variances.size() == 1) { + PADDLE_ENFORCE_GT(variances[0], 0.0, + "variance[0] must be greater than 0."); + } + + const int img_h = ctx->Attrs().Get("img_h"); + PADDLE_ENFORCE_GT(img_h, 0, "img_h should be larger than 0."); + const int img_w = ctx->Attrs().Get("img_w"); + PADDLE_ENFORCE_GT(img_w, 0, "img_w should be larger than 0."); + + const float step_h = ctx->Attrs().Get("step_h"); + PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); + const float step_w = ctx->Attrs().Get("step_w"); + PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); + + const int layer_height = input_dims[3]; + const int layer_width = input_dims[2]; + + std::vector dim_vec(3); + // Since all images in a batch has same height and width, we only need to + // generate one set of priors which can be shared across all images. + dim_vec[0] = 1; + // 2 channels. First channel stores the mean of each prior coordinate. + // Second channel stores the variance of each prior coordinate. + dim_vec[1] = 2; + dim_vec[2] = layer_width * layer_height * num_priors * 4; + PADDLE_ENFORCE_GT(dim_vec[2], 0, + "output_dim[2] must larger than 0." + "check your data dims"); + auto output_dim = framework::make_ddim(dim_vec); + ctx->SetOutputDim("Out", output_dim); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Image")->type()), + ctx.device_context()); + } +}; + +class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PriorBoxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor), " + "the input feature data of PriorBoxOp."); + AddInput("Image", + "(Tensor), " + "the input image data of PriorBoxOp."); + AddOutput("Out", "(Tensor), the output prior boxes of PriorBoxOp."); + AddAttr>("min_sizes", "(vector) ", + "List of min sizes of generated prior boxes."); + AddAttr>("max_sizes", "(vector) ", + "List of max sizes of generated prior boxes."); + AddAttr>( + "aspect_ratios", "(vector) ", + "List of aspect ratios of generated prior boxes.") + .SetDefault({}); + AddAttr>( + "variances", "(vector) ", + "List of variances to be encoded in prior boxes.") + .SetDefault({0.1}); + AddAttr("flip", "(bool) ", "Whether to flip aspect ratios.") + .SetDefault(true); + AddAttr("clip", "(bool) ", "Whether to clip out-of-boundary boxes.") + .SetDefault(true); + AddAttr("img_w", "").SetDefault(0); + AddAttr("img_h", "").SetDefault(0); + AddAttr("step_w", + "Prior boxes step across width, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("step_h", + "Prior boxes step across height, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("offset", + "(float) " + "Prior boxes center offset.") + .SetDefault(0.5); + AddComment(R"DOC( +Prior box operator +Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker); +REGISTER_OP_CPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.cu b/paddle/operators/prior_box_op.cu new file mode 100755 index 0000000000000000000000000000000000000000..d1928462a2d75baa98a85833647fc695cea2a019 --- /dev/null +++ b/paddle/operators/prior_box_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6dabba526504a2680d93da4d70035372b6667f68 --- /dev/null +++ b/paddle/operators/prior_box_op.h @@ -0,0 +1,199 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +// #include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { + +inline void expand_aspect_ratios(const std::vector input_aspect_ratior, + bool flip, + std::vector& output_aspect_ratior) { + constexpr float eps = 1e-6; + output_aspect_ratior.clear(); + output_aspect_ratior.push_back(1.); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior.size(); ++j) { + if (fabs(ar - output_aspect_ratior[j]) < eps) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior.push_back(ar); + if (flip) { + output_aspect_ratior.push_back(1. / ar); + } + } + } +} + +template +class PriorBoxOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* out = ctx.Output("Out"); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); + auto variances = ctx.Attr>("variances"); + auto flip = ctx.Attr("flip"); + auto clip = ctx.Attr("clip"); + + std::vector aspect_ratios; + expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios); + + auto img_w = ctx.Attr("img_w"); + auto img_h = ctx.Attr("img_h"); + auto step_w = ctx.Attr("step_w"); + auto step_h = ctx.Attr("step_h"); + auto offset = ctx.Attr("offset"); + + int img_width, img_height; + if (img_h == 0 || img_w == 0) { + img_width = image->dims()[2]; + img_height = image->dims()[3]; + } else { + img_width = img_w; + img_height = img_h; + } + + const int layer_width = input->dims()[2]; + const int layer_height = input->dims()[3]; + + float step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / layer_width; + step_height = static_cast(img_height) / layer_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + + int dim = layer_height * layer_width * num_priors * 4; + + T* output_data = nullptr; + framework::Tensor output_cpu; + out->mutable_data(ctx.GetPlace()); + if (platform::is_gpu_place(ctx.GetPlace())) { + output_data = + output_cpu.mutable_data(out->dims(), platform::CPUPlace()); + } else { + output_data = out->mutable_data(ctx.GetPlace()); + } + + int idx = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + float center_x = (w + offset) * step_width; + float center_y = (h + offset) * step_height; + float box_width, box_height; + for (size_t s = 0; s < min_sizes.size(); ++s) { + int min_size = min_sizes[s]; + // first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size; + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + + if (max_sizes.size() > 0) { + int max_size = max_sizes[s]; + // second prior: aspect_ratio = 1, + // size = sqrt(min_size * max_size) + box_width = box_height = sqrt(min_size * max_size); + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + } + + // rest of priors + for (size_t r = 0; r < aspect_ratios.size(); ++r) { + float ar = aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar); + box_height = min_size / sqrt(ar); + // xmin + output_data[idx++] = (center_x - box_width / 2.) / img_width; + // ymin + output_data[idx++] = (center_y - box_height / 2.) / img_height; + // xmax + output_data[idx++] = (center_x + box_width / 2.) / img_width; + // ymax + output_data[idx++] = (center_y + box_height / 2.) / img_height; + } + } + } + } + + // clip the prior's coordidate such that it is within [0, 1] + if (clip) { + for (int d = 0; d < dim; ++d) { + output_data[d] = std::min(std::max(output_data[d], 0.), 1.); + } + } + + // set the variance. + auto output_stride = framework::stride(out->dims()); + output_data += output_stride[1]; + if (variances.size() == 1) { + for (int i = 0; i < dim; ++i) { + output_data[i] = variances[0]; + } + } else { + int count = 0; + for (int h = 0; h < layer_height; ++h) { + for (int w = 0; w < layer_width; ++w) { + for (int i = 0; i < num_priors; ++i) { + for (int j = 0; j < 4; ++j) { + output_data[count] = variances[j]; + ++count; + } + } + } + } + } + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::CopyFrom(output_cpu, platform::CPUPlace(), + ctx.device_context(), out); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8218895299e2a0dceeb7cb2ad72a65d6629680 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -0,0 +1,179 @@ +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestPriorBoxOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'max_sizes': self.max_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'img_w': self.image_w, + 'img_h': self.image_h, + 'offset': self.offset + } + + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "prior_box" + self.set_data() + + def init_test_params(self): + self.layer_w = 4 + self.layer_h = 4 + + self.image_w = 20 + self.image_h = 20 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('int64') + self.max_sizes = [5, 10] + self.max_sizes = np.array(self.max_sizes).astype('int64') + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 1: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + dim = self.layer_w * self.layer_h * self.num_priors * 4 + out_dim = (1, 2, dim) + output = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + center_x = (w + self.offset) * self.step_w + center_y = (h + self.offset) * self.step_h + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + # first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + # size = sqrt(min_size * max_size) + box_width = box_height = math.sqrt(min_size * max_size) + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if math.fabs(ar - 1.) < 1e-6: + continue + box_width = min_size * math.sqrt(ar) + box_height = min_size / math.sqrt(ar) + # xmin + output[0, 0, idx] = ( + center_x - box_width / 2.) / self.image_w + idx += 1 + # ymin + output[0, 0, idx] = ( + center_y - box_height / 2.) / self.image_h + idx += 1 + # xmax + output[0, 0, idx] = ( + center_x + box_width / 2.) / self.image_w + idx += 1 + # ymax + output[0, 0, idx] = ( + center_y + box_height / 2.) / self.image_h + idx += 1 + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + for d in range(dim): + output[0, 0, d] = min(max(output[0, 0, d], 0), 1) + # set the variance. + if len(self.variances) == 1: + for i in range(dim): + output[0, 1, i] = self.variances[0] + else: + count = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + for i in range(self.num_priors): + for j in range(4): + output[0, 1, count] = self.variances[j] + count += 1 + self.output = output.astype('float32') + + +if __name__ == '__main__': + unittest.main()