未验证 提交 36f08eef 编写于 作者: Q qingqing01 提交者: GitHub

CUDA kernel for density_prior_box_op. (#14513)

* CUDA kernel for density_prior_box_op.
* Support flatten to 2D.
上级 dfbdece5
......@@ -276,7 +276,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k
paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False))
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None))
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'flatten_to_2d', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, False, None))
paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False))
paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
......
......@@ -252,6 +252,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = std::vector<int>();
break;
}
case proto::AttrType::LONGS: {
VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from LONGS to LONGS";
this->attrs_[name] = std::vector<int64_t>();
break;
}
case proto::AttrType::FLOATS: {
VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS";
......
......@@ -22,7 +22,7 @@ iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc poly_util.cc gpc.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc)
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu)
detection_library(anchor_generator_op SRCS anchor_generator_op.cc
anchor_generator_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
......
......@@ -39,24 +39,27 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
auto fixed_sizes = ctx->Attrs().Get<std::vector<float>>("fixed_sizes");
auto fixed_ratios = ctx->Attrs().Get<std::vector<float>>("fixed_ratios");
auto densities = ctx->Attrs().Get<std::vector<int>>("densities");
bool flatten = ctx->Attrs().Get<bool>("flatten_to_2d");
PADDLE_ENFORCE_EQ(fixed_sizes.size(), densities.size(),
"The number of fixed_sizes and densities must be equal.");
size_t num_priors = 0;
if ((fixed_sizes.size() > 0) && (densities.size() > 0)) {
for (size_t i = 0; i < densities.size(); ++i) {
if (fixed_ratios.size() > 0) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
}
for (size_t i = 0; i < densities.size(); ++i) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
if (!flatten) {
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_priors;
dim_vec[3] = 4;
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
} else {
int64_t dim0 = input_dims[2] * input_dims[3] * num_priors;
ctx->SetOutputDim("Boxes", {dim0, 4});
ctx->SetOutputDim("Variances", {dim0, 4});
}
std::vector<int64_t> dim_vec(4);
dim_vec[0] = input_dims[2];
dim_vec[1] = input_dims[3];
dim_vec[2] = num_priors;
dim_vec[3] = 4;
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
}
protected:
......@@ -64,7 +67,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
platform::CPUPlace());
ctx.GetPlace());
}
};
......@@ -101,7 +104,10 @@ class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
});
AddAttr<bool>("clip", "(bool) Whether to clip out-of-boundary boxes.")
.SetDefault(true);
AddAttr<bool>("flatten_to_2d",
"(bool) Whether to flatten to 2D and "
"the second dim is 4.")
.SetDefault(false);
AddAttr<float>(
"step_w",
"Density prior boxes step across width, 0.0 for auto calculation.")
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/density_prior_box_op.h"
namespace paddle {
namespace operators {
template <typename T>
static __device__ inline T Clip(T in) {
return min(max(in, 0.), 1.);
}
template <typename T>
static __global__ void GenDensityPriorBox(
const int height, const int width, const int im_height, const int im_width,
const T offset, const T step_width, const T step_height,
const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin,
const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int gidy = blockIdx.y * blockDim.y + threadIdx.y;
int step_x = blockDim.x * gridDim.x;
int step_y = blockDim.y * gridDim.y;
const T* width_ratio = ratios_shift;
const T* height_ratio = ratios_shift + num_priors;
const T* width_shift = ratios_shift + 2 * num_priors;
const T* height_shift = ratios_shift + 3 * num_priors;
for (int j = gidy; j < height; j += step_y) {
for (int i = gidx; i < width * num_priors; i += step_x) {
int h = j;
int w = i / num_priors;
int k = i % num_priors;
T center_x = (w + offset) * step_width;
T center_y = (h + offset) * step_height;
T center_x_temp = center_x + width_shift[k];
T center_y_temp = center_y + height_shift[k];
T box_width_ratio = width_ratio[k] / 2.;
T box_height_ratio = height_ratio[k] / 2.;
T xmin = max((center_x_temp - box_width_ratio) / im_width, 0.);
T ymin = max((center_y_temp - box_height_ratio) / im_height, 0.);
T xmax = min((center_x_temp + box_width_ratio) / im_width, 1.);
T ymax = min((center_y_temp + box_height_ratio) / im_height, 1.);
int out_offset = (j * width * num_priors + i) * 4;
out[out_offset] = is_clip ? Clip<T>(xmin) : xmin;
out[out_offset + 1] = is_clip ? Clip<T>(ymin) : ymin;
out[out_offset + 2] = is_clip ? Clip<T>(xmax) : xmax;
out[out_offset + 3] = is_clip ? Clip<T>(ymax) : ymax;
var[out_offset] = var_xmin;
var[out_offset + 1] = var_ymin;
var[out_offset + 2] = var_xmax;
var[out_offset + 3] = var_ymax;
}
}
}
template <typename T>
class DensityPriorBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto is_clip = ctx.Attr<bool>("clip");
auto fixed_sizes = ctx.Attr<std::vector<float>>("fixed_sizes");
auto fixed_ratios = ctx.Attr<std::vector<float>>("fixed_ratios");
auto densities = ctx.Attr<std::vector<int>>("densities");
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
T offset = static_cast<T>(ctx.Attr<float>("offset"));
auto img_width = image->dims()[3];
auto img_height = image->dims()[2];
auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2];
T step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<T>(img_width) / feature_width;
step_height = static_cast<T>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = 0;
for (size_t i = 0; i < densities.size(); ++i) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
int step_average = static_cast<int>((step_width + step_height) * 0.5);
framework::Tensor h_temp;
T* tdata = h_temp.mutable_data<T>({num_priors * 4}, platform::CPUPlace());
int idx = 0;
for (size_t s = 0; s < fixed_sizes.size(); ++s) {
auto fixed_size = fixed_sizes[s];
int density = densities[s];
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
float ar = fixed_ratios[r];
int shift = step_average / density;
float box_width_ratio = fixed_size * sqrt(ar);
float box_height_ratio = fixed_size / sqrt(ar);
for (int di = 0; di < density; ++di) {
for (int dj = 0; dj < density; ++dj) {
float center_x_temp = shift / 2. + dj * shift - step_average / 2.;
float center_y_temp = shift / 2. + di * shift - step_average / 2.;
tdata[idx] = box_width_ratio;
tdata[num_priors + idx] = box_height_ratio;
tdata[2 * num_priors + idx] = center_x_temp;
tdata[3 * num_priors + idx] = center_y_temp;
idx++;
}
}
}
}
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
framework::Tensor d_temp;
framework::TensorCopySync(h_temp, ctx.GetPlace(), &d_temp);
// At least use 32 threads, at most 512 threads.
// blockx is multiple of 32.
int blockx = std::min(((feature_width * num_priors + 31) >> 5) << 5, 512L);
int gridx = (feature_width * num_priors + blockx - 1) / blockx;
dim3 threads(blockx, 1);
dim3 grids(gridx, feature_height);
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
GenDensityPriorBox<T><<<grids, threads, 0, stream>>>(
feature_height, feature_width, img_height, img_width, offset,
step_width, step_height, num_priors, d_temp.data<T>(), is_clip,
variances[0], variances[1], variances[2], variances[3],
boxes->data<T>(), vars->data<T>());
}
}; // namespace operators
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(density_prior_box,
ops::DensityPriorBoxOpCUDAKernel<float>,
ops::DensityPriorBoxOpCUDAKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -52,18 +52,16 @@ class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
step_height = step_h;
}
int num_priors = 0;
if (fixed_sizes.size() > 0 && densities.size() > 0) {
for (size_t i = 0; i < densities.size(); ++i) {
if (fixed_ratios.size() > 0) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
}
for (size_t i = 0; i < densities.size(); ++i) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes).setConstant(0.0);
auto box_dim = vars->dims();
boxes->Resize({feature_height, feature_width, num_priors, 4});
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes).setConstant(0.0);
int step_average = static_cast<int>((step_width + step_height) * 0.5);
for (int h = 0; h < feature_height; ++h) {
......@@ -76,36 +74,34 @@ class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
auto fixed_size = fixed_sizes[s];
int density = densities[s];
// Generate density prior boxes with fixed ratios.
if (fixed_ratios.size() > 0) {
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
float ar = fixed_ratios[r];
int shift = step_average / density;
float box_width_ratio = fixed_size * sqrt(ar);
float box_height_ratio = fixed_size / sqrt(ar);
for (int di = 0; di < density; ++di) {
for (int dj = 0; dj < density; ++dj) {
float center_x_temp =
center_x - step_average / 2. + shift / 2. + dj * shift;
float center_y_temp =
center_y - step_average / 2. + shift / 2. + di * shift;
e_boxes(h, w, idx, 0) =
(center_x_temp - box_width_ratio / 2.) / img_width >= 0
? (center_x_temp - box_width_ratio / 2.) / img_width
: 0;
e_boxes(h, w, idx, 1) =
(center_y_temp - box_height_ratio / 2.) / img_height >= 0
? (center_y_temp - box_height_ratio / 2.) / img_height
: 0;
e_boxes(h, w, idx, 2) =
(center_x_temp + box_width_ratio / 2.) / img_width <= 1
? (center_x_temp + box_width_ratio / 2.) / img_width
: 1;
e_boxes(h, w, idx, 3) =
(center_y_temp + box_height_ratio / 2.) / img_height <= 1
? (center_y_temp + box_height_ratio / 2.) / img_height
: 1;
idx++;
}
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
float ar = fixed_ratios[r];
int shift = step_average / density;
float box_width_ratio = fixed_size * sqrt(ar);
float box_height_ratio = fixed_size / sqrt(ar);
for (int di = 0; di < density; ++di) {
for (int dj = 0; dj < density; ++dj) {
float center_x_temp =
center_x - step_average / 2. + shift / 2. + dj * shift;
float center_y_temp =
center_y - step_average / 2. + shift / 2. + di * shift;
e_boxes(h, w, idx, 0) =
(center_x_temp - box_width_ratio / 2.) / img_width >= 0
? (center_x_temp - box_width_ratio / 2.) / img_width
: 0;
e_boxes(h, w, idx, 1) =
(center_y_temp - box_height_ratio / 2.) / img_height >= 0
? (center_y_temp - box_height_ratio / 2.) / img_height
: 0;
e_boxes(h, w, idx, 2) =
(center_x_temp + box_width_ratio / 2.) / img_width <= 1
? (center_x_temp + box_width_ratio / 2.) / img_width
: 1;
e_boxes(h, w, idx, 3) =
(center_y_temp + box_height_ratio / 2.) / img_height <= 1
? (center_y_temp + box_height_ratio / 2.) / img_height
: 1;
idx++;
}
}
}
......@@ -139,6 +135,7 @@ class DensityPriorBoxOpKernel : public framework::OpKernel<T> {
e_vars = var_et.broadcast(Eigen::DSizes<int, 2>(box_num, 1));
vars->Resize(var_dim);
boxes->Resize(box_dim);
}
}; // namespace operators
......
......@@ -1029,6 +1029,7 @@ def density_prior_box(input,
clip=False,
steps=[0.0, 0.0],
offset=0.5,
flatten_to_2d=False,
name=None):
"""
**Density Prior Box Operator**
......@@ -1065,22 +1066,24 @@ def density_prior_box(input,
height/weight of the input will be automatically calculated.
Default: [0., 0.]
offset(float): Prior boxes center offset. Default: 0.5
flatten_to_2d(bool): Whether to flatten output prior boxes and variance
to 2D shape, the second dim is 4. Default: False.
name(str): Name of the density prior box op. Default: None.
Returns:
tuple: A tuple with two Variable (boxes, variances)
boxes: the output density prior boxes of PriorBox.
The layout is [H, W, num_priors, 4].
H is the height of input, W is the width of input,
num_priors is the total
box count of each position of input.
The layout is [H, W, num_priors, 4] when flatten_to_2d is False.
The layout is [H * W * num_priors, 4] when flatten_to_2d is True.
H is the height of input, W is the width of input,
num_priors is the total box count of each position of input.
variances: the expanded variances of PriorBox.
The layout is [H, W, num_priors, 4].
H is the height of input, W is the width of input
num_priors is the total
box count of each position of input
The layout is [H, W, num_priors, 4] when flatten_to_2d is False.
The layout is [H * W * num_priors, 4] when flatten_to_2d is True.
H is the height of input, W is the width of input
num_priors is the total box count of each position of input.
Examples:
......@@ -1089,14 +1092,11 @@ def density_prior_box(input,
box, var = fluid.layers.density_prior_box(
input=conv1,
image=images,
min_sizes=[100.],
max_sizes=[200.],
aspect_ratios=[1.0, 1.0 / 2.0, 2.0],
densities=[3, 4],
fixed_sizes=[50., 60.],
fixed_ratios=[1.0, 3.0, 1.0 / 3.0],
flip=True,
clip=True)
densities=[4, 2, 1],
fixed_sizes=[32.0, 64.0, 128.0],
fixed_ratios=[1.],
clip=True,
flatten_to_2d=True)
"""
helper = LayerHelper("density_prior_box", **locals())
dtype = helper.input_dtype()
......@@ -1127,14 +1127,11 @@ def density_prior_box(input,
'step_w': steps[0],
'step_h': steps[1],
'offset': offset,
'densities': densities,
'fixed_sizes': fixed_sizes,
'fixed_ratios': fixed_ratios,
'flatten_to_2d': flatten_to_2d,
}
if densities is not None and len(densities) > 0:
attrs['densities'] = densities
if fixed_sizes is not None and len(fixed_sizes) > 0:
attrs['fixed_sizes'] = fixed_sizes
if fixed_ratios is not None and len(fixed_ratios) > 0:
attrs['fixed_ratios'] = fixed_ratios
box = helper.create_variable_for_type_inference(dtype)
var = helper.create_variable_for_type_inference(dtype)
helper.append_op(
......
......@@ -112,38 +112,42 @@ class TestDetection(unittest.TestCase):
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
data_shape = [3, 224, 224]
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
box, var = layers.prior_box(
input=conv1,
image=images,
min_sizes=[100.0],
aspect_ratios=[1.],
flip=True,
clip=True)
assert len(box.shape) == 4
assert box.shape == var.shape
assert box.shape[3] == 4
program = Program()
with program_guard(program):
data_shape = [3, 224, 224]
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
box, var = layers.prior_box(
input=conv1,
image=images,
min_sizes=[100.0],
aspect_ratios=[1.],
flip=True,
clip=True)
assert len(box.shape) == 4
assert box.shape == var.shape
assert box.shape[3] == 4
class TestDensityPriorBox(unittest.TestCase):
def test_density_prior_box(self):
data_shape = [3, 224, 224]
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
box, var = layers.density_prior_box(
input=conv1,
image=images,
densities=[3, 4],
fixed_sizes=[50., 60.],
fixed_ratios=[1.0],
clip=True)
assert len(box.shape) == 4
assert box.shape == var.shape
assert box.shape[3] == 4
program = Program()
with program_guard(program):
data_shape = [3, 224, 224]
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(images, 3, 3, 2)
box, var = layers.density_prior_box(
input=conv1,
image=images,
densities=[3, 4],
fixed_sizes=[50., 60.],
fixed_ratios=[1.0],
clip=True)
assert len(box.shape) == 4
assert box.shape == var.shape
assert box.shape[-1] == 4
class TestAnchorGenerator(unittest.TestCase):
......
......@@ -36,7 +36,8 @@ class TestDensityPriorBoxOp(OpTest):
'offset': self.offset,
'densities': self.densities,
'fixed_sizes': self.fixed_sizes,
'fixed_ratios': self.fixed_ratios
'fixed_ratios': self.fixed_ratios,
'flatten_to_2d': self.flatten_to_2d
}
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
......@@ -48,16 +49,17 @@ class TestDensityPriorBoxOp(OpTest):
self.set_data()
def set_density(self):
self.densities = []
self.fixed_sizes = []
self.fixed_ratios = []
self.densities = [4, 2, 1]
self.fixed_sizes = [32.0, 64.0, 128.0]
self.fixed_ratios = [1.0]
self.layer_w = 17
self.layer_h = 17
self.image_w = 533
self.image_h = 533
self.flatten_to_2d = False
def init_test_params(self):
self.layer_w = 32
self.layer_h = 32
self.image_w = 40
self.image_h = 40
self.set_density()
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
......@@ -69,8 +71,6 @@ class TestDensityPriorBoxOp(OpTest):
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float).flatten()
self.set_density()
self.clip = True
self.num_priors = 0
if len(self.fixed_sizes) > 0 and len(self.densities) > 0:
......@@ -129,6 +129,9 @@ class TestDensityPriorBoxOp(OpTest):
(self.layer_h, self.layer_w, self.num_priors, 1))
self.out_boxes = out_boxes.astype('float32')
self.out_var = out_var.astype('float32')
if self.flatten_to_2d:
self.out_boxes = self.out_boxes.reshape((-1, 4))
self.out_var = self.out_var.reshape((-1, 4))
class TestDensityPriorBox(TestDensityPriorBoxOp):
......@@ -136,6 +139,11 @@ class TestDensityPriorBox(TestDensityPriorBoxOp):
self.densities = [3, 4]
self.fixed_sizes = [1.0, 2.0]
self.fixed_ratios = [1.0]
self.layer_w = 32
self.layer_h = 32
self.image_w = 40
self.image_h = 40
self.flatten_to_2d = True
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册