From 5056d3ec56c2bf391ecd5a1110b900eb78e2d4c8 Mon Sep 17 00:00:00 2001 From: Xingyuan Bu Date: Mon, 2 Jul 2018 09:57:27 +0800 Subject: [PATCH] FasterRCNN Anchor Generator Op (#11218) * Add anchor generator operator for Faster-RCNN. * Add unittest testing. * Add Python API. --- .../fluid/operators/detection/CMakeLists.txt | 2 + .../detection/anchor_generator_op.cc | 154 ++++++++++++++++++ .../detection/anchor_generator_op.cu | 132 +++++++++++++++ .../operators/detection/anchor_generator_op.h | 109 +++++++++++++ python/paddle/fluid/layers/detection.py | 93 +++++++++++ python/paddle/fluid/tests/test_detection.py | 18 ++ .../unittests/test_anchor_generator_op.py | 110 +++++++++++++ 7 files changed, 618 insertions(+) create mode 100644 paddle/fluid/operators/detection/anchor_generator_op.cc create mode 100644 paddle/fluid/operators/detection/anchor_generator_op.cu create mode 100644 paddle/fluid/operators/detection/anchor_generator_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_anchor_generator_op.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 20d960f9fe..6d296ff7bf 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -22,6 +22,8 @@ iou_similarity_op.cu) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) +detection_library(anchor_generator_op SRCS anchor_generator_op.cc +anchor_generator_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc target_assign_op.cu) detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc new file mode 100644 index 0000000000..0c0155a0a9 --- /dev/null +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/anchor_generator_op.h" + +namespace paddle { +namespace operators { + +class AnchorGeneratorOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of AnchorGeneratorOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Anchors"), + "Output(Anchors) of AnchorGeneratorOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Variances"), + "Output(Variances) of AnchorGeneratorOp should not be null."); + + auto input_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + auto anchor_sizes = ctx->Attrs().Get>("anchor_sizes"); + auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios"); + auto stride = ctx->Attrs().Get>("stride"); + auto variances = ctx->Attrs().Get>("variances"); + + size_t num_anchors = aspect_ratios.size() * anchor_sizes.size(); + + std::vector dim_vec(4); + dim_vec[0] = input_dims[2]; + dim_vec[1] = input_dims[3]; + dim_vec[2] = num_anchors; + dim_vec[3] = 4; + ctx->SetOutputDim("Anchors", framework::make_ddim(dim_vec)); + ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +class AnchorGeneratorOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(Tensor, default Tensor), " + "the input feature is a tensor with a rank of 4. " + "The layout is NCHW."); + AddOutput("Anchors", + "(Tensor, default Tensor), the output is a " + "tensor with a rank of 4. The layout is [H, W, num_anchors, 4]. " + "H is the height of input, W is the width of input, num_anchors " + "is the box count of each position. " + "Each anchor is in (xmin, ymin, xmax, ymax) format"); + AddOutput("Variances", + "(Tensor, default Tensor), the expanded variances for " + "normalizing bbox regression targets. The layout is [H, W, " + "num_anchors, 4]. " + "H is the height of input, W is the width of input, num_anchors " + "is the box count of each position. " + "Each variance is in (xcenter, ycenter, w, h) format"); + + AddAttr>( + "anchor_sizes", + "(vector) List of Region Proposal Network(RPN) anchor sizes " + " given in absolute pixels e.g. (64, 128, 256, 512)." + " For instance, the anchor size of 64 means the area of this anchor " + "equals to 64**2.") + .AddCustomChecker([](const std::vector& anchor_sizes) { + PADDLE_ENFORCE_GT(anchor_sizes.size(), 0, + "Size of anchor_sizes must be at least 1."); + for (size_t i = 0; i < anchor_sizes.size(); ++i) { + PADDLE_ENFORCE_GT(anchor_sizes[i], 0.0, + "anchor_sizes[%d] must be positive.", i); + } + }); + AddAttr>( + "aspect_ratios", + "(vector) List of Region Proposal Network(RPN) anchor aspect " + "ratios, e.g. (0.5, 1, 2)." + "For instacne, the aspect ratio of 0.5 means the height / width of " + "this anchor equals 0.5."); + + AddAttr>("variances", + "(vector) List of variances to be used " + "in box regression deltas") + .AddCustomChecker([](const std::vector& variances) { + PADDLE_ENFORCE_EQ(variances.size(), 4, + "Must and only provide 4 variance."); + for (size_t i = 0; i < variances.size(); ++i) { + PADDLE_ENFORCE_GT(variances[i], 0.0, + "variance[%d] must be greater than 0.", i); + } + }); + + AddAttr>("stride", + "Anchors stride across width and height, " + "with a default of (16, 16)") + .SetDefault(std::vector(2, 16.0)) + .AddCustomChecker([](const std::vector& stride) { + PADDLE_ENFORCE_EQ( + stride.size(), 2, + "Must and only provide 2 stride for width and height."); + for (size_t i = 0; i < stride.size(); ++i) { + PADDLE_ENFORCE_GT(stride[i], 0.0, + "stride[%d] should be larger than 0.", i); + } + }); + + AddAttr("offset", + "(float) " + "Anchor center offset, with a default of 0.5") + .SetDefault(0.5); + AddComment(R"DOC( +AnchorGenerator operator +Generates anchors for Faster RCNN, FPN etc. algorithm. +Each position of the input produce N anchors, N = + size(anchor_sizes) * size(aspect_ratios). + +Please get more information from the following papers: +https://arxiv.org/abs/1506.01497. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(anchor_generator, ops::AnchorGeneratorOp, + ops::AnchorGeneratorOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel, + ops::AnchorGeneratorOpKernel); diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cu b/paddle/fluid/operators/detection/anchor_generator_op.cu new file mode 100644 index 0000000000..3cc9bbeee1 --- /dev/null +++ b/paddle/fluid/operators/detection/anchor_generator_op.cu @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/anchor_generator_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num, + const T* anchor_sizes, const int as_num, + const T* stride, const int sd_num, const int height, + const int width, const T offset) { + int num_anchors = as_num * ar_num; + int box_num = height * width * num_anchors; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; + i += blockDim.x * gridDim.x) { + int h_idx = i / (num_anchors * width); + int w_idx = (i / num_anchors) % width; + T stride_width = stride[0]; + T stride_height = stride[1]; + T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1); + T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1); + T area, area_ratios; + T base_w, base_h; + T scale_w, scale_h; + T anchor_width, anchor_height; + int anch_idx = i % num_anchors; + int ar_idx = anch_idx / as_num; + int as_idx = anch_idx % as_num; + T aspect_ratio = aspect_ratios[ar_idx]; + T anchor_size = anchor_sizes[as_idx]; + area = stride_width * stride_height; + area_ratios = area / aspect_ratio; + base_w = round(sqrt(area_ratios)); + base_h = round(base_w * aspect_ratio); + scale_w = anchor_size / stride_width; + scale_h = anchor_size / stride_height; + anchor_width = scale_w * base_w; + anchor_height = scale_h * base_h; + + T xmin = (x_ctr - 0.5 * (anchor_width - 1)); + T ymin = (y_ctr - 0.5 * (anchor_height - 1)); + T xmax = (x_ctr + 0.5 * (anchor_width - 1)); + T ymax = (y_ctr + 0.5 * (anchor_height - 1)); + out[i * 4] = xmin; + out[i * 4 + 1] = ymin; + out[i * 4 + 2] = xmax; + out[i * 4 + 3] = ymax; + } +} + +template +__global__ void SetVariance(T* out, const T* var, const int vnum, + const int num) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + out[i] = var[i % vnum]; + } +} + +template +class AnchorGeneratorOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* anchors = ctx.Output("Anchors"); + auto* vars = ctx.Output("Variances"); + + auto anchor_sizes = ctx.Attr>("anchor_sizes"); + auto aspect_ratios = ctx.Attr>("aspect_ratios"); + auto stride = ctx.Attr>("stride"); + auto variances = ctx.Attr>("variances"); + + T offset = static_cast(ctx.Attr("offset")); + + auto width = input->dims()[3]; + auto height = input->dims()[2]; + + int num_anchors = aspect_ratios.size() * anchor_sizes.size(); + + int box_num = width * height * num_anchors; + + int block = 512; + int grid = (box_num + block - 1) / block; + + auto stream = + ctx.template device_context().stream(); + + anchors->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + + framework::Tensor ar; + framework::TensorFromVector(aspect_ratios, ctx.device_context(), &ar); + + framework::Tensor as; + framework::TensorFromVector(anchor_sizes, ctx.device_context(), &as); + + framework::Tensor sd; + framework::TensorFromVector(stride, ctx.device_context(), &sd); + + GenAnchors<<>>( + anchors->data(), ar.data(), aspect_ratios.size(), as.data(), + anchor_sizes.size(), sd.data(), stride.size(), height, width, + offset); + + framework::Tensor v; + framework::TensorFromVector(variances, ctx.device_context(), &v); + grid = (box_num * 4 + block - 1) / block; + SetVariance<<>>(vars->data(), v.data(), + variances.size(), box_num * 4); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(anchor_generator, + ops::AnchorGeneratorOpCUDAKernel, + ops::AnchorGeneratorOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/anchor_generator_op.h b/paddle/fluid/operators/detection/anchor_generator_op.h new file mode 100644 index 0000000000..e0e499d76a --- /dev/null +++ b/paddle/fluid/operators/detection/anchor_generator_op.h @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +template +class AnchorGeneratorOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* anchors = ctx.Output("Anchors"); + auto* vars = ctx.Output("Variances"); + + auto anchor_sizes = ctx.Attr>("anchor_sizes"); + auto aspect_ratios = ctx.Attr>("aspect_ratios"); + auto stride = ctx.Attr>("stride"); + auto variances = ctx.Attr>("variances"); + + T offset = static_cast(ctx.Attr("offset")); + + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; + + T stride_width, stride_height; + stride_width = stride[0]; + stride_height = stride[1]; + + int num_anchors = aspect_ratios.size() * anchor_sizes.size(); + + anchors->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + + auto e_anchors = framework::EigenTensor::From(*anchors); + for (int h_idx = 0; h_idx < feature_height; ++h_idx) { + for (int w_idx = 0; w_idx < feature_width; ++w_idx) { + T x_ctr = (w_idx * stride_width) + offset * (stride_width - 1); + T y_ctr = (h_idx * stride_height) + offset * (stride_height - 1); + T area, area_ratios; + T base_w, base_h; + T scale_w, scale_h; + T anchor_width, anchor_height; + int idx = 0; + for (size_t r = 0; r < aspect_ratios.size(); ++r) { + auto ar = aspect_ratios[r]; + for (size_t s = 0; s < anchor_sizes.size(); ++s) { + auto anchor_size = anchor_sizes[s]; + area = stride_width * stride_height; + area_ratios = area / ar; + base_w = round(sqrt(area_ratios)); + base_h = round(base_w * ar); + scale_w = anchor_size / stride_width; + scale_h = anchor_size / stride_height; + anchor_width = scale_w * base_w; + anchor_height = scale_h * base_h; + e_anchors(h_idx, w_idx, idx, 0) = + (x_ctr - 0.5 * (anchor_width - 1)); + e_anchors(h_idx, w_idx, idx, 1) = + (y_ctr - 0.5 * (anchor_height - 1)); + e_anchors(h_idx, w_idx, idx, 2) = + (x_ctr + 0.5 * (anchor_width - 1)); + e_anchors(h_idx, w_idx, idx, 3) = + (y_ctr + 0.5 * (anchor_height - 1)); + idx++; + } + } + } + } + + framework::Tensor var_t; + var_t.mutable_data( + framework::make_ddim({1, static_cast(variances.size())}), + ctx.GetPlace()); + auto var_et = framework::EigenTensor::From(var_t); + for (size_t i = 0; i < variances.size(); ++i) { + var_et(0, i) = variances[i]; + } + + int anchor_num = feature_height * feature_width * num_anchors; + auto var_dim = vars->dims(); + vars->Resize({anchor_num, static_cast(variances.size())}); + + auto e_vars = framework::EigenMatrix::From(*vars); + e_vars = var_et.broadcast(Eigen::DSizes(anchor_num, 1)); + + vars->Resize(var_dim); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 200db87f17..6af01297df 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -30,6 +30,7 @@ __all__ = [ 'detection_output', 'ssd_loss', 'detection_map', + 'anchor_generator', ] __auto__ = [ @@ -998,3 +999,95 @@ def multi_box_head(inputs, box.stop_gradient = True var.stop_gradient = True return mbox_locs_concat, mbox_confs_concat, box, var + + +def anchor_generator(input, + anchor_sizes=None, + aspect_ratios=None, + variance=[0.1, 0.1, 0.2, 0.2], + stride=None, + offset=0.5, + name=None): + """ + **Anchor generator operator** + + Generate anchors for Faster RCNN algorithm. + Each position of the input produce N anchors, N = + size(anchor_sizes) * size(aspect_ratios). The order of generated anchors + is firstly aspect_ratios loop then anchor_sizes loop. + + Args: + input(Variable): The input feature map, the format is NCHW. + anchor_sizes(list|tuple|float): The anchor sizes of generated anchors, + given in absolute pixels e.g. [64., 128., 256., 512.]. + For instance, the anchor size of 64 means the area of this anchor equals to 64**2. + aspect_ratios(list|tuple|float): The height / width ratios of generated + anchors, e.g. [0.5, 1.0, 2.0]. + variance(list|tuple): The variances to be used in box regression deltas. + Default:[0.1, 0.1, 0.2, 0.2]. + stride(list|turple): The anchors stride across width and height, + e.g. [16.0, 16.0] + offset(float): Prior boxes center offset. Default: 0.5 + name(str): Name of the prior box op. Default: None. + + Returns: + Anchors(Variable): The output anchors with a layout of [H, W, num_anchors, 4]. + H is the height of input, W is the width of input, + num_anchors is the box count of each position. + Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized. + Variances(Variable): The expanded variances of anchors + with a layout of [H, W, num_priors, 4]. + H is the height of input, W is the width of input + num_anchors is the box count of each position. + Each variance is in (xcenter, ycenter, w, h) format. + + + Examples: + + .. code-block:: python + + anchor, var = anchor_generator( + input=conv1, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + """ + helper = LayerHelper("anchor_generator", **locals()) + dtype = helper.input_dtype() + + def _is_list_or_tuple_(data): + return (isinstance(data, list) or isinstance(data, tuple)) + + if not _is_list_or_tuple_(anchor_sizes): + anchor_sizes = [anchor_sizes] + if not _is_list_or_tuple_(aspect_ratios): + aspect_ratios = [aspect_ratios] + if not (_is_list_or_tuple_(stride) and len(stride) == 2): + raise ValueError('stride should be a list or tuple ', + 'with length 2, (stride_width, stride_height).') + + anchor_sizes = list(map(float, anchor_sizes)) + aspect_ratios = list(map(float, aspect_ratios)) + stride = list(map(float, stride)) + + attrs = { + 'anchor_sizes': anchor_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'stride': stride, + 'offset': offset + } + + anchor = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + helper.append_op( + type="anchor_generator", + inputs={"Input": input}, + outputs={"Anchors": anchor, + "Variances": var}, + attrs=attrs, ) + anchor.stop_gradient = True + var.stop_gradient = True + return anchor, var diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 8569d838bd..2d70c986b1 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -127,6 +127,24 @@ class TestPriorBox(unittest.TestCase): assert box.shape[3] == 4 +class TestAnchorGenerator(unittest.TestCase): + def test_anchor_generator(self): + data_shape = [3, 224, 224] + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d(images, 3, 3, 2) + anchor, var = fluid.layers.anchor_generator( + input=conv1, + anchor_sizes=[64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + variance=[0.1, 0.1, 0.2, 0.2], + stride=[16.0, 16.0], + offset=0.5) + assert len(anchor.shape) == 4 + assert anchor.shape == var.shape + assert anchor.shape[3] == 4 + + class TestMultiBoxHead(unittest.TestCase): def test_multi_box_head(self): data_shape = [3, 224, 224] diff --git a/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py b/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py new file mode 100644 index 0000000000..9c7d5d41f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_anchor_generator_op.py @@ -0,0 +1,110 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://w_idxw.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +def anchor_generator_in_python(input_feat, anchor_sizes, aspect_ratios, + variances, stride, offset): + num_anchors = len(aspect_ratios) * len(anchor_sizes) + layer_h = input_feat.shape[2] + layer_w = input_feat.shape[3] + out_dim = (layer_h, layer_w, num_anchors, 4) + out_anchors = np.zeros(out_dim).astype('float32') + + for h_idx in range(layer_h): + for w_idx in range(layer_w): + x_ctr = (w_idx * stride[0]) + offset * (stride[0] - 1) + y_ctr = (h_idx * stride[1]) + offset * (stride[1] - 1) + idx = 0 + for r in range(len(aspect_ratios)): + ar = aspect_ratios[r] + for s in range(len(anchor_sizes)): + anchor_size = anchor_sizes[s] + area = stride[0] * stride[1] + area_ratios = area / ar + base_w = np.round(np.sqrt(area_ratios)) + base_h = np.round(base_w * ar) + scale_w = anchor_size / stride[0] + scale_h = anchor_size / stride[1] + w = scale_w * base_w + h = scale_h * base_h + out_anchors[h_idx, w_idx, idx, :] = [ + (x_ctr - 0.5 * (w - 1)), (y_ctr - 0.5 * (h - 1)), + (x_ctr + 0.5 * (w - 1)), (y_ctr + 0.5 * (h - 1)) + ] + idx += 1 + + # set the variance. + out_var = np.tile(variances, (layer_h, layer_w, num_anchors, 1)) + out_anchors = out_anchors.astype('float32') + out_var = out_var.astype('float32') + return out_anchors, out_var + + +class TestAnchorGeneratorOp(OpTest): + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input} + + self.attrs = { + 'anchor_sizes': self.anchor_sizes, + 'aspect_ratios': self.aspect_ratios, + 'stride': self.stride, + 'offset': self.offset, + 'variances': self.variances, + } + + self.outputs = {'Anchors': self.out_anchors, 'Variances': self.out_var} + + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "anchor_generator" + self.set_data() + + def init_test_params(self): + self.batch_size = 1 + self.input_channels = 2 + self.layer_h = 2 + self.layer_w = 2 + + self.anchor_sizes = [64., 128., 256., 512.] + self.aspect_ratios = [0.5, 1., 2.] + self.stride = [16., 16.] + + self.offset = 0.5 + + self.variances = [0.1, 0.1, 0.2, 0.2] + + def init_test_input(self): + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_h, + self.layer_w)).astype('float32') + + def init_test_output(self): + self.out_anchors, self.out_var = anchor_generator_in_python( + self.input, self.anchor_sizes, self.aspect_ratios, self.variances, + self.stride, self.offset) + + +if __name__ == '__main__': + unittest.main() -- GitLab