diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 541c4db1fa0914b657b3553ea20114f4bbe74464..50114bf3df0ac5ef861f1d2280729263cd6cbf92 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -276,7 +276,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False)) -paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None)) +paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'flatten_to_2d', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, False, None)) paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)) paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index fbaa169df6324761ef9136aa173dce4e2182ed38..362cda3f2329bef1abaa93b4529e506d41f07606 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -252,6 +252,12 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { this->attrs_[name] = std::vector(); break; } + case proto::AttrType::LONGS: { + VLOG(110) << "SetAttr: " << Type() << ", " << name + << " from LONGS to LONGS"; + this->attrs_[name] = std::vector(); + break; + } case proto::AttrType::FLOATS: { VLOG(110) << "SetAttr: " << Type() << ", " << name << " from INTS to FLOATS"; diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 58f6f48467310ffb2429ad440f58fcd823edf079..6c85f1577e0c49d00f4ccf7fa7be0974eb62bdf3 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -22,7 +22,7 @@ iou_similarity_op.cu) detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc) detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc poly_util.cc gpc.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) -detection_library(density_prior_box_op SRCS density_prior_box_op.cc) +detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu) detection_library(anchor_generator_op SRCS anchor_generator_op.cc anchor_generator_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index 99df15c3226b4305a28a3912398d6d1c766daa73..1012ba3652ddfb06af9292e8061684481f1dbef3 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -39,24 +39,27 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { auto fixed_sizes = ctx->Attrs().Get>("fixed_sizes"); auto fixed_ratios = ctx->Attrs().Get>("fixed_ratios"); auto densities = ctx->Attrs().Get>("densities"); + bool flatten = ctx->Attrs().Get("flatten_to_2d"); PADDLE_ENFORCE_EQ(fixed_sizes.size(), densities.size(), "The number of fixed_sizes and densities must be equal."); size_t num_priors = 0; - if ((fixed_sizes.size() > 0) && (densities.size() > 0)) { - for (size_t i = 0; i < densities.size(); ++i) { - if (fixed_ratios.size() > 0) { - num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); - } - } + for (size_t i = 0; i < densities.size(); ++i) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); + } + if (!flatten) { + std::vector dim_vec(4); + dim_vec[0] = input_dims[2]; + dim_vec[1] = input_dims[3]; + dim_vec[2] = num_priors; + dim_vec[3] = 4; + ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); + ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); + } else { + int64_t dim0 = input_dims[2] * input_dims[3] * num_priors; + ctx->SetOutputDim("Boxes", {dim0, 4}); + ctx->SetOutputDim("Variances", {dim0, 4}); } - std::vector dim_vec(4); - dim_vec[0] = input_dims[2]; - dim_vec[1] = input_dims[3]; - dim_vec[2] = num_priors; - dim_vec[3] = 4; - ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); - ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); } protected: @@ -64,7 +67,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), - platform::CPUPlace()); + ctx.GetPlace()); } }; @@ -101,7 +104,10 @@ class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("clip", "(bool) Whether to clip out-of-boundary boxes.") .SetDefault(true); - + AddAttr("flatten_to_2d", + "(bool) Whether to flatten to 2D and " + "the second dim is 4.") + .SetDefault(false); AddAttr( "step_w", "Density prior boxes step across width, 0.0 for auto calculation.") diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cu b/paddle/fluid/operators/detection/density_prior_box_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3b7c781795f02b9d9c9f2ead51034193ceb2a745 --- /dev/null +++ b/paddle/fluid/operators/detection/density_prior_box_op.cu @@ -0,0 +1,170 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/density_prior_box_op.h" + +namespace paddle { +namespace operators { + +template +static __device__ inline T Clip(T in) { + return min(max(in, 0.), 1.); +} + +template +static __global__ void GenDensityPriorBox( + const int height, const int width, const int im_height, const int im_width, + const T offset, const T step_width, const T step_height, + const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin, + const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) { + int gidx = blockIdx.x * blockDim.x + threadIdx.x; + int gidy = blockIdx.y * blockDim.y + threadIdx.y; + int step_x = blockDim.x * gridDim.x; + int step_y = blockDim.y * gridDim.y; + + const T* width_ratio = ratios_shift; + const T* height_ratio = ratios_shift + num_priors; + const T* width_shift = ratios_shift + 2 * num_priors; + const T* height_shift = ratios_shift + 3 * num_priors; + + for (int j = gidy; j < height; j += step_y) { + for (int i = gidx; i < width * num_priors; i += step_x) { + int h = j; + int w = i / num_priors; + int k = i % num_priors; + + T center_x = (w + offset) * step_width; + T center_y = (h + offset) * step_height; + + T center_x_temp = center_x + width_shift[k]; + T center_y_temp = center_y + height_shift[k]; + + T box_width_ratio = width_ratio[k] / 2.; + T box_height_ratio = height_ratio[k] / 2.; + + T xmin = max((center_x_temp - box_width_ratio) / im_width, 0.); + T ymin = max((center_y_temp - box_height_ratio) / im_height, 0.); + T xmax = min((center_x_temp + box_width_ratio) / im_width, 1.); + T ymax = min((center_y_temp + box_height_ratio) / im_height, 1.); + + int out_offset = (j * width * num_priors + i) * 4; + out[out_offset] = is_clip ? Clip(xmin) : xmin; + out[out_offset + 1] = is_clip ? Clip(ymin) : ymin; + out[out_offset + 2] = is_clip ? Clip(xmax) : xmax; + out[out_offset + 3] = is_clip ? Clip(ymax) : ymax; + + var[out_offset] = var_xmin; + var[out_offset + 1] = var_ymin; + var[out_offset + 2] = var_xmax; + var[out_offset + 3] = var_ymax; + } + } +} + +template +class DensityPriorBoxOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* vars = ctx.Output("Variances"); + + auto variances = ctx.Attr>("variances"); + auto is_clip = ctx.Attr("clip"); + + auto fixed_sizes = ctx.Attr>("fixed_sizes"); + auto fixed_ratios = ctx.Attr>("fixed_ratios"); + auto densities = ctx.Attr>("densities"); + + T step_w = static_cast(ctx.Attr("step_w")); + T step_h = static_cast(ctx.Attr("step_h")); + T offset = static_cast(ctx.Attr("offset")); + + auto img_width = image->dims()[3]; + auto img_height = image->dims()[2]; + + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; + + T step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = 0; + for (size_t i = 0; i < densities.size(); ++i) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); + } + int step_average = static_cast((step_width + step_height) * 0.5); + + framework::Tensor h_temp; + T* tdata = h_temp.mutable_data({num_priors * 4}, platform::CPUPlace()); + int idx = 0; + for (size_t s = 0; s < fixed_sizes.size(); ++s) { + auto fixed_size = fixed_sizes[s]; + int density = densities[s]; + for (size_t r = 0; r < fixed_ratios.size(); ++r) { + float ar = fixed_ratios[r]; + int shift = step_average / density; + float box_width_ratio = fixed_size * sqrt(ar); + float box_height_ratio = fixed_size / sqrt(ar); + for (int di = 0; di < density; ++di) { + for (int dj = 0; dj < density; ++dj) { + float center_x_temp = shift / 2. + dj * shift - step_average / 2.; + float center_y_temp = shift / 2. + di * shift - step_average / 2.; + tdata[idx] = box_width_ratio; + tdata[num_priors + idx] = box_height_ratio; + tdata[2 * num_priors + idx] = center_x_temp; + tdata[3 * num_priors + idx] = center_y_temp; + idx++; + } + } + } + } + + boxes->mutable_data(ctx.GetPlace()); + vars->mutable_data(ctx.GetPlace()); + + framework::Tensor d_temp; + framework::TensorCopySync(h_temp, ctx.GetPlace(), &d_temp); + + // At least use 32 threads, at most 512 threads. + // blockx is multiple of 32. + int blockx = std::min(((feature_width * num_priors + 31) >> 5) << 5, 512L); + int gridx = (feature_width * num_priors + blockx - 1) / blockx; + dim3 threads(blockx, 1); + dim3 grids(gridx, feature_height); + + auto stream = + ctx.template device_context().stream(); + GenDensityPriorBox<<>>( + feature_height, feature_width, img_height, img_width, offset, + step_width, step_height, num_priors, d_temp.data(), is_clip, + variances[0], variances[1], variances[2], variances[3], + boxes->data(), vars->data()); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(density_prior_box, + ops::DensityPriorBoxOpCUDAKernel, + ops::DensityPriorBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/density_prior_box_op.h b/paddle/fluid/operators/detection/density_prior_box_op.h index 9a52077e9cf90b278549a077af161bd4e282d972..ed2f5df80cf4d7a5a44af9b09f3b048b1b14cdb9 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.h +++ b/paddle/fluid/operators/detection/density_prior_box_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -52,18 +52,16 @@ class DensityPriorBoxOpKernel : public framework::OpKernel { step_height = step_h; } int num_priors = 0; - if (fixed_sizes.size() > 0 && densities.size() > 0) { - for (size_t i = 0; i < densities.size(); ++i) { - if (fixed_ratios.size() > 0) { - num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); - } - } + for (size_t i = 0; i < densities.size(); ++i) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); } boxes->mutable_data(ctx.GetPlace()); vars->mutable_data(ctx.GetPlace()); - auto e_boxes = framework::EigenTensor::From(*boxes).setConstant(0.0); + auto box_dim = vars->dims(); + boxes->Resize({feature_height, feature_width, num_priors, 4}); + auto e_boxes = framework::EigenTensor::From(*boxes).setConstant(0.0); int step_average = static_cast((step_width + step_height) * 0.5); for (int h = 0; h < feature_height; ++h) { @@ -76,36 +74,34 @@ class DensityPriorBoxOpKernel : public framework::OpKernel { auto fixed_size = fixed_sizes[s]; int density = densities[s]; // Generate density prior boxes with fixed ratios. - if (fixed_ratios.size() > 0) { - for (size_t r = 0; r < fixed_ratios.size(); ++r) { - float ar = fixed_ratios[r]; - int shift = step_average / density; - float box_width_ratio = fixed_size * sqrt(ar); - float box_height_ratio = fixed_size / sqrt(ar); - for (int di = 0; di < density; ++di) { - for (int dj = 0; dj < density; ++dj) { - float center_x_temp = - center_x - step_average / 2. + shift / 2. + dj * shift; - float center_y_temp = - center_y - step_average / 2. + shift / 2. + di * shift; - e_boxes(h, w, idx, 0) = - (center_x_temp - box_width_ratio / 2.) / img_width >= 0 - ? (center_x_temp - box_width_ratio / 2.) / img_width - : 0; - e_boxes(h, w, idx, 1) = - (center_y_temp - box_height_ratio / 2.) / img_height >= 0 - ? (center_y_temp - box_height_ratio / 2.) / img_height - : 0; - e_boxes(h, w, idx, 2) = - (center_x_temp + box_width_ratio / 2.) / img_width <= 1 - ? (center_x_temp + box_width_ratio / 2.) / img_width - : 1; - e_boxes(h, w, idx, 3) = - (center_y_temp + box_height_ratio / 2.) / img_height <= 1 - ? (center_y_temp + box_height_ratio / 2.) / img_height - : 1; - idx++; - } + for (size_t r = 0; r < fixed_ratios.size(); ++r) { + float ar = fixed_ratios[r]; + int shift = step_average / density; + float box_width_ratio = fixed_size * sqrt(ar); + float box_height_ratio = fixed_size / sqrt(ar); + for (int di = 0; di < density; ++di) { + for (int dj = 0; dj < density; ++dj) { + float center_x_temp = + center_x - step_average / 2. + shift / 2. + dj * shift; + float center_y_temp = + center_y - step_average / 2. + shift / 2. + di * shift; + e_boxes(h, w, idx, 0) = + (center_x_temp - box_width_ratio / 2.) / img_width >= 0 + ? (center_x_temp - box_width_ratio / 2.) / img_width + : 0; + e_boxes(h, w, idx, 1) = + (center_y_temp - box_height_ratio / 2.) / img_height >= 0 + ? (center_y_temp - box_height_ratio / 2.) / img_height + : 0; + e_boxes(h, w, idx, 2) = + (center_x_temp + box_width_ratio / 2.) / img_width <= 1 + ? (center_x_temp + box_width_ratio / 2.) / img_width + : 1; + e_boxes(h, w, idx, 3) = + (center_y_temp + box_height_ratio / 2.) / img_height <= 1 + ? (center_y_temp + box_height_ratio / 2.) / img_height + : 1; + idx++; } } } @@ -139,6 +135,7 @@ class DensityPriorBoxOpKernel : public framework::OpKernel { e_vars = var_et.broadcast(Eigen::DSizes(box_num, 1)); vars->Resize(var_dim); + boxes->Resize(box_dim); } }; // namespace operators diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 3f17400a1432bb799e09accf2600ab6ec85e05a7..4843af8340310e0f47964d41708b13216fcd2161 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1029,6 +1029,7 @@ def density_prior_box(input, clip=False, steps=[0.0, 0.0], offset=0.5, + flatten_to_2d=False, name=None): """ **Density Prior Box Operator** @@ -1065,22 +1066,24 @@ def density_prior_box(input, height/weight of the input will be automatically calculated. Default: [0., 0.] offset(float): Prior boxes center offset. Default: 0.5 + flatten_to_2d(bool): Whether to flatten output prior boxes and variance + to 2D shape, the second dim is 4. Default: False. name(str): Name of the density prior box op. Default: None. Returns: tuple: A tuple with two Variable (boxes, variances) boxes: the output density prior boxes of PriorBox. - The layout is [H, W, num_priors, 4]. - H is the height of input, W is the width of input, - num_priors is the total - box count of each position of input. + The layout is [H, W, num_priors, 4] when flatten_to_2d is False. + The layout is [H * W * num_priors, 4] when flatten_to_2d is True. + H is the height of input, W is the width of input, + num_priors is the total box count of each position of input. variances: the expanded variances of PriorBox. - The layout is [H, W, num_priors, 4]. - H is the height of input, W is the width of input - num_priors is the total - box count of each position of input + The layout is [H, W, num_priors, 4] when flatten_to_2d is False. + The layout is [H * W * num_priors, 4] when flatten_to_2d is True. + H is the height of input, W is the width of input + num_priors is the total box count of each position of input. Examples: @@ -1089,14 +1092,11 @@ def density_prior_box(input, box, var = fluid.layers.density_prior_box( input=conv1, image=images, - min_sizes=[100.], - max_sizes=[200.], - aspect_ratios=[1.0, 1.0 / 2.0, 2.0], - densities=[3, 4], - fixed_sizes=[50., 60.], - fixed_ratios=[1.0, 3.0, 1.0 / 3.0], - flip=True, - clip=True) + densities=[4, 2, 1], + fixed_sizes=[32.0, 64.0, 128.0], + fixed_ratios=[1.], + clip=True, + flatten_to_2d=True) """ helper = LayerHelper("density_prior_box", **locals()) dtype = helper.input_dtype() @@ -1127,14 +1127,11 @@ def density_prior_box(input, 'step_w': steps[0], 'step_h': steps[1], 'offset': offset, + 'densities': densities, + 'fixed_sizes': fixed_sizes, + 'fixed_ratios': fixed_ratios, + 'flatten_to_2d': flatten_to_2d, } - if densities is not None and len(densities) > 0: - attrs['densities'] = densities - if fixed_sizes is not None and len(fixed_sizes) > 0: - attrs['fixed_sizes'] = fixed_sizes - if fixed_ratios is not None and len(fixed_ratios) > 0: - attrs['fixed_ratios'] = fixed_ratios - box = helper.create_variable_for_type_inference(dtype) var = helper.create_variable_for_type_inference(dtype) helper.append_op( diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 982d29180141d052e25ea3dcba6e3e7ce4181c48..a2eca5541a152ca99804a7f87c9b0bc3d12d4eee 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -112,38 +112,42 @@ class TestDetection(unittest.TestCase): class TestPriorBox(unittest.TestCase): def test_prior_box(self): - data_shape = [3, 224, 224] - images = fluid.layers.data( - name='pixel', shape=data_shape, dtype='float32') - conv1 = fluid.layers.conv2d(images, 3, 3, 2) - box, var = layers.prior_box( - input=conv1, - image=images, - min_sizes=[100.0], - aspect_ratios=[1.], - flip=True, - clip=True) - assert len(box.shape) == 4 - assert box.shape == var.shape - assert box.shape[3] == 4 + program = Program() + with program_guard(program): + data_shape = [3, 224, 224] + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d(images, 3, 3, 2) + box, var = layers.prior_box( + input=conv1, + image=images, + min_sizes=[100.0], + aspect_ratios=[1.], + flip=True, + clip=True) + assert len(box.shape) == 4 + assert box.shape == var.shape + assert box.shape[3] == 4 class TestDensityPriorBox(unittest.TestCase): def test_density_prior_box(self): - data_shape = [3, 224, 224] - images = fluid.layers.data( - name='pixel', shape=data_shape, dtype='float32') - conv1 = fluid.layers.conv2d(images, 3, 3, 2) - box, var = layers.density_prior_box( - input=conv1, - image=images, - densities=[3, 4], - fixed_sizes=[50., 60.], - fixed_ratios=[1.0], - clip=True) - assert len(box.shape) == 4 - assert box.shape == var.shape - assert box.shape[3] == 4 + program = Program() + with program_guard(program): + data_shape = [3, 224, 224] + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d(images, 3, 3, 2) + box, var = layers.density_prior_box( + input=conv1, + image=images, + densities=[3, 4], + fixed_sizes=[50., 60.], + fixed_ratios=[1.0], + clip=True) + assert len(box.shape) == 4 + assert box.shape == var.shape + assert box.shape[-1] == 4 class TestAnchorGenerator(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py b/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py index 79d1fd3d7171e06a88a75cf50b6a51ef4da51f07..4b0bc1dcf85fbb384eea09ee286d35ec248aae70 100644 --- a/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_density_prior_box_op.py @@ -36,7 +36,8 @@ class TestDensityPriorBoxOp(OpTest): 'offset': self.offset, 'densities': self.densities, 'fixed_sizes': self.fixed_sizes, - 'fixed_ratios': self.fixed_ratios + 'fixed_ratios': self.fixed_ratios, + 'flatten_to_2d': self.flatten_to_2d } self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} @@ -48,16 +49,17 @@ class TestDensityPriorBoxOp(OpTest): self.set_data() def set_density(self): - self.densities = [] - self.fixed_sizes = [] - self.fixed_ratios = [] + self.densities = [4, 2, 1] + self.fixed_sizes = [32.0, 64.0, 128.0] + self.fixed_ratios = [1.0] + self.layer_w = 17 + self.layer_h = 17 + self.image_w = 533 + self.image_h = 533 + self.flatten_to_2d = False def init_test_params(self): - self.layer_w = 32 - self.layer_h = 32 - - self.image_w = 40 - self.image_h = 40 + self.set_density() self.step_w = float(self.image_w) / float(self.layer_w) self.step_h = float(self.image_h) / float(self.layer_h) @@ -69,8 +71,6 @@ class TestDensityPriorBoxOp(OpTest): self.variances = [0.1, 0.1, 0.2, 0.2] self.variances = np.array(self.variances, dtype=np.float).flatten() - self.set_density() - self.clip = True self.num_priors = 0 if len(self.fixed_sizes) > 0 and len(self.densities) > 0: @@ -129,6 +129,9 @@ class TestDensityPriorBoxOp(OpTest): (self.layer_h, self.layer_w, self.num_priors, 1)) self.out_boxes = out_boxes.astype('float32') self.out_var = out_var.astype('float32') + if self.flatten_to_2d: + self.out_boxes = self.out_boxes.reshape((-1, 4)) + self.out_var = self.out_var.reshape((-1, 4)) class TestDensityPriorBox(TestDensityPriorBoxOp): @@ -136,6 +139,11 @@ class TestDensityPriorBox(TestDensityPriorBoxOp): self.densities = [3, 4] self.fixed_sizes = [1.0, 2.0] self.fixed_ratios = [1.0] + self.layer_w = 32 + self.layer_h = 32 + self.image_w = 40 + self.image_h = 40 + self.flatten_to_2d = True if __name__ == '__main__':