提交 ee0113af 编写于 作者: W wanghaox

implement of prior box operator for ssd

上级 dcf3ffd9
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/prior_box_op.h"
namespace paddle {
namespace operators {
class PriorBoxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(X) of SequenceSliceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Image"),
"Input(Offset) of SequenceSliceOp should not be null.");
auto image_dims = ctx->GetInputDim("Image");
auto input_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE(image_dims.size() == 4,
"The format of input tensor is NCHW.");
auto min_sizes = ctx->Attrs().Get<std::vector<int>>("min_sizes");
auto max_sizes = ctx->Attrs().Get<std::vector<int>>("max_sizes");
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
auto input_aspect_ratio =
ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
bool flip = ctx->Attrs().Get<bool>("flip");
PADDLE_ENFORCE_GT(min_sizes.size(), 0, "must provide min_size.");
for (size_t i = 0; i < min_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(min_sizes[i], 0, "min_sizes[%d] must be positive.", i);
}
std::vector<float> aspect_ratios;
expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios);
int num_priors = aspect_ratios.size() * min_sizes.size();
if (max_sizes.size() > 0) {
PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(),
"The length of min_size and max_size must be equal.");
for (size_t i = 0; i < min_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i],
"max_size[%d] must be greater than min_size[%d].", i,
i);
num_priors += 1;
}
}
if (variances.size() > 1) {
PADDLE_ENFORCE_EQ(variances.size(), 4,
"Must and only provide 4 variance.");
for (size_t i = 0; i < variances.size(); ++i) {
PADDLE_ENFORCE_GT(variances[i], 0.0,
"variance[%d] must be greater than 0.", i);
}
} else if (variances.size() == 1) {
PADDLE_ENFORCE_GT(variances[0], 0.0,
"variance[0] must be greater than 0.");
}
const int img_h = ctx->Attrs().Get<int>("img_h");
PADDLE_ENFORCE_GT(img_h, 0, "img_h should be larger than 0.");
const int img_w = ctx->Attrs().Get<int>("img_w");
PADDLE_ENFORCE_GT(img_w, 0, "img_w should be larger than 0.");
const float step_h = ctx->Attrs().Get<float>("step_h");
PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0.");
const float step_w = ctx->Attrs().Get<float>("step_w");
PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0.");
const int layer_height = input_dims[3];
const int layer_width = input_dims[2];
std::vector<int64_t> dim_vec(3);
// Since all images in a batch has same height and width, we only need to
// generate one set of priors which can be shared across all images.
dim_vec[0] = 1;
// 2 channels. First channel stores the mean of each prior coordinate.
// Second channel stores the variance of each prior coordinate.
dim_vec[1] = 2;
dim_vec[2] = layer_width * layer_height * num_priors * 4;
PADDLE_ENFORCE_GT(dim_vec[2], 0,
"output_dim[2] must larger than 0."
"check your data dims");
auto output_dim = framework::make_ddim(dim_vec);
ctx->SetOutputDim("Out", output_dim);
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Image")->type()),
ctx.device_context());
}
};
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PriorBoxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor), "
"the input feature data of PriorBoxOp.");
AddInput("Image",
"(Tensor), "
"the input image data of PriorBoxOp.");
AddOutput("Out", "(Tensor), the output prior boxes of PriorBoxOp.");
AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ",
"List of min sizes of generated prior boxes.");
AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ",
"List of max sizes of generated prior boxes.");
AddAttr<std::vector<float>>(
"aspect_ratios", "(vector<float>) ",
"List of aspect ratios of generated prior boxes.")
.SetDefault({});
AddAttr<std::vector<float>>(
"variances", "(vector<float>) ",
"List of variances to be encoded in prior boxes.")
.SetDefault({0.1});
AddAttr<bool>("flip", "(bool) ", "Whether to flip aspect ratios.")
.SetDefault(true);
AddAttr<bool>("clip", "(bool) ", "Whether to clip out-of-boundary boxes.")
.SetDefault(true);
AddAttr<int>("img_w", "").SetDefault(0);
AddAttr<int>("img_h", "").SetDefault(0);
AddAttr<float>("step_w",
"Prior boxes step across width, 0 for auto calculation.")
.SetDefault(0.0);
AddAttr<float>("step_h",
"Prior boxes step across height, 0 for auto calculation.")
.SetDefault(0.0);
AddAttr<float>("offset",
"(float) "
"Prior boxes center offset.")
.SetDefault(0.5);
AddComment(R"DOC(
Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
Please get more information from the following papers:
https://arxiv.org/abs/1512.02325.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker);
REGISTER_OP_CPU_KERNEL(
prior_box, ops::PriorBoxOpKernel<paddle::platform::CPUPlace, float>,
ops::PriorBoxOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/prior_box_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
prior_box, ops::PriorBoxOpKernel<paddle::platform::GPUPlace, float>,
ops::PriorBoxOpKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
// #include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
inline void expand_aspect_ratios(const std::vector<float> input_aspect_ratior,
bool flip,
std::vector<float>& output_aspect_ratior) {
constexpr float eps = 1e-6;
output_aspect_ratior.clear();
output_aspect_ratior.push_back(1.);
for (size_t i = 0; i < input_aspect_ratior.size(); ++i) {
float ar = input_aspect_ratior[i];
bool already_exist = false;
for (size_t j = 0; j < output_aspect_ratior.size(); ++j) {
if (fabs(ar - output_aspect_ratior[j]) < eps) {
already_exist = true;
break;
}
}
if (!already_exist) {
output_aspect_ratior.push_back(ar);
if (flip) {
output_aspect_ratior.push_back(1. / ar);
}
}
}
}
template <typename Place, typename T>
class PriorBoxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
auto* out = ctx.Output<paddle::framework::Tensor>("Out");
auto min_sizes = ctx.Attr<std::vector<int>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<int>>("max_sizes");
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip");
auto clip = ctx.Attr<bool>("clip");
std::vector<float> aspect_ratios;
expand_aspect_ratios(input_aspect_ratio, flip, aspect_ratios);
auto img_w = ctx.Attr<int>("img_w");
auto img_h = ctx.Attr<int>("img_h");
auto step_w = ctx.Attr<float>("step_w");
auto step_h = ctx.Attr<float>("step_h");
auto offset = ctx.Attr<float>("offset");
int img_width, img_height;
if (img_h == 0 || img_w == 0) {
img_width = image->dims()[2];
img_height = image->dims()[3];
} else {
img_width = img_w;
img_height = img_h;
}
const int layer_width = input->dims()[2];
const int layer_height = input->dims()[3];
float step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / layer_width;
step_height = static_cast<float>(img_height) / layer_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (max_sizes.size() > 0) {
num_priors += max_sizes.size();
}
int dim = layer_height * layer_width * num_priors * 4;
T* output_data = nullptr;
framework::Tensor output_cpu;
out->mutable_data<T>(ctx.GetPlace());
if (platform::is_gpu_place(ctx.GetPlace())) {
output_data =
output_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
} else {
output_data = out->mutable_data<T>(ctx.GetPlace());
}
int idx = 0;
for (int h = 0; h < layer_height; ++h) {
for (int w = 0; w < layer_width; ++w) {
float center_x = (w + offset) * step_width;
float center_y = (h + offset) * step_height;
float box_width, box_height;
for (size_t s = 0; s < min_sizes.size(); ++s) {
int min_size = min_sizes[s];
// first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size;
// xmin
output_data[idx++] = (center_x - box_width / 2.) / img_width;
// ymin
output_data[idx++] = (center_y - box_height / 2.) / img_height;
// xmax
output_data[idx++] = (center_x + box_width / 2.) / img_width;
// ymax
output_data[idx++] = (center_y + box_height / 2.) / img_height;
if (max_sizes.size() > 0) {
int max_size = max_sizes[s];
// second prior: aspect_ratio = 1,
// size = sqrt(min_size * max_size)
box_width = box_height = sqrt(min_size * max_size);
// xmin
output_data[idx++] = (center_x - box_width / 2.) / img_width;
// ymin
output_data[idx++] = (center_y - box_height / 2.) / img_height;
// xmax
output_data[idx++] = (center_x + box_width / 2.) / img_width;
// ymax
output_data[idx++] = (center_y + box_height / 2.) / img_height;
}
// rest of priors
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
float ar = aspect_ratios[r];
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width = min_size * sqrt(ar);
box_height = min_size / sqrt(ar);
// xmin
output_data[idx++] = (center_x - box_width / 2.) / img_width;
// ymin
output_data[idx++] = (center_y - box_height / 2.) / img_height;
// xmax
output_data[idx++] = (center_x + box_width / 2.) / img_width;
// ymax
output_data[idx++] = (center_y + box_height / 2.) / img_height;
}
}
}
}
// clip the prior's coordidate such that it is within [0, 1]
if (clip) {
for (int d = 0; d < dim; ++d) {
output_data[d] = std::min<T>(std::max<T>(output_data[d], 0.), 1.);
}
}
// set the variance.
auto output_stride = framework::stride(out->dims());
output_data += output_stride[1];
if (variances.size() == 1) {
for (int i = 0; i < dim; ++i) {
output_data[i] = variances[0];
}
} else {
int count = 0;
for (int h = 0; h < layer_height; ++h) {
for (int w = 0; w < layer_width; ++w) {
for (int i = 0; i < num_priors; ++i) {
for (int j = 0; j < 4; ++j) {
output_data[count] = variances[j];
++count;
}
}
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::CopyFrom(output_cpu, platform::CPUPlace(),
ctx.device_context(), out);
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
class TestPriorBoxOp(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image}
self.attrs = {
'min_sizes': self.min_sizes,
'max_sizes': self.max_sizes,
'aspect_ratios': self.aspect_ratios,
'variances': self.variances,
'flip': self.flip,
'clip': self.clip,
'step_w': self.step_w,
'step_h': self.step_h,
'img_w': self.image_w,
'img_h': self.image_h,
'offset': self.offset
}
self.outputs = {'Out': self.output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
return
def setUp(self):
self.op_type = "prior_box"
self.set_data()
def init_test_params(self):
self.layer_w = 4
self.layer_h = 4
self.image_w = 20
self.image_h = 20
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
self.input_channels = 2
self.image_channels = 3
self.batch_size = 10
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('int64')
self.max_sizes = [5, 10]
self.max_sizes = np.array(self.max_sizes).astype('int64')
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float).flatten()
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float).flatten()
self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 1:
self.num_priors += len(self.max_sizes)
self.offset = 0.5
def init_test_input(self):
self.image = np.random.random(
(self.batch_size, self.image_channels, self.image_w,
self.image_h)).astype('float32')
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_w,
self.layer_h)).astype('float32')
def init_test_output(self):
dim = self.layer_w * self.layer_h * self.num_priors * 4
out_dim = (1, 2, dim)
output = np.zeros(out_dim).astype('float32')
idx = 0
for h in range(self.layer_h):
for w in range(self.layer_w):
center_x = (w + self.offset) * self.step_w
center_y = (h + self.offset) * self.step_h
for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s]
# first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size
# xmin
output[0, 0, idx] = (
center_x - box_width / 2.) / self.image_w
idx += 1
# ymin
output[0, 0, idx] = (
center_y - box_height / 2.) / self.image_h
idx += 1
# xmax
output[0, 0, idx] = (
center_x + box_width / 2.) / self.image_w
idx += 1
# ymax
output[0, 0, idx] = (
center_y + box_height / 2.) / self.image_h
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
# size = sqrt(min_size * max_size)
box_width = box_height = math.sqrt(min_size * max_size)
# xmin
output[0, 0, idx] = (
center_x - box_width / 2.) / self.image_w
idx += 1
# ymin
output[0, 0, idx] = (
center_y - box_height / 2.) / self.image_h
idx += 1
# xmax
output[0, 0, idx] = (
center_x + box_width / 2.) / self.image_w
idx += 1
# ymax
output[0, 0, idx] = (
center_y + box_height / 2.) / self.image_h
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if math.fabs(ar - 1.) < 1e-6:
continue
box_width = min_size * math.sqrt(ar)
box_height = min_size / math.sqrt(ar)
# xmin
output[0, 0, idx] = (
center_x - box_width / 2.) / self.image_w
idx += 1
# ymin
output[0, 0, idx] = (
center_y - box_height / 2.) / self.image_h
idx += 1
# xmax
output[0, 0, idx] = (
center_x + box_width / 2.) / self.image_w
idx += 1
# ymax
output[0, 0, idx] = (
center_y + box_height / 2.) / self.image_h
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
for d in range(dim):
output[0, 0, d] = min(max(output[0, 0, d], 0), 1)
# set the variance.
if len(self.variances) == 1:
for i in range(dim):
output[0, 1, i] = self.variances[0]
else:
count = 0
for h in range(self.layer_h):
for w in range(self.layer_w):
for i in range(self.num_priors):
for j in range(4):
output[0, 1, count] = self.variances[j]
count += 1
self.output = output.astype('float32')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册