提交 bbff442e 编写于 作者: C chengduoZH

follow comments of qingqing

上级 d641d5ac
...@@ -19,30 +19,28 @@ from ..layer_helper import LayerHelper ...@@ -19,30 +19,28 @@ from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable
from tensor import concat from tensor import concat
from ops import reshape from ops import reshape
from operator import mul
import math import math
__all__ = [ __all__ = ['prior_box', ]
'prior_box',
'prior_boxes',
] def prior_box(inputs,
image,
min_ratio,
def prior_boxes(inputs, max_ratio,
image, aspect_ratios,
min_ratio, base_size,
max_ratio, steps=None,
aspect_ratios, step_w=None,
base_size, step_h=None,
steps=None, offset=0.5,
step_w=None, variance=[0.1, 0.1, 0.1, 0.1],
step_h=None, flip=False,
offset=0.5, clip=False,
variance=[0.1, 0.1, 0.1, 0.1], min_sizes=None,
flip=False, max_sizes=None,
clip=False, name=None):
min_sizes=None,
max_sizes=None,
name=None):
""" """
**Prior_boxes** **Prior_boxes**
...@@ -140,9 +138,10 @@ def prior_boxes(inputs, ...@@ -140,9 +138,10 @@ def prior_boxes(inputs,
def _reshape_with_axis_(input, axis=1): def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)): if not (axis > 0 and axis < len(input.shape)):
raise ValueError( raise ValueError(
"The axis should be smaller than the arity of input's shape.") "The axis should be smaller than the arity of input and bigger than 0."
)
new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)] new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)]
out = reshape([input], shape=new_shape) out = reshape(x=input, shape=new_shape)
return out return out
assert isinstance(inputs, list), 'inputs should be a list.' assert isinstance(inputs, list), 'inputs should be a list.'
......
...@@ -33,7 +33,7 @@ def prior_box_output(data_shape): ...@@ -33,7 +33,7 @@ def prior_box_output(data_shape):
conv5 = fluid.layers.conv2d( conv5 = fluid.layers.conv2d(
input=conv4, num_filters=3, filter_size=3, stride=2, use_cudnn=False) input=conv4, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
box, var = detection.prior_boxes( box, var = detection.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5], inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images, image=images,
min_ratio=20, min_ratio=20,
...@@ -57,20 +57,22 @@ def main(use_cuda): ...@@ -57,20 +57,22 @@ def main(use_cuda):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
batch = [128] batch = [4] # batch is not used in the prior_box.
assert box.shape[1] == 4
assert var.shape[1] == 4
assert box.shape == var.shape
assert len(box.shape) == 2
for _ in range(1): for _ in range(1):
x = np.random.random(batch + data_shape).astype("float32") x = np.random.random(batch + data_shape).astype("float32")
tensor_x = core.LoDTensor() tensor_x = core.LoDTensor()
tensor_x.set(x, place) tensor_x.set(x, place)
box, var = exe.run(fluid.default_main_program(), boxes, vars = exe.run(fluid.default_main_program(),
feed={'pixel': tensor_x}, feed={'pixel': tensor_x},
fetch_list=[box, var]) fetch_list=[box, var])
box_arr = np.array(box) assert vars.shape == var.shape
var_arr = np.array(var) assert boxes.shape == box.shape
assert box_arr.shape[1] == 4
assert var_arr.shape[1] == 4
assert box_arr.shape[0] == var_arr.shape[0]
class TestFitALine(unittest.TestCase): class TestFitALine(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册