未验证 提交 81be9cef 编写于 作者: W Wang Hao 提交者: GitHub

Merge pull request #6150 from wanghaox/prior_box

prior box operator for ssd
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/prior_box_op.h"
namespace paddle {
namespace operators {
class PriorBoxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of PriorBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Image"),
"Input(Image) of PriorBoxOp should not be null.");
auto image_dims = ctx->GetInputDim("Image");
auto input_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(image_dims.size() == 4, "The layout of image is NCHW.");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_LT(input_dims[2], image_dims[2],
"The height of input must smaller than image.");
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3],
"The width of input must smaller than image.");
auto min_sizes = ctx->Attrs().Get<std::vector<int>>("min_sizes");
auto max_sizes = ctx->Attrs().Get<std::vector<int>>("max_sizes");
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
bool flip = ctx->Attrs().Get<bool>("flip");
PADDLE_ENFORCE_GT(min_sizes.size(), 0,
"Size of min_sizes must be at least 1.");
for (size_t i = 0; i < min_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i);
}
std::vector<float> aspect_ratios_vec;
ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec);
int num_priors = aspect_ratios_vec.size() * min_sizes.size();
if (max_sizes.size() > 0) {
PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(),
"The number of min_size and max_size must be equal.");
for (size_t i = 0; i < min_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i],
"max_size[%d] must be greater than min_size[%d].", i,
i);
num_priors += 1;
}
}
PADDLE_ENFORCE_EQ(variances.size(), 4, "Must and only provide 4 variance.");
for (size_t i = 0; i < variances.size(); ++i) {
PADDLE_ENFORCE_GT(variances[i], 0.0,
"variance[%d] must be greater than 0.", i);
}
const float step_h = ctx->Attrs().Get<float>("step_h");
PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0.");
const float step_w = ctx->Attrs().Get<float>("step_w");
PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0.");
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_priors;
dim_vec[3] = 4;
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
}
};
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor, default Tensor<float>), "
"the input feature data of PriorBoxOp, The layout is NCHW.");
AddInput("Image",
"(Tensor, default Tensor<float>), "
"the input image data of PriorBoxOp, The layout is NCHW.");
AddOutput("Boxes",
"(Tensor, default Tensor<float>), the output prior boxes of "
"PriorBoxOp. The layout is [H, W, num_priors, 4]. "
"H is the height of input, W is the width of input, num_priors "
"is the box count of each position.");
AddOutput("Variances",
"(Tensor, default Tensor<float>), the expanded variances of "
"PriorBoxOp. The layout is [H, W, num_priors, 4]. "
"H is the height of input, W is the width of input, num_priors "
"is the box count of each position.");
AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ",
"List of min sizes of generated prior boxes.");
AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ",
"List of max sizes of generated prior boxes.");
AddAttr<std::vector<float>>(
"aspect_ratios", "(vector<float>) ",
"List of aspect ratios of generated prior boxes.");
AddAttr<std::vector<float>>(
"variances", "(vector<float>) ",
"List of variances to be encoded in prior boxes.");
AddAttr<bool>("flip", "(bool) ", "Whether to flip aspect ratios.")
.SetDefault(true);
AddAttr<bool>("clip", "(bool) ", "Whether to clip out-of-boundary boxes.")
.SetDefault(true);
AddAttr<float>("step_w",
"Prior boxes step across width, 0 for auto calculation.")
.SetDefault(0.0);
AddAttr<float>("step_h",
"Prior boxes step across height, 0 for auto calculation.")
.SetDefault(0.0);
AddAttr<float>("offset",
"(float) "
"Prior boxes center offset.")
.SetDefault(0.5);
AddComment(R"DOC(
Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
Each position of the input produce N prior boxes, N is determined by
the count of min_sizes, max_sizes and aspect_ratios, The size of the
box is in range(min_size, max_size) interval, which is generated in
sequence according to the aspect_ratios.
Please get more information from the following papers:
https://arxiv.org/abs/1512.02325.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker);
REGISTER_OP_CPU_KERNEL(
prior_box, ops::PriorBoxOpKernel<paddle::platform::CPUPlace, float>,
ops::PriorBoxOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace operators {
inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
bool flip,
std::vector<float>& output_aspect_ratior) {
constexpr float epsilon = 1e-6;
output_aspect_ratior.clear();
output_aspect_ratior.push_back(1.);
for (size_t i = 0; i < input_aspect_ratior.size(); ++i) {
float ar = input_aspect_ratior[i];
bool already_exist = false;
for (size_t j = 0; j < output_aspect_ratior.size(); ++j) {
if (fabs(ar - output_aspect_ratior[j]) < epsilon) {
already_exist = true;
break;
}
}
if (!already_exist) {
output_aspect_ratior.push_back(ar);
if (flip) {
output_aspect_ratior.push_back(1. / ar);
}
}
}
}
template <typename T>
struct ClipFunctor {
HOSTDEVICE T operator()(T in) const {
return std::min<T>(std::max<T>(in, 0.), 1.);
}
};
template <typename Place, typename T>
class PriorBoxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto min_sizes = ctx.Attr<std::vector<int>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<int>>("max_sizes");
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip");
auto clip = ctx.Attr<bool>("clip");
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios);
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto img_width = image->dims()[3];
auto img_height = image->dims()[2];
auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2];
T step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(img_width) / feature_width;
step_height = static_cast<T>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (max_sizes.size() > 0) {
num_priors += max_sizes.size();
}
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes);
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
T center_x = (w + offset) * step_width;
T center_y = (h + offset) * step_height;
T box_width, box_height;
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
int min_size = min_sizes[s];
// first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size;
// xmin
e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
// ymin
e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
// ymax
e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
idx++;
if (max_sizes.size() > 0) {
int max_size = max_sizes[s];
// second prior: aspect_ratio = 1,
// size = sqrt(min_size * max_size)
box_width = box_height = sqrt(min_size * max_size);
// xmin
e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
// ymin
e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
// ymax
e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
idx++;
}
// rest of priors
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
float ar = aspect_ratios[r];
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width = min_size * sqrt(ar);
box_height = min_size / sqrt(ar);
// xmin
e_boxes(h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
// ymin
e_boxes(h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
// ymax
e_boxes(h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
idx++;
}
}
}
}
if (clip) {
platform::Transform<platform::CPUDeviceContext> trans;
ClipFunctor<T> clip_func;
trans(ctx.template device_context<platform::CPUDeviceContext>(),
boxes->data<T>(), boxes->data<T>() + boxes->numel(),
boxes->data<T>(), clip_func);
}
framework::Tensor var_t;
var_t.mutable_data<T>(
framework::make_ddim({1, static_cast<int>(variances.size())}),
ctx.GetPlace());
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
for (size_t i = 0; i < variances.size(); ++i) {
var_et(0, i) = variances[i];
}
int box_num = feature_height * feature_width * num_priors;
auto var_dim = vars->dims();
vars->Resize({box_num, static_cast<int>(variances.size())});
auto e_vars = framework::EigenMatrix<T, Eigen::RowMajor>::From(*vars);
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(box_num, 1));
vars->Resize(var_dim);
}
}; // namespace operators
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
class TestPriorBoxOp(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image}
self.attrs = {
'min_sizes': self.min_sizes,
'max_sizes': self.max_sizes,
'aspect_ratios': self.aspect_ratios,
'variances': self.variances,
'flip': self.flip,
'clip': self.clip,
'step_w': self.step_w,
'step_h': self.step_h,
'offset': self.offset
}
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
return
def setUp(self):
self.op_type = "prior_box"
self.set_data()
def init_test_params(self):
self.layer_w = 4
self.layer_h = 4
self.image_w = 20
self.image_h = 20
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
self.input_channels = 2
self.image_channels = 3
self.batch_size = 10
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('int64')
self.max_sizes = [5, 10]
self.max_sizes = np.array(self.max_sizes).astype('int64')
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float).flatten()
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float).flatten()
self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 1:
self.num_priors += len(self.max_sizes)
self.offset = 0.5
def init_test_input(self):
self.image = np.random.random(
(self.batch_size, self.image_channels, self.image_w,
self.image_h)).astype('float32')
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_w,
self.layer_h)).astype('float32')
def init_test_output(self):
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
out_boxes = np.zeros(out_dim).astype('float32')
out_var = np.zeros(out_dim).astype('float32')
idx = 0
for h in range(self.layer_h):
for w in range(self.layer_w):
c_x = (w + self.offset) * self.step_w
c_y = (h + self.offset) * self.step_h
idx = 0
for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s]
c_w = c_h = min_size / 2.
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h
]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if math.fabs(ar - 1.) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h]
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
# set the variance.
out_var = np.tile(self.variances, (self.layer_h, self.layer_w,
self.num_priors, 1))
self.out_boxes = out_boxes.astype('float32')
self.out_var = out_var.astype('float32')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册