提交 ca2e96f2 编写于 作者: W wanghaox

update code

上级 534cf741
...@@ -98,16 +98,14 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,16 +98,14 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"the input image data of PriorBoxOp, The layout is NCHW."); "the input image data of PriorBoxOp, The layout is NCHW.");
AddOutput("Boxes", AddOutput("Boxes",
"(Tensor, default Tensor<float>), the output prior boxes of " "(Tensor, default Tensor<float>), the output prior boxes of "
"PriorBoxOp. The layout is [layer_height, layer_width, " "PriorBoxOp. The layout is [H, W, num_priors, 4]. "
"num_priors, 4]. layer_height is the height of input, " "H is the height of input, W is the width of input, num_priors "
"layer_width is the width of input, num_priors is the box " "is the box count of each position.");
"count of each position.");
AddOutput("Variances", AddOutput("Variances",
"(Tensor, default Tensor<float>), the expanded variances of " "(Tensor, default Tensor<float>), the expanded variances of "
"PriorBoxOp. The layout is [layer_height, layer_width, " "PriorBoxOp. The layout is [H, W, num_priors, 4]. "
"num_priors, 4]. layer_height is the height of input, " "H is the height of input, W is the width of input, num_priors "
"layer_width is the width of input, num_priors is the box " "is the box count of each position.");
"count of each position.");
AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ", AddAttr<std::vector<int>>("min_sizes", "(vector<int>) ",
"List of min sizes of generated prior boxes."); "List of min sizes of generated prior boxes.");
AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ", AddAttr<std::vector<int>>("max_sizes", "(vector<int>) ",
......
...@@ -77,13 +77,13 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -77,13 +77,13 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
auto img_width = image->dims()[3]; auto img_width = image->dims()[3];
auto img_height = image->dims()[2]; auto img_height = image->dims()[2];
auto layer_width = input->dims()[3]; auto feature_width = input->dims()[3];
auto layer_height = input->dims()[2]; auto feature_height = input->dims()[2];
T step_width, step_height; T step_width, step_height;
if (step_w == 0 || step_h == 0) { if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(img_width) / layer_width; step_width = static_cast<T>(img_width) / feature_width;
step_height = static_cast<T>(img_height) / layer_height; step_height = static_cast<T>(img_height) / feature_height;
} else { } else {
step_width = step_w; step_width = step_w;
step_height = step_h; step_height = step_h;
...@@ -98,8 +98,8 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -98,8 +98,8 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
vars->mutable_data<T>(ctx.GetPlace()); vars->mutable_data<T>(ctx.GetPlace());
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes); auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes);
for (int h = 0; h < layer_height; ++h) { for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < layer_width; ++w) { for (int w = 0; w < feature_width; ++w) {
T center_x = (w + offset) * step_width; T center_x = (w + offset) * step_width;
T center_y = (h + offset) * step_height; T center_y = (h + offset) * step_height;
T box_width, box_height; T box_width, box_height;
...@@ -164,12 +164,16 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -164,12 +164,16 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
boxes->data<T>(), clip_func); boxes->data<T>(), clip_func);
} }
Eigen::Tensor<T, 2, Eigen::RowMajor> var_et(1, variances.size()); framework::Tensor var_t;
var_t.mutable_data<T>(
framework::make_ddim({1, static_cast<int>(variances.size())}),
ctx.GetPlace());
auto var_et = framework::EigenTensor<T, 2>::From(var_t);
for (size_t i = 0; i < variances.size(); ++i) { for (size_t i = 0; i < variances.size(); ++i) {
var_et(0, i) = variances[i]; var_et(0, i) = variances[i];
} }
int box_num = layer_height * layer_width * num_priors; int box_num = feature_height * feature_width * num_priors;
auto var_dim = vars->dims(); auto var_dim = vars->dims();
vars->Resize({box_num, static_cast<int>(variances.size())}); vars->Resize({box_num, static_cast<int>(variances.size())});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册