提交 4a55fb5f 编写于 作者: R ruri 提交者: qingqing01

Add density_prior_box_op (#14226)

Density prior box operator for image detection model.
上级 9a6e2392
......@@ -274,6 +274,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k
paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False))
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None))
paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False))
paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
......
......@@ -22,6 +22,7 @@ iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc poly_util.cc gpc.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc)
detection_library(anchor_generator_op SRCS anchor_generator_op.cc
anchor_generator_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
......
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/density_prior_box_op.h"
namespace paddle {
namespace operators {
class DensityPriorBoxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of DensityPriorBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Image"),
"Input(Image) of DensityPriorBoxOp should not be null.");
auto image_dims = ctx->GetInputDim("Image");
auto input_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW.");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_LT(input_dims[2], image_dims[2],
"The height of input must smaller than image.");
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3],
"The width of input must smaller than image.");
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes");
auto fixed_ratios = ctx->Attrs().Get<std::vector<float>>("fixed_ratios");
auto densities = ctx->Attrs().Get<std::vector<int>>("densities");
PADDLE_ENFORCE_EQ(fixed_sizes.size(), densities.size(),
"The number of fixed_sizes and densities must be equal.");
size_t num_priors = 0;
if ((fixed_sizes.size() > 0) && (densities.size() > 0)) {
for (size_t i = 0; i < densities.size(); ++i) {
if (fixed_ratios.size() > 0) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
}
}
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_priors;
dim_vec[3] = 4;
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
platform::CPUPlace());
}
};
class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Input",
"(Tensor, default Tensor<float>), "
"the input feature data of DensityPriorBoxOp, the layout is NCHW.");
AddInput("Image",
"(Tensor, default Tensor<float>), "
"the input image data of DensityPriorBoxOp, the layout is NCHW.");
AddOutput("Boxes",
"(Tensor, default Tensor<float>), the output prior boxes of "
"DensityPriorBoxOp. The layout is [H, W, num_priors, 4]. "
"H is the height of input, W is the width of input, num_priors "
"is the box count of each position.");
AddOutput("Variances",
"(Tensor, default Tensor<float>), the expanded variances of "
"DensityPriorBoxOp. The layout is [H, W, num_priors, 4]. "
"H is the height of input, W is the width of input, num_priors "
"is the box count of each position.");
AddAttr<std::vector<float>>("variances",
"(vector<float>) List of variances to be "
"encoded in density prior boxes.")
.AddCustomChecker([](const std::vector<float>& variances) {
PADDLE_ENFORCE_EQ(variances.size(), 4,
"Must and only provide 4 variance.");
for (size_t i = 0; i < variances.size(); ++i) {
PADDLE_ENFORCE_GT(variances[i], 0.0,
"variance[%d] must be greater than 0.", i);
}
});
AddAttr<bool>("clip", "(bool) Whether to clip out-of-boundary boxes.")
.SetDefault(true);
AddAttr<float>(
"step_w",
"Density prior boxes step across width, 0.0 for auto calculation.")
.SetDefault(0.0)
.AddCustomChecker([](const float& step_w) {
PADDLE_ENFORCE_GE(step_w, 0.0, "step_w should be larger than 0.");
});
AddAttr<float>(
"step_h",
"Density prior boxes step across height, 0.0 for auto calculation.")
.SetDefault(0.0)
.AddCustomChecker([](const float& step_h) {
PADDLE_ENFORCE_GE(step_h, 0.0, "step_h should be larger than 0.");
});
AddAttr<float>("offset",
"(float) "
"Density prior boxes center offset.")
.SetDefault(0.5);
AddAttr<std::vector<float>>("fixed_sizes",
"(vector<float>) List of fixed sizes "
"of generated density prior boxes.")
.SetDefault(std::vector<float>{})
.AddCustomChecker([](const std::vector<float>& fixed_sizes) {
for (size_t i = 0; i < fixed_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(fixed_sizes[i], 0.0,
"fixed_sizes[%d] should be larger than 0.", i);
}
});
AddAttr<std::vector<float>>("fixed_ratios",
"(vector<float>) List of fixed ratios "
"of generated density prior boxes.")
.SetDefault(std::vector<float>{})
.AddCustomChecker([](const std::vector<float>& fixed_ratios) {
for (size_t i = 0; i < fixed_ratios.size(); ++i) {
PADDLE_ENFORCE_GT(fixed_ratios[i], 0.0,
"fixed_ratios[%d] should be larger than 0.", i);
}
});
AddAttr<std::vector<int>>("densities",
"(vector<float>) List of densities "
"of generated density prior boxes.")
.SetDefault(std::vector<int>{})
.AddCustomChecker([](const std::vector<int>& densities) {
for (size_t i = 0; i < densities.size(); ++i) {
PADDLE_ENFORCE_GT(densities[i], 0,
"densities[%d] should be larger than 0.", i);
}
});
AddComment(R"DOC(
Density Prior box operator
Each position of the input produce N density prior boxes, N is determined by
the count of fixed_ratios, densities, the calculation of N is as follows:
for density in densities:
N += size(fixed_ratios)*density^2
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(density_prior_box, ops::DensityPriorBoxOp,
ops::DensityPriorBoxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(density_prior_box, ops::DensityPriorBoxOpKernel<float>,
ops::DensityPriorBoxOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/detection/prior_box_op.h"
namespace paddle {
namespace operators {
template <typename T>
class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto clip = ctx.Attr<bool>("clip");
auto fixed_sizes = ctx.Attr<std::vector<float>>("fixed_sizes");
auto fixed_ratios = ctx.Attr<std::vector<float>>("fixed_ratios");
auto densities = ctx.Attr<std::vector<int>>("densities");
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto img_width = image->dims()[3];
auto img_height = image->dims()[2];
auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2];
T step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(img_width) / feature_width;
step_height = static_cast<T>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = 0;
if (fixed_sizes.size() > 0 && densities.size() > 0) {
for (size_t i = 0; i < densities.size(); ++i) {
if (fixed_ratios.size() > 0) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
}
}
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes).setConstant(0.0);
int step_average = static_cast<int>((step_width + step_height) * 0.5);
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
T center_x = (w + offset) * step_width;
T center_y = (h + offset) * step_height;
int idx = 0;
// Generate density prior boxes with fixed sizes.
for (size_t s = 0; s < fixed_sizes.size(); ++s) {
auto fixed_size = fixed_sizes[s];
int density = densities[s];
// Generate density prior boxes with fixed ratios.
if (fixed_ratios.size() > 0) {
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
float ar = fixed_ratios[r];
int shift = step_average / density;
float box_width_ratio = fixed_size * sqrt(ar);
float box_height_ratio = fixed_size / sqrt(ar);
for (int di = 0; di < density; ++di) {
for (int dj = 0; dj < density; ++dj) {
float center_x_temp =
center_x - step_average / 2. + shift / 2. + dj * shift;
float center_y_temp =
center_y - step_average / 2. + shift / 2. + di * shift;
e_boxes(h, w, idx, 0) =
(center_x_temp - box_width_ratio / 2.) / img_width >= 0
? (center_x_temp - box_width_ratio / 2.) / img_width
: 0;
e_boxes(h, w, idx, 1) =
(center_y_temp - box_height_ratio / 2.) / img_height >= 0
? (center_y_temp - box_height_ratio / 2.) / img_height
: 0;
e_boxes(h, w, idx, 2) =
(center_x_temp + box_width_ratio / 2.) / img_width <= 1
? (center_x_temp + box_width_ratio / 2.) / img_width
: 1;
e_boxes(h, w, idx, 3) =
(center_y_temp + box_height_ratio / 2.) / img_height <= 1
? (center_y_temp + box_height_ratio / 2.) / img_height
: 1;
idx++;
}
}
}
}
}
}
}
if (clip) {
platform::Transform<platform::CPUDeviceContext> trans;
ClipFunctor<T> clip_func;
trans(ctx.template device_context<platform::CPUDeviceContext>(),
boxes->data<T>(), boxes->data<T>() + boxes->numel(),
boxes->data<T>(), clip_func);
}
framework::Tensor var_t;
var_t.mutable_data<T>(
framework::make_ddim({1, static_cast<int>(variances.size())}),
ctx.GetPlace());
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
for (size_t i = 0; i < variances.size(); ++i) {
var_et(0, i) = variances[i];
}
int box_num = feature_height * feature_width * num_priors;
auto var_dim = vars->dims();
vars->Resize({box_num, static_cast<int>(variances.size())});
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(box_num, 1));
vars->Resize(var_dim);
}
}; // namespace operators
} // namespace operators
} // namespace paddle
......@@ -31,6 +31,7 @@ from functools import reduce
__all__ = [
'prior_box',
'density_prior_box',
'multi_box_head',
'bipartite_match',
'target_assign',
......@@ -1023,6 +1024,135 @@ def prior_box(input,
return box, var
def density_prior_box(input,
image,
densities=None,
fixed_sizes=None,
fixed_ratios=None,
variance=[0.1, 0.1, 0.2, 0.2],
clip=False,
steps=[0.0, 0.0],
offset=0.5,
name=None):
"""
**Density Prior Box Operator**
Generate density prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. Each position of the input produce N prior boxes, N is
determined by the count of densities, fixed_sizes and fixed_ratios.
Boxes center at grid points around each input position is generated by
this operator, and the grid points is determined by densities and
the count of density prior box is determined by fixed_sizes and fixed_ratios.
Obviously, the number of fixed_sizes is equal to the number of densities.
For densities_i in densities:
N_density_prior_box =sum(N_fixed_ratios * densities_i^2),
Args:
input(Variable): The Input Variables, the format is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
densities(list|tuple|None): the densities of generated density prior
boxes, this attribute should be a list or tuple of integers.
Default: None.
fixed_sizes(list|tuple|None): the fixed sizes of generated density
prior boxes, this attribute should a list or tuple of same
length with :attr:`densities`. Default: None.
fixed_ratios(list|tuple|None): the fixed ratios of generated density
prior boxes, if this attribute is not set and :attr:`densities`
and :attr:`fix_sizes` is set, :attr:`aspect_ratios` will be used
to generate density prior boxes.
variance(list|tuple): the variances to be encoded in density prior boxes.
Default:[0.1, 0.1, 0.2, 0.2].
clip(bool): Whether to clip out-of-boundary boxes. Default: False.
step(list|turple): Prior boxes step across width and height, If
step[0] == 0.0/step[1] == 0.0, the density prior boxes step across
height/weight of the input will be automatically calculated.
Default: [0., 0.]
offset(float): Prior boxes center offset. Default: 0.5
name(str): Name of the density prior box op. Default: None.
Returns:
tuple: A tuple with two Variable (boxes, variances)
boxes: the output density prior boxes of PriorBox.
The layout is [H, W, num_priors, 4].
H is the height of input, W is the width of input,
num_priors is the total
box count of each position of input.
variances: the expanded variances of PriorBox.
The layout is [H, W, num_priors, 4].
H is the height of input, W is the width of input
num_priors is the total
box count of each position of input
Examples:
.. code-block:: python
box, var = fluid.layers.density_prior_box(
input=conv1,
image=images,
min_sizes=[100.],
max_sizes=[200.],
aspect_ratios=[1.0, 1.0 / 2.0, 2.0],
densities=[3, 4],
fixed_sizes=[50., 60.],
fixed_ratios=[1.0, 3.0, 1.0 / 3.0],
flip=True,
clip=True)
"""
helper = LayerHelper("density_prior_box", **locals())
dtype = helper.input_dtype()
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(densities):
raise TypeError('densities should be a list or a tuple or None.')
if not _is_list_or_tuple_(fixed_sizes):
raise TypeError('fixed_sizes should be a list or a tuple or None.')
if not _is_list_or_tuple_(fixed_ratios):
raise TypeError('fixed_ratios should be a list or a tuple or None.')
if len(densities) != len(fixed_sizes):
raise ValueError('densities and fixed_sizes length should be euqal.')
if not (_is_list_or_tuple_(steps) and len(steps) == 2):
raise ValueError('steps should be a list or tuple ',
'with length 2, (step_width, step_height).')
densities = list(map(int, densities))
fixed_sizes = list(map(float, fixed_sizes))
fixed_ratios = list(map(float, fixed_ratios))
steps = list(map(float, steps))
attrs = {
'variances': variance,
'clip': clip,
'step_w': steps[0],
'step_h': steps[1],
'offset': offset,
}
if densities is not None and len(densities) > 0:
attrs['densities'] = densities
if fixed_sizes is not None and len(fixed_sizes) > 0:
attrs['fixed_sizes'] = fixed_sizes
if fixed_ratios is not None and len(fixed_ratios) > 0:
attrs['fixed_ratios'] = fixed_ratios
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="density_prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs=attrs, )
box.stop_gradient = True
var.stop_gradient = True
return box, var
def multi_box_head(inputs,
image,
base_size,
......
......@@ -128,6 +128,24 @@ class TestPriorBox(unittest.TestCase):
assert box.shape[3] == 4
class TestDensityPriorBox(unittest.TestCase):
def test_density_prior_box(self):
data_shape = [3, 224, 224]
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
box, var = layers.density_prior_box(
input=conv1,
image=images,
densities=[3, 4],
fixed_sizes=[50., 60.],
fixed_ratios=[1.0],
clip=True)
assert len(box.shape) == 4
assert box.shape == var.shape
assert box.shape[3] == 4
class TestAnchorGenerator(unittest.TestCase):
def test_anchor_generator(self):
data_shape = [3, 224, 224]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
class TestDensityPriorBoxOp(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image}
self.attrs = {
'variances': self.variances,
'clip': self.clip,
'step_w': self.step_w,
'step_h': self.step_h,
'offset': self.offset,
'densities': self.densities,
'fixed_sizes': self.fixed_sizes,
'fixed_ratios': self.fixed_ratios
}
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "density_prior_box"
self.set_data()
def set_density(self):
self.densities = []
self.fixed_sizes = []
self.fixed_ratios = []
def init_test_params(self):
self.layer_w = 32
self.layer_h = 32
self.image_w = 40
self.image_h = 40
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
self.input_channels = 2
self.image_channels = 3
self.batch_size = 10
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float).flatten()
self.set_density()
self.clip = True
self.num_priors = 0
if len(self.fixed_sizes) > 0 and len(self.densities) > 0:
for density in self.densities:
if len(self.fixed_ratios) > 0:
self.num_priors += len(self.fixed_ratios) * (pow(density,
2))
self.offset = 0.5
def init_test_input(self):
self.image = np.random.random(
(self.batch_size, self.image_channels, self.image_w,
self.image_h)).astype('float32')
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_w,
self.layer_h)).astype('float32')
def init_test_output(self):
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
out_boxes = np.zeros(out_dim).astype('float32')
out_var = np.zeros(out_dim).astype('float32')
step_average = int((self.step_w + self.step_h) * 0.5)
for h in range(self.layer_h):
for w in range(self.layer_w):
idx = 0
c_x = (w + self.offset) * self.step_w
c_y = (h + self.offset) * self.step_h
# Generate density prior boxes with fixed size
for density, fixed_size in zip(self.densities,
self.fixed_sizes):
if (len(self.fixed_ratios) > 0):
for ar in self.fixed_ratios:
shift = int(step_average / density)
box_width_ratio = fixed_size * math.sqrt(ar)
box_height_ratio = fixed_size / math.sqrt(ar)
for di in range(density):
for dj in range(density):
c_x_temp = c_x - step_average / 2.0 + shift / 2.0 + dj * shift
c_y_temp = c_y - step_average / 2.0 + shift / 2.0 + di * shift
out_boxes[h, w, idx, :] = [
max((c_x_temp - box_width_ratio / 2.0) /
self.image_w, 0),
max((c_y_temp - box_height_ratio / 2.0)
/ self.image_h, 0),
min((c_x_temp + box_width_ratio / 2.0) /
self.image_w, 1),
min((c_y_temp + box_height_ratio / 2.0)
/ self.image_h, 1)
]
idx += 1
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
out_var = np.tile(self.variances,
(self.layer_h, self.layer_w, self.num_priors, 1))
self.out_boxes = out_boxes.astype('float32')
self.out_var = out_var.astype('float32')
class TestDensityPriorBox(TestDensityPriorBoxOp):
def set_density(self):
self.densities = [3, 4]
self.fixed_sizes = [1.0, 2.0]
self.fixed_ratios = [1.0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册