diff --git a/paddle/operators/prior_box_op.cc b/paddle/operators/prior_box_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..105ff4ac3e3ba889aad880f4204af15829c6da47 --- /dev/null +++ b/paddle/operators/prior_box_op.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/prior_box_op.h" + +namespace paddle { +namespace operators { + +class PriorBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of PriorBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Image"), + "Input(Image) of PriorBoxOp should not be null."); + + auto image_dims = ctx->GetInputDim("Image"); + auto input_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW."); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + PADDLE_ENFORCE_LT(input_dims[2], image_dims[2], + "The height of input must smaller than image."); + + PADDLE_ENFORCE_LT(input_dims[3], image_dims[3], + "The width of input must smaller than image."); + + auto min_sizes = ctx->Attrs().Get>("min_sizes"); + auto max_sizes = ctx->Attrs().Get>("max_sizes"); + auto variances = ctx->Attrs().Get>("variances"); + auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios"); + bool flip = ctx->Attrs().Get("flip"); + + PADDLE_ENFORCE_GT(min_sizes.size(), 0, + "Size of min_sizes must be at least 1."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i); + } + + std::vector aspect_ratios_vec; + ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec); + + int num_priors = aspect_ratios_vec.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), + "The number of min_size and max_size must be equal."); + for (size_t i = 0; i < min_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], + "max_size[%d] must be greater than min_size[%d].", i, + i); + num_priors += 1; + } + } + + PADDLE_ENFORCE_EQ(variances.size(), 4, "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + + const float step_h = ctx->Attrs().Get("step_h"); + PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); + const float step_w = ctx->Attrs().Get("step_w"); + PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); + + std::vector dim_vec(4); + dim_vec[0] = input_dims[2]; + dim_vec[1] = input_dims[3]; + dim_vec[2] = num_priors; + dim_vec[3] = 4; + ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); + ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); + } +}; + +class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor, default Tensor), " + "the input feature data of PriorBoxOp, The layout is NCHW."); + AddInput("Image", + "(Tensor, default Tensor), " + "the input image data of PriorBoxOp, The layout is NCHW."); + AddOutput("Boxes", + "(Tensor, default Tensor), the output prior boxes of " + "PriorBoxOp. The layout is [H, W, num_priors, 4]. " + "H is the height of input, W is the width of input, num_priors " + "is the box count of each position."); + AddOutput("Variances", + "(Tensor, default Tensor), the expanded variances of " + "PriorBoxOp. The layout is [H, W, num_priors, 4]. " + "H is the height of input, W is the width of input, num_priors " + "is the box count of each position."); + AddAttr>("min_sizes", "(vector) ", + "List of min sizes of generated prior boxes."); + AddAttr>("max_sizes", "(vector) ", + "List of max sizes of generated prior boxes."); + AddAttr>( + "aspect_ratios", "(vector) ", + "List of aspect ratios of generated prior boxes."); + AddAttr>( + "variances", "(vector) ", + "List of variances to be encoded in prior boxes."); + AddAttr("flip", "(bool) ", "Whether to flip aspect ratios.") + .SetDefault(true); + AddAttr("clip", "(bool) ", "Whether to clip out-of-boundary boxes.") + .SetDefault(true); + AddAttr("step_w", + "Prior boxes step across width, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("step_h", + "Prior boxes step across height, 0 for auto calculation.") + .SetDefault(0.0); + AddAttr("offset", + "(float) " + "Prior boxes center offset.") + .SetDefault(0.5); + AddComment(R"DOC( +Prior box operator +Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. +Each position of the input produce N prior boxes, N is determined by + the count of min_sizes, max_sizes and aspect_ratios, The size of the + box is in range(min_size, max_size) interval, which is generated in + sequence according to the aspect_ratios. + +Please get more information from the following papers: +https://arxiv.org/abs/1512.02325. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker); +REGISTER_OP_CPU_KERNEL( + prior_box, ops::PriorBoxOpKernel, + ops::PriorBoxOpKernel); diff --git a/paddle/operators/prior_box_op.h b/paddle/operators/prior_box_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e0a663ace8f38c2d08fd4714c1247d3313ffae3e --- /dev/null +++ b/paddle/operators/prior_box_op.h @@ -0,0 +1,188 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, + bool flip, + std::vector& output_aspect_ratior) { + constexpr float epsilon = 1e-6; + output_aspect_ratior.clear(); + output_aspect_ratior.push_back(1.); + for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { + float ar = input_aspect_ratior[i]; + bool already_exist = false; + for (size_t j = 0; j < output_aspect_ratior.size(); ++j) { + if (fabs(ar - output_aspect_ratior[j]) < epsilon) { + already_exist = true; + break; + } + } + if (!already_exist) { + output_aspect_ratior.push_back(ar); + if (flip) { + output_aspect_ratior.push_back(1. / ar); + } + } + } +} + +template +struct ClipFunctor { + HOSTDEVICE T operator()(T in) const { + return std::min(std::max(in, 0.), 1.); + } +}; + +template +class PriorBoxOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* vars = ctx.Output("Variances"); + + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); + auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); + auto variances = ctx.Attr>("variances"); + auto flip = ctx.Attr("flip"); + auto clip = ctx.Attr("clip"); + + std::vector aspect_ratios; + ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios); + + T step_w = static_cast(ctx.Attr("step_w")); + T step_h = static_cast(ctx.Attr("step_h")); + T offset = static_cast(ctx.Attr("offset")); + + auto img_width = image->dims()[3]; + auto img_height = image->dims()[2]; + + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; + + T step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (max_sizes.size() > 0) { + num_priors += max_sizes.size(); + } + + boxes->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + + auto e_boxes = framework::EigenTensor::From(*boxes); + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + T center_x = (w + offset) * step_width; + T center_y = (h + offset) * step_height; + T box_width, box_height; + int idx = 0; + for (size_t s = 0; s < min_sizes.size(); ++s) { + int min_size = min_sizes[s]; + // first prior: aspect_ratio = 1, size = min_size + box_width = box_height = min_size; + // xmin + e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width; + // ymin + e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height; + // xmax + e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width; + // ymax + e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + + idx++; + if (max_sizes.size() > 0) { + int max_size = max_sizes[s]; + // second prior: aspect_ratio = 1, + // size = sqrt(min_size * max_size) + box_width = box_height = sqrt(min_size * max_size); + // xmin + e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width; + // ymin + e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height; + // xmax + e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width; + // ymax + e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; + } + + // rest of priors + for (size_t r = 0; r < aspect_ratios.size(); ++r) { + float ar = aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar); + box_height = min_size / sqrt(ar); + // xmin + e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width; + // ymin + e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height; + // xmax + e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width; + // ymax + e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height; + idx++; + } + } + } + } + + if (clip) { + platform::Transform trans; + ClipFunctor clip_func; + trans(ctx.template device_context(), + boxes->data(), boxes->data() + boxes->numel(), + boxes->data(), clip_func); + } + + framework::Tensor var_t; + var_t.mutable_data( + framework::make_ddim({1, static_cast(variances.size())}), + ctx.GetPlace()); + auto var_et = framework::EigenTensor::From(var_t); + for (size_t i = 0; i < variances.size(); ++i) { + var_et(0, i) = variances[i]; + } + + int box_num = feature_height * feature_width * num_priors; + auto var_dim = vars->dims(); + vars->Resize({box_num, static_cast(variances.size())}); + + auto e_vars = framework::EigenMatrix::From(*vars); + e_vars = var_et.broadcast(Eigen::DSizes(box_num, 1)); + + vars->Resize(var_dim); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ca8d2bca74ce2d4be8160c8851e393489691ae56 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -0,0 +1,148 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestPriorBoxOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'max_sizes': self.max_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset + } + + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "prior_box" + self.set_data() + + def init_test_params(self): + self.layer_w = 4 + self.layer_h = 4 + + self.image_w = 20 + self.image_h = 20 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('int64') + self.max_sizes = [5, 10] + self.max_sizes = np.array(self.max_sizes).astype('int64') + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 1: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype('float32') + out_var = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + idx = 0 + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + c_w = c_h = min_size / 2. + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h + ] + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if math.fabs(ar - 1.) < 1e-6: + continue + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h] + idx += 1 + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + # set the variance. + out_var = np.tile(self.variances, (self.layer_h, self.layer_w, + self.num_priors, 1)) + self.out_boxes = out_boxes.astype('float32') + self.out_var = out_var.astype('float32') + + +if __name__ == '__main__': + unittest.main()