From 4a55fb5f5b8d177a61133afe7210561f796a7e32 Mon Sep 17 00:00:00 2001 From: ruri Date: Tue, 13 Nov 2018 11:42:26 +0800 Subject: [PATCH] Add density_prior_box_op (#14226) Density prior box operator for image detection model. --- paddle/fluid/API.spec | 1 + .../fluid/operators/detection/CMakeLists.txt | 1 + .../detection/density_prior_box_op.cc | 175 ++++++++++++++++++ .../detection/density_prior_box_op.h | 146 +++++++++++++++ python/paddle/fluid/layers/detection.py | 130 +++++++++++++ python/paddle/fluid/tests/test_detection.py | 18 ++ .../unittests/test_density_prior_box_op.py | 142 ++++++++++++++ 7 files changed, 613 insertions(+) create mode 100644 paddle/fluid/operators/detection/density_prior_box_op.cc create mode 100644 paddle/fluid/operators/detection/density_prior_box_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_density_prior_box_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index de32a5d5a..1bd4376f9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -274,6 +274,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False)) +paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None)) paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)) paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index d5eec148f..e5c3f0eeb 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -22,6 +22,7 @@ iou_similarity_op.cu) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc poly_util.cc gpc.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) +detection_library(density_prior_box_op SRCS density_prior_box_op.cc) detection_library(anchor_generator_op SRCS anchor_generator_op.cc anchor_generator_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc new file mode 100644 index 000000000..99df15c32 --- /dev/null +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -0,0 +1,175 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/density_prior_box_op.h" + +namespace paddle { +namespace operators { + +class DensityPriorBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of DensityPriorBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Image"), + "Input(Image) of DensityPriorBoxOp should not be null."); + + auto image_dims = ctx->GetInputDim("Image"); + auto input_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW."); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + PADDLE_ENFORCE_LT(input_dims[2], image_dims[2], + "The height of input must smaller than image."); + + PADDLE_ENFORCE_LT(input_dims[3], image_dims[3], + "The width of input must smaller than image."); + auto variances = ctx->Attrs().Get>("variances"); + + auto fixed_sizes = ctx->Attrs().Get>("fixed_sizes"); + auto fixed_ratios = ctx->Attrs().Get>("fixed_ratios"); + auto densities = ctx->Attrs().Get>("densities"); + + PADDLE_ENFORCE_EQ(fixed_sizes.size(), densities.size(), + "The number of fixed_sizes and densities must be equal."); + size_t num_priors = 0; + if ((fixed_sizes.size() > 0) && (densities.size() > 0)) { + for (size_t i = 0; i < densities.size(); ++i) { + if (fixed_ratios.size() > 0) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); + } + } + } + std::vector dim_vec(4); + dim_vec[0] = input_dims[2]; + dim_vec[1] = input_dims[3]; + dim_vec[2] = num_priors; + dim_vec[3] = 4; + ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); + ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + platform::CPUPlace()); + } +}; + +class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "(Tensor, default Tensor), " + "the input feature data of DensityPriorBoxOp, the layout is NCHW."); + AddInput("Image", + "(Tensor, default Tensor), " + "the input image data of DensityPriorBoxOp, the layout is NCHW."); + AddOutput("Boxes", + "(Tensor, default Tensor), the output prior boxes of " + "DensityPriorBoxOp. The layout is [H, W, num_priors, 4]. " + "H is the height of input, W is the width of input, num_priors " + "is the box count of each position."); + AddOutput("Variances", + "(Tensor, default Tensor), the expanded variances of " + "DensityPriorBoxOp. The layout is [H, W, num_priors, 4]. " + "H is the height of input, W is the width of input, num_priors " + "is the box count of each position."); + AddAttr>("variances", + "(vector) List of variances to be " + "encoded in density prior boxes.") + .AddCustomChecker([](const std::vector& variances) { + PADDLE_ENFORCE_EQ(variances.size(), 4, + "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + }); + AddAttr("clip", "(bool) Whether to clip out-of-boundary boxes.") + .SetDefault(true); + + AddAttr( + "step_w", + "Density prior boxes step across width, 0.0 for auto calculation.") + .SetDefault(0.0) + .AddCustomChecker([](const float& step_w) { + PADDLE_ENFORCE_GE(step_w, 0.0, "step_w should be larger than 0."); + }); + AddAttr( + "step_h", + "Density prior boxes step across height, 0.0 for auto calculation.") + .SetDefault(0.0) + .AddCustomChecker([](const float& step_h) { + PADDLE_ENFORCE_GE(step_h, 0.0, "step_h should be larger than 0."); + }); + + AddAttr("offset", + "(float) " + "Density prior boxes center offset.") + .SetDefault(0.5); + AddAttr>("fixed_sizes", + "(vector) List of fixed sizes " + "of generated density prior boxes.") + .SetDefault(std::vector{}) + .AddCustomChecker([](const std::vector& fixed_sizes) { + for (size_t i = 0; i < fixed_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(fixed_sizes[i], 0.0, + "fixed_sizes[%d] should be larger than 0.", i); + } + }); + + AddAttr>("fixed_ratios", + "(vector) List of fixed ratios " + "of generated density prior boxes.") + .SetDefault(std::vector{}) + .AddCustomChecker([](const std::vector& fixed_ratios) { + for (size_t i = 0; i < fixed_ratios.size(); ++i) { + PADDLE_ENFORCE_GT(fixed_ratios[i], 0.0, + "fixed_ratios[%d] should be larger than 0.", i); + } + }); + + AddAttr>("densities", + "(vector) List of densities " + "of generated density prior boxes.") + .SetDefault(std::vector{}) + .AddCustomChecker([](const std::vector& densities) { + for (size_t i = 0; i < densities.size(); ++i) { + PADDLE_ENFORCE_GT(densities[i], 0, + "densities[%d] should be larger than 0.", i); + } + }); + AddComment(R"DOC( + Density Prior box operator + Each position of the input produce N density prior boxes, N is determined by + the count of fixed_ratios, densities, the calculation of N is as follows: + for density in densities: + N += size(fixed_ratios)*density^2 + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(density_prior_box, ops::DensityPriorBoxOp, + ops::DensityPriorBoxOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(density_prior_box, ops::DensityPriorBoxOpKernel, + ops::DensityPriorBoxOpKernel); diff --git a/paddle/fluid/operators/detection/density_prior_box_op.h b/paddle/fluid/operators/detection/density_prior_box_op.h new file mode 100644 index 000000000..9a52077e9 --- /dev/null +++ b/paddle/fluid/operators/detection/density_prior_box_op.h @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/operators/detection/prior_box_op.h" + +namespace paddle { +namespace operators { + +template +class DensityPriorBoxOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* vars = ctx.Output("Variances"); + + auto variances = ctx.Attr>("variances"); + auto clip = ctx.Attr("clip"); + + auto fixed_sizes = ctx.Attr>("fixed_sizes"); + auto fixed_ratios = ctx.Attr>("fixed_ratios"); + auto densities = ctx.Attr>("densities"); + + T step_w = static_cast(ctx.Attr("step_w")); + T step_h = static_cast(ctx.Attr("step_h")); + T offset = static_cast(ctx.Attr("offset")); + + auto img_width = image->dims()[3]; + auto img_height = image->dims()[2]; + + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; + + T step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + int num_priors = 0; + if (fixed_sizes.size() > 0 && densities.size() > 0) { + for (size_t i = 0; i < densities.size(); ++i) { + if (fixed_ratios.size() > 0) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); + } + } + } + + boxes->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + auto e_boxes = framework::EigenTensor::From(*boxes).setConstant(0.0); + + int step_average = static_cast((step_width + step_height) * 0.5); + + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + T center_x = (w + offset) * step_width; + T center_y = (h + offset) * step_height; + int idx = 0; + // Generate density prior boxes with fixed sizes. + for (size_t s = 0; s < fixed_sizes.size(); ++s) { + auto fixed_size = fixed_sizes[s]; + int density = densities[s]; + // Generate density prior boxes with fixed ratios. + if (fixed_ratios.size() > 0) { + for (size_t r = 0; r < fixed_ratios.size(); ++r) { + float ar = fixed_ratios[r]; + int shift = step_average / density; + float box_width_ratio = fixed_size * sqrt(ar); + float box_height_ratio = fixed_size / sqrt(ar); + for (int di = 0; di < density; ++di) { + for (int dj = 0; dj < density; ++dj) { + float center_x_temp = + center_x - step_average / 2. + shift / 2. + dj * shift; + float center_y_temp = + center_y - step_average / 2. + shift / 2. + di * shift; + e_boxes(h, w, idx, 0) = + (center_x_temp - box_width_ratio / 2.) / img_width >= 0 + ? (center_x_temp - box_width_ratio / 2.) / img_width + : 0; + e_boxes(h, w, idx, 1) = + (center_y_temp - box_height_ratio / 2.) / img_height >= 0 + ? (center_y_temp - box_height_ratio / 2.) / img_height + : 0; + e_boxes(h, w, idx, 2) = + (center_x_temp + box_width_ratio / 2.) / img_width <= 1 + ? (center_x_temp + box_width_ratio / 2.) / img_width + : 1; + e_boxes(h, w, idx, 3) = + (center_y_temp + box_height_ratio / 2.) / img_height <= 1 + ? (center_y_temp + box_height_ratio / 2.) / img_height + : 1; + idx++; + } + } + } + } + } + } + } + if (clip) { + platform::Transform trans; + ClipFunctor clip_func; + trans(ctx.template device_context(), + boxes->data(), boxes->data() + boxes->numel(), + boxes->data(), clip_func); + } + framework::Tensor var_t; + var_t.mutable_data( + framework::make_ddim({1, static_cast(variances.size())}), + ctx.GetPlace()); + + auto var_et = framework::EigenTensor::From(var_t); + + for (size_t i = 0; i < variances.size(); ++i) { + var_et(0, i) = variances[i]; + } + + int box_num = feature_height * feature_width * num_priors; + auto var_dim = vars->dims(); + vars->Resize({box_num, static_cast(variances.size())}); + + auto e_vars = framework::EigenMatrix::From(*vars); + + e_vars = var_et.broadcast(Eigen::DSizes(box_num, 1)); + + vars->Resize(var_dim); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 4ac94981a..96b6705e2 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -31,6 +31,7 @@ from functools import reduce __all__ = [ 'prior_box', + 'density_prior_box', 'multi_box_head', 'bipartite_match', 'target_assign', @@ -1023,6 +1024,135 @@ def prior_box(input, return box, var +def density_prior_box(input, + image, + densities=None, + fixed_sizes=None, + fixed_ratios=None, + variance=[0.1, 0.1, 0.2, 0.2], + clip=False, + steps=[0.0, 0.0], + offset=0.5, + name=None): + """ + **Density Prior Box Operator** + + Generate density prior boxes for SSD(Single Shot MultiBox Detector) + algorithm. Each position of the input produce N prior boxes, N is + determined by the count of densities, fixed_sizes and fixed_ratios. + Boxes center at grid points around each input position is generated by + this operator, and the grid points is determined by densities and + the count of density prior box is determined by fixed_sizes and fixed_ratios. + Obviously, the number of fixed_sizes is equal to the number of densities. + For densities_i in densities: + N_density_prior_box =sum(N_fixed_ratios * densities_i^2), + + Args: + input(Variable): The Input Variables, the format is NCHW. + image(Variable): The input image data of PriorBoxOp, + the layout is NCHW. + densities(list|tuple|None): the densities of generated density prior + boxes, this attribute should be a list or tuple of integers. + Default: None. + fixed_sizes(list|tuple|None): the fixed sizes of generated density + prior boxes, this attribute should a list or tuple of same + length with :attr:`densities`. Default: None. + fixed_ratios(list|tuple|None): the fixed ratios of generated density + prior boxes, if this attribute is not set and :attr:`densities` + and :attr:`fix_sizes` is set, :attr:`aspect_ratios` will be used + to generate density prior boxes. + variance(list|tuple): the variances to be encoded in density prior boxes. + Default:[0.1, 0.1, 0.2, 0.2]. + clip(bool): Whether to clip out-of-boundary boxes. Default: False. + step(list|turple): Prior boxes step across width and height, If + step[0] == 0.0/step[1] == 0.0, the density prior boxes step across + height/weight of the input will be automatically calculated. + Default: [0., 0.] + offset(float): Prior boxes center offset. Default: 0.5 + name(str): Name of the density prior box op. Default: None. + + Returns: + tuple: A tuple with two Variable (boxes, variances) + + boxes: the output density prior boxes of PriorBox. + The layout is [H, W, num_priors, 4]. + H is the height of input, W is the width of input, + num_priors is the total + box count of each position of input. + + variances: the expanded variances of PriorBox. + The layout is [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_priors is the total + box count of each position of input + + + Examples: + .. code-block:: python + + box, var = fluid.layers.density_prior_box( + input=conv1, + image=images, + min_sizes=[100.], + max_sizes=[200.], + aspect_ratios=[1.0, 1.0 / 2.0, 2.0], + densities=[3, 4], + fixed_sizes=[50., 60.], + fixed_ratios=[1.0, 3.0, 1.0 / 3.0], + flip=True, + clip=True) + """ + helper = LayerHelper("density_prior_box", **locals()) + dtype = helper.input_dtype() + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(densities): + raise TypeError('densities should be a list or a tuple or None.') + if not _is_list_or_tuple_(fixed_sizes): + raise TypeError('fixed_sizes should be a list or a tuple or None.') + if not _is_list_or_tuple_(fixed_ratios): + raise TypeError('fixed_ratios should be a list or a tuple or None.') + if len(densities) != len(fixed_sizes): + raise ValueError('densities and fixed_sizes length should be euqal.') + if not (_is_list_or_tuple_(steps) and len(steps) == 2): + raise ValueError('steps should be a list or tuple ', + 'with length 2, (step_width, step_height).') + + densities = list(map(int, densities)) + fixed_sizes = list(map(float, fixed_sizes)) + fixed_ratios = list(map(float, fixed_ratios)) + steps = list(map(float, steps)) + + attrs = { + 'variances': variance, + 'clip': clip, + 'step_w': steps[0], + 'step_h': steps[1], + 'offset': offset, + } + if densities is not None and len(densities) > 0: + attrs['densities'] = densities + if fixed_sizes is not None and len(fixed_sizes) > 0: + attrs['fixed_sizes'] = fixed_sizes + if fixed_ratios is not None and len(fixed_ratios) > 0: + attrs['fixed_ratios'] = fixed_ratios + + box = helper.create_variable_for_type_inference(dtype) + var = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="density_prior_box", + inputs={"Input": input, + "Image": image}, + outputs={"Boxes": box, + "Variances": var}, + attrs=attrs, ) + box.stop_gradient = True + var.stop_gradient = True + return box, var + + def multi_box_head(inputs, image, base_size, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 28dc75195..982d29180 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -128,6 +128,24 @@ class TestPriorBox(unittest.TestCase): assert box.shape[3] == 4 +class TestDensityPriorBox(unittest.TestCase): + def test_density_prior_box(self): + data_shape = [3, 224, 224] + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d(images, 3, 3, 2) + box, var = layers.density_prior_box( + input=conv1, + image=images, + densities=[3, 4], + fixed_sizes=[50., 60.], + fixed_ratios=[1.0], + clip=True) + assert len(box.shape) == 4 + assert box.shape == var.shape + assert box.shape[3] == 4 + + class TestAnchorGenerator(unittest.TestCase): def test_anchor_generator(self): data_shape = [3, 224, 224] diff --git a/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py b/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py new file mode 100644 index 000000000..79d1fd3d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py @@ -0,0 +1,142 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestDensityPriorBoxOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'variances': self.variances, + 'clip': self.clip, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset, + 'densities': self.densities, + 'fixed_sizes': self.fixed_sizes, + 'fixed_ratios': self.fixed_ratios + } + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "density_prior_box" + self.set_data() + + def set_density(self): + self.densities = [] + self.fixed_sizes = [] + self.fixed_ratios = [] + + def init_test_params(self): + self.layer_w = 32 + self.layer_h = 32 + + self.image_w = 40 + self.image_h = 40 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.set_density() + + self.clip = True + self.num_priors = 0 + if len(self.fixed_sizes) > 0 and len(self.densities) > 0: + for density in self.densities: + if len(self.fixed_ratios) > 0: + self.num_priors += len(self.fixed_ratios) * (pow(density, + 2)) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype('float32') + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype('float32') + out_var = np.zeros(out_dim).astype('float32') + + step_average = int((self.step_w + self.step_h) * 0.5) + for h in range(self.layer_h): + for w in range(self.layer_w): + idx = 0 + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + # Generate density prior boxes with fixed size + for density, fixed_size in zip(self.densities, + self.fixed_sizes): + if (len(self.fixed_ratios) > 0): + for ar in self.fixed_ratios: + shift = int(step_average / density) + box_width_ratio = fixed_size * math.sqrt(ar) + box_height_ratio = fixed_size / math.sqrt(ar) + for di in range(density): + for dj in range(density): + c_x_temp = c_x - step_average / 2.0 + shift / 2.0 + dj * shift + c_y_temp = c_y - step_average / 2.0 + shift / 2.0 + di * shift + out_boxes[h, w, idx, :] = [ + max((c_x_temp - box_width_ratio / 2.0) / + self.image_w, 0), + max((c_y_temp - box_height_ratio / 2.0) + / self.image_h, 0), + min((c_x_temp + box_width_ratio / 2.0) / + self.image_w, 1), + min((c_y_temp + box_height_ratio / 2.0) + / self.image_h, 1) + ] + idx += 1 + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + out_var = np.tile(self.variances, + (self.layer_h, self.layer_w, self.num_priors, 1)) + self.out_boxes = out_boxes.astype('float32') + self.out_var = out_var.astype('float32') + + +class TestDensityPriorBox(TestDensityPriorBoxOp): + def set_density(self): + self.densities = [3, 4] + self.fixed_sizes = [1.0, 2.0] + self.fixed_ratios = [1.0] + + +if __name__ == '__main__': + unittest.main() -- GitLab