提交 bbff442e 编写于 作者: C chengduoZH

follow comments of qingqing

上级 d641d5ac
......@@ -19,15 +19,13 @@ from ..layer_helper import LayerHelper
from ..framework import Variable
from tensor import concat
from ops import reshape
from operator import mul
import math
__all__ = [
'prior_box',
'prior_boxes',
]
__all__ = ['prior_box', ]
def prior_boxes(inputs,
def prior_box(inputs,
image,
min_ratio,
max_ratio,
......@@ -140,9 +138,10 @@ def prior_boxes(inputs,
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError(
"The axis should be smaller than the arity of input's shape.")
"The axis should be smaller than the arity of input and bigger than 0."
)
new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)]
out = reshape([input], shape=new_shape)
out = reshape(x=input, shape=new_shape)
return out
assert isinstance(inputs, list), 'inputs should be a list.'
......
......@@ -33,7 +33,7 @@ def prior_box_output(data_shape):
conv5 = fluid.layers.conv2d(
input=conv4, num_filters=3, filter_size=3, stride=2, use_cudnn=False)
box, var = detection.prior_boxes(
box, var = detection.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
min_ratio=20,
......@@ -57,20 +57,22 @@ def main(use_cuda):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch = [128]
batch = [4] # batch is not used in the prior_box.
assert box.shape[1] == 4
assert var.shape[1] == 4
assert box.shape == var.shape
assert len(box.shape) == 2
for _ in range(1):
x = np.random.random(batch + data_shape).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x, place)
box, var = exe.run(fluid.default_main_program(),
boxes, vars = exe.run(fluid.default_main_program(),
feed={'pixel': tensor_x},
fetch_list=[box, var])
box_arr = np.array(box)
var_arr = np.array(var)
assert box_arr.shape[1] == 4
assert var_arr.shape[1] == 4
assert box_arr.shape[0] == var_arr.shape[0]
assert vars.shape == var.shape
assert boxes.shape == box.shape
class TestFitALine(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册