提交 dff1bf33 编写于 作者: C chengduoZH

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/add_prior_box_py

......@@ -106,9 +106,11 @@ class Vector {
// std::vector iterator methods. Based on CPU data access method
size_t size() const { return size_; }
T* begin() { return &this->operator[](0); }
T* begin() { return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); }
T* end() { return &this->operator[](size()); }
T* end() {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
}
T& front() { return *begin(); }
......@@ -118,8 +120,13 @@ class Vector {
return *it;
}
const T* begin() const { return &this->operator[](0); }
const T* end() const { return &this->operator[](size()); }
const T* begin() const {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](0);
}
const T* end() const {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
}
const T* cbegin() const { return begin(); }
......@@ -358,6 +365,11 @@ class Vector {
}
}
static T& EmptyDummy() {
static T dummy = T();
return dummy;
}
mutable int flag_;
mutable Tensor cpu_vec_;
mutable Tensor cuda_vec_;
......
......@@ -98,3 +98,9 @@ TEST(mixed_vector, InitWithCount) {
ASSERT_EQ(vec[i], 10);
}
}
TEST(mixed_vector, ForEach) {
vec<int> tmp;
for (auto& v : tmp) {
}
}
......@@ -29,6 +29,6 @@ inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp)
inference_test(recommender_system)
inference_test(rnn_encoder_decoder)
#inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment)
inference_test(word2vec)
......@@ -19,7 +19,6 @@ from ..layer_helper import LayerHelper
from ..framework import Variable
from tensor import concat
from ops import reshape
from operator import mul
import math
__all__ = [
......@@ -143,43 +142,50 @@ def prior_box(inputs,
"""
**Prior_boxes**
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
The details of this algorithm, please refer the section 2.2 of SSD paper
(SSD: Single Shot MultiBox Detector)<https://arxiv.org/abs/1512.02325>`_ .
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list): The list of input Variables, the format of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp, the layout is NCHW.
inputs(list): The list of input Variables, the format
of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes.
aspect_ratios(list): the aspect ratios of generated prior boxes.
The length of input and aspect_ratios must be equal.
base_size(int): the base_size is used to get min_size and max_size
according to min_ratio and max_ratio.
step_w(list, optional, default=None): Prior boxes step across width.
If step_w[i] == 0.0, the prior boxes step across width of the inputs[i]
will be automatically calculated.
step_h(list, optional, default=None): Prior boxes step across height,
If step_h[i] == 0.0, the prior boxes step across height of the inputs[i]
will be automatically calculated.
aspect_ratios(list): the aspect ratios of generated prior
boxes. The length of input and aspect_ratios must be equal.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
step_w(list, optional, default=None): Prior boxes step
across width. If step_w[i] == 0.0, the prior boxes step
across width of the inputs[i] will be automatically calculated.
step_h(list, optional, default=None): Prior boxes step
across height, If step_h[i] == 0.0, the prior boxes
step across height of the inputs[i] will be automatically calculated.
offset(float, optional, default=0.5): Prior boxes center offset.
variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances
to be encoded in prior boxes.
flip(bool, optional, default=False): Whether to flip aspect ratios.
clip(bool, optional, default=False): Whether to clip out-of-boundary boxes.
min_sizes(list, optional, default=None): If `len(inputs) <=2`, min_sizes must
be set up, and the length of min_sizes should equal to the length of inputs.
max_sizes(list, optional, default=None): If `len(inputs) <=2`, max_sizes must
be set up, and the length of min_sizes should equal to the length of inputs.
flip(bool, optional, default=False): Whether to flip
aspect ratios.
clip(bool, optional, default=False): Whether to clip
out-of-boundary boxes.
min_sizes(list, optional, default=None): If `len(inputs) <=2`,
min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
max_sizes(list, optional, default=None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
name(str, optional, None): Name of the prior box layer.
Returns:
boxes(Variable): the output prior boxes of PriorBoxOp. The layout is
[num_priors, 4]. num_priors is the total box count of each
position of inputs.
Variances(Variable): the expanded variances of PriorBoxOp. The layout
is [num_priors, 4]. num_priors is the total box count of each
position of inputs
boxes(Variable): the output prior boxes of PriorBoxOp.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBoxOp.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
Examples:
.. code-block:: python
......@@ -235,10 +241,11 @@ def prior_box(inputs,
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError(
"The axis should be smaller than the arity of input and bigger than 0."
)
new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)]
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = reshape(x=input, shape=new_shape)
return out
......
......@@ -54,8 +54,12 @@ class TestBook(unittest.TestCase):
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
self.check_prior_box(use_cuda=False)
self.check_prior_box(use_cuda=True)
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
assert len(box.shape) == 2
assert box.shape == var.shape
assert box.shape[1] == 4
def prior_box_output(self, data_shape):
images = fluid.layers.data(
......@@ -104,32 +108,6 @@ class TestPriorBox(unittest.TestCase):
clip=True)
return box, var
def check_prior_box(self, use_cuda):
if use_cuda: # prior_box only support CPU.
return
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch = [4] # batch is not used in the prior_box.
assert box.shape[1] == 4
assert var.shape[1] == 4
assert box.shape == var.shape
assert len(box.shape) == 2
x = np.random.random(batch + data_shape).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x, place)
boxes, vars = exe.run(fluid.default_main_program(),
feed={'pixel': tensor_x},
fetch_list=[box, var])
assert vars.shape == var.shape
assert boxes.shape == box.shape
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册