提交 99a6c5d4 编写于 作者: W wanghaox

change output shape to [2, layer_height, layer_width, num_priors, 4]

上级 7297e6ff
...@@ -93,17 +93,12 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -93,17 +93,12 @@ class PriorBoxOp : public framework::OperatorWithKernel {
const int layer_height = input_dims[2]; const int layer_height = input_dims[2];
const int layer_width = input_dims[3]; const int layer_width = input_dims[3];
std::vector<int64_t> dim_vec(3); std::vector<int64_t> dim_vec(5);
// Since all images in a batch has same height and width, we only need to dim_vec[0] = 2;
// generate one set of priors which can be shared across all images. dim_vec[1] = layer_height;
dim_vec[0] = 1; dim_vec[2] = layer_width;
// 2 channels. First channel stores the mean of each prior coordinate. dim_vec[3] = num_priors;
// Second channel stores the variance of each prior coordinate. dim_vec[4] = 4;
dim_vec[1] = 2;
dim_vec[2] = layer_width * layer_height * num_priors * 4;
PADDLE_ENFORCE_GT(dim_vec[2], 0,
"output_dim[2] must larger than 0."
"check your data dims");
auto output_dim = framework::make_ddim(dim_vec); auto output_dim = framework::make_ddim(dim_vec);
ctx->SetOutputDim("Out", output_dim); ctx->SetOutputDim("Out", output_dim);
} }
...@@ -130,7 +125,8 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -130,7 +125,8 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"the input image data of PriorBoxOp, The format is NCHW."); "the input image data of PriorBoxOp, The format is NCHW.");
AddOutput("Out", AddOutput("Out",
"(Tensor, default Tensor<float>), the output prior boxes of " "(Tensor, default Tensor<float>), the output prior boxes of "
"PriorBoxOp."); "PriorBoxOp. The format is [2, layer_height, layer_width, "
"num_priors, 4]");
AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ", AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ",
"List of min sizes of generated prior boxes."); "List of min sizes of generated prior boxes.");
AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ", AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ",
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
// #include "paddle/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -94,50 +93,52 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -94,50 +93,52 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
num_priors += max_sizes.size(); num_priors += max_sizes.size();
} }
int dim = layer_height * layer_width * num_priors * 4;
T* output_data = nullptr; T* output_data = nullptr;
framework::Tensor output_cpu; framework::Tensor output_cpu;
framework::Tensor* output_tensor;
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
output_data = output_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
output_cpu.mutable_data<T>(out->dims(), platform::CPUPlace()); output_tensor = &output_cpu;
} else { } else {
output_data = out->mutable_data<T>(ctx.GetPlace()); output_tensor = out;
} }
int idx = 0; auto e_out = framework::EigenTensor<T, 5>::From(*output_tensor);
for (int h = 0; h < layer_height; ++h) { for (int h = 0; h < layer_height; ++h) {
for (int w = 0; w < layer_width; ++w) { for (int w = 0; w < layer_width; ++w) {
float center_x = (w + offset) * step_width; float center_x = (w + offset) * step_width;
float center_y = (h + offset) * step_height; float center_y = (h + offset) * step_height;
float box_width, box_height; float box_width, box_height;
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) { for (size_t s = 0; s < min_sizes.size(); ++s) {
int min_size = min_sizes[s]; int min_size = min_sizes[s];
// first prior: aspect_ratio = 1, size = min_size // first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size; box_width = box_height = min_size;
// xmin // xmin
output_data[idx++] = (center_x - box_width / 2.) / img_width; e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
// ymin // ymin
output_data[idx++] = (center_y - box_height / 2.) / img_height; e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
// xmax // xmax
output_data[idx++] = (center_x + box_width / 2.) / img_width; e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
// ymax // ymax
output_data[idx++] = (center_y + box_height / 2.) / img_height; e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
idx++;
if (max_sizes.size() > 0) { if (max_sizes.size() > 0) {
int max_size = max_sizes[s]; int max_size = max_sizes[s];
// second prior: aspect_ratio = 1, // second prior: aspect_ratio = 1,
// size = sqrt(min_size * max_size) // size = sqrt(min_size * max_size)
box_width = box_height = sqrt(min_size * max_size); box_width = box_height = sqrt(min_size * max_size);
// xmin // xmin
output_data[idx++] = (center_x - box_width / 2.) / img_width; e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
// ymin // ymin
output_data[idx++] = (center_y - box_height / 2.) / img_height; e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
// xmax // xmax
output_data[idx++] = (center_x + box_width / 2.) / img_width; e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
// ymax // ymax
output_data[idx++] = (center_y + box_height / 2.) / img_height; e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
idx++;
} }
// rest of priors // rest of priors
...@@ -149,13 +150,14 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -149,13 +150,14 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
box_width = min_size * sqrt(ar); box_width = min_size * sqrt(ar);
box_height = min_size / sqrt(ar); box_height = min_size / sqrt(ar);
// xmin // xmin
output_data[idx++] = (center_x - box_width / 2.) / img_width; e_out(0, h, w, idx, 0) = (center_x - box_width / 2.) / img_width;
// ymin // ymin
output_data[idx++] = (center_y - box_height / 2.) / img_height; e_out(0, h, w, idx, 1) = (center_y - box_height / 2.) / img_height;
// xmax // xmax
output_data[idx++] = (center_x + box_width / 2.) / img_width; e_out(0, h, w, idx, 2) = (center_x + box_width / 2.) / img_width;
// ymax // ymax
output_data[idx++] = (center_y + box_height / 2.) / img_height; e_out(0, h, w, idx, 3) = (center_y + box_height / 2.) / img_height;
idx++;
} }
} }
} }
...@@ -163,26 +165,31 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -163,26 +165,31 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
// clip the prior's coordidate such that it is within [0, 1] // clip the prior's coordidate such that it is within [0, 1]
if (clip) { if (clip) {
for (int d = 0; d < dim; ++d) { for (int h = 0; h < layer_height; ++h) {
output_data[d] = std::min<T>(std::max<T>(output_data[d], 0.), 1.); for (int w = 0; w < layer_width; ++w) {
for (int i = 0; i < num_priors; ++i) {
for (int j = 0; j < 4; ++j) {
e_out(0, h, w, i, j) =
std::min<T>(std::max<T>(e_out(0, h, w, i, j), 0.), 1.);
}
}
}
} }
}
// set the variance. // set the variance.
auto output_stride = framework::stride(out->dims()); auto output_stride = framework::stride(out->dims());
output_data += output_stride[1]; output_data += output_stride[1];
if (variances.size() == 1) { if (variances.size() == 1) {
for (int i = 0; i < dim; ++i) { variances.resize(4);
output_data[i] = variances[0]; variances[1] = variances[0];
variances[2] = variances[0];
variances[3] = variances[0];
} }
} else {
int count = 0;
for (int h = 0; h < layer_height; ++h) { for (int h = 0; h < layer_height; ++h) {
for (int w = 0; w < layer_width; ++w) { for (int w = 0; w < layer_width; ++w) {
for (int i = 0; i < num_priors; ++i) { for (int i = 0; i < num_priors; ++i) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
output_data[count] = variances[j]; e_out(1, h, w, i, j) = variances[j];
++count;
} }
} }
} }
......
...@@ -81,8 +81,7 @@ class TestPriorBoxOp(OpTest): ...@@ -81,8 +81,7 @@ class TestPriorBoxOp(OpTest):
self.layer_h)).astype('float32') self.layer_h)).astype('float32')
def init_test_output(self): def init_test_output(self):
dim = self.layer_w * self.layer_h * self.num_priors * 4 out_dim = (2, self.layer_h, self.layer_w, self.num_priors, 4)
out_dim = (1, 2, dim)
output = np.zeros(out_dim).astype('float32') output = np.zeros(out_dim).astype('float32')
idx = 0 idx = 0
...@@ -90,24 +89,22 @@ class TestPriorBoxOp(OpTest): ...@@ -90,24 +89,22 @@ class TestPriorBoxOp(OpTest):
for w in range(self.layer_w): for w in range(self.layer_w):
center_x = (w + self.offset) * self.step_w center_x = (w + self.offset) * self.step_w
center_y = (h + self.offset) * self.step_h center_y = (h + self.offset) * self.step_h
idx = 0
for s in range(len(self.min_sizes)): for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s] min_size = self.min_sizes[s]
# first prior: aspect_ratio = 1, size = min_size # first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size box_width = box_height = min_size
# xmin # xmin
output[0, 0, idx] = ( output[0, h, w, idx, 0] = (
center_x - box_width / 2.) / self.image_w center_x - box_width / 2.) / self.image_w
idx += 1
# ymin # ymin
output[0, 0, idx] = ( output[0, h, w, idx, 1] = (
center_y - box_height / 2.) / self.image_h center_y - box_height / 2.) / self.image_h
idx += 1
# xmax # xmax
output[0, 0, idx] = ( output[0, h, w, idx, 2] = (
center_x + box_width / 2.) / self.image_w center_x + box_width / 2.) / self.image_w
idx += 1
# ymax # ymax
output[0, 0, idx] = ( output[0, h, w, idx, 3] = (
center_y + box_height / 2.) / self.image_h center_y + box_height / 2.) / self.image_h
idx += 1 idx += 1
...@@ -117,19 +114,16 @@ class TestPriorBoxOp(OpTest): ...@@ -117,19 +114,16 @@ class TestPriorBoxOp(OpTest):
# size = sqrt(min_size * max_size) # size = sqrt(min_size * max_size)
box_width = box_height = math.sqrt(min_size * max_size) box_width = box_height = math.sqrt(min_size * max_size)
# xmin # xmin
output[0, 0, idx] = ( output[0, h, w, idx, 0] = (
center_x - box_width / 2.) / self.image_w center_x - box_width / 2.) / self.image_w
idx += 1
# ymin # ymin
output[0, 0, idx] = ( output[0, h, w, idx, 1] = (
center_y - box_height / 2.) / self.image_h center_y - box_height / 2.) / self.image_h
idx += 1
# xmax # xmax
output[0, 0, idx] = ( output[0, h, w, idx, 2] = (
center_x + box_width / 2.) / self.image_w center_x + box_width / 2.) / self.image_w
idx += 1
# ymax # ymax
output[0, 0, idx] = ( output[0, h, w, idx, 3] = (
center_y + box_height / 2.) / self.image_h center_y + box_height / 2.) / self.image_h
idx += 1 idx += 1
...@@ -141,37 +135,35 @@ class TestPriorBoxOp(OpTest): ...@@ -141,37 +135,35 @@ class TestPriorBoxOp(OpTest):
box_width = min_size * math.sqrt(ar) box_width = min_size * math.sqrt(ar)
box_height = min_size / math.sqrt(ar) box_height = min_size / math.sqrt(ar)
# xmin # xmin
output[0, 0, idx] = ( output[0, h, w, idx, 0] = (
center_x - box_width / 2.) / self.image_w center_x - box_width / 2.) / self.image_w
idx += 1
# ymin # ymin
output[0, 0, idx] = ( output[0, h, w, idx, 1] = (
center_y - box_height / 2.) / self.image_h center_y - box_height / 2.) / self.image_h
idx += 1
# xmax # xmax
output[0, 0, idx] = ( output[0, h, w, idx, 2] = (
center_x + box_width / 2.) / self.image_w center_x + box_width / 2.) / self.image_w
idx += 1
# ymax # ymax
output[0, 0, idx] = ( output[0, h, w, idx, 3] = (
center_y + box_height / 2.) / self.image_h center_y + box_height / 2.) / self.image_h
idx += 1 idx += 1
# clip the prior's coordidate such that it is within[0, 1] # clip the prior's coordidate such that it is within[0, 1]
if self.clip: if self.clip:
for d in range(dim):
output[0, 0, d] = min(max(output[0, 0, d], 0), 1)
# set the variance.
if len(self.variances) == 1:
for i in range(dim):
output[0, 1, i] = self.variances[0]
else:
count = 0
for h in range(self.layer_h): for h in range(self.layer_h):
for w in range(self.layer_w): for w in range(self.layer_w):
for i in range(self.num_priors): for i in range(self.num_priors):
for j in range(4): for j in range(4):
output[0, 1, count] = self.variances[j] output[0, h, w, i, j] = min(
count += 1 max(output[0, h, w, i, j], 0), 1)
# set the variance.
for h in range(self.layer_h):
for w in range(self.layer_w):
for i in range(self.num_priors):
for j in range(4):
if len(self.variances) == 1:
output[1, h, w, i, j] = self.variances[0]
else:
output[1, h, w, i, j] = self.variances[j]
self.output = output.astype('float32') self.output = output.astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册