Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a9fe09f8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a9fe09f8
编写于
4月 13, 2020
作者:
X
xiaoting
提交者:
GitHub
4月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish error message (#23696)
* polish error message, test=develop
上级
ff0ab756
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
191 addition
and
67 deletion
+191
-67
paddle/fluid/operators/detection/multiclass_nms_op.cc
paddle/fluid/operators/detection/multiclass_nms_op.cc
+25
-15
paddle/fluid/operators/detection/yolo_box_op.cc
paddle/fluid/operators/detection/yolo_box_op.cc
+35
-16
paddle/fluid/operators/detection/yolov3_loss_op.cc
paddle/fluid/operators/detection/yolov3_loss_op.cc
+86
-35
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+12
-1
python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
...on/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
+33
-0
未找到文件。
paddle/fluid/operators/detection/multiclass_nms_op.cc
浏览文件 @
a9fe09f8
...
...
@@ -26,12 +26,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BBoxes"
),
"Input(BBoxes) of MultiClassNMS should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Scores"
),
"Input(Scores) of MultiClassNMS should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of MultiClassNMS should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BBoxes"
),
"Input"
,
"BBoxes"
,
"MultiClassNMS"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scores"
),
"Input"
,
"Scores"
,
"MultiClassNMS"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"MultiClassNMS"
);
auto
box_dims
=
ctx
->
GetInputDim
(
"BBoxes"
);
auto
score_dims
=
ctx
->
GetInputDim
(
"Scores"
);
...
...
@@ -41,7 +38,10 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
score_size
==
2
||
score_size
==
3
,
"The rank of Input(Scores) must be 2 or 3"
);
PADDLE_ENFORCE_EQ
(
box_dims
.
size
(),
3
,
"The rank of Input(BBoxes) must be 3"
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(BBoxes) must be 3"
"But receive box_dims size(%s)"
,
box_dims
.
size
()));
if
(
score_size
==
3
)
{
PADDLE_ENFORCE
(
box_dims
[
2
]
==
4
||
box_dims
[
2
]
==
8
||
box_dims
[
2
]
==
16
||
box_dims
[
2
]
==
24
||
...
...
@@ -55,16 +55,26 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
"16 points: [xi, yi] i= 1,2,...,16"
);
PADDLE_ENFORCE_EQ
(
box_dims
[
1
],
score_dims
[
2
],
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
);
platform
::
errors
::
InvalidArgument
(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)"
,
box_dims
[
1
],
score_dims
[
2
]));
}
else
{
PADDLE_ENFORCE
(
box_dims
[
2
]
==
4
,
"The last dimension of Input(BBoxes) must be 4"
);
PADDLE_ENFORCE_EQ
(
box_dims
[
2
],
4
,
platform
::
errors
::
InvalidArgument
(
"The last dimension of Input(BBoxes) must be 4"
"But received box_dims[2](%s)."
,
box_dims
[
2
]));
PADDLE_ENFORCE_EQ
(
box_dims
[
1
],
score_dims
[
1
],
"The 2nd dimension of Input(BBoxes)"
"must be equal to the 2nd dimension"
" of Input(Scores)"
);
platform
::
errors
::
InvalidArgument
(
"The 2nd dimension of Input(BBoxes)"
"must be equal to the 2nd dimension"
" of Input(Scores)"
"But received box_dims[1](%s) != "
"score_dims[1](%s)"
,
box_dims
[
1
],
score_dims
[
1
]));
}
}
// Here the box_dims[0] is not the real dimension of output.
...
...
paddle/fluid/operators/detection/yolo_box_op.cc
浏览文件 @
a9fe09f8
...
...
@@ -21,14 +21,10 @@ class YoloBoxOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ImgSize"
),
"Input(ImgSize) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Boxes"
),
"Output(Boxes) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Scores"
),
"Output(Scores) of YoloBoxOp should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"YoloBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ImgSize"
),
"Input"
,
"ImgSize"
,
"YoloBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Boxes"
),
"Output"
,
"Boxes"
,
"YoloBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Scores"
),
"Output"
,
"Scores"
,
"YoloBoxOp"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_imgsize
=
ctx
->
GetInputDim
(
"ImgSize"
);
...
...
@@ -36,26 +32,49 @@ class YoloBoxOp : public framework::OperatorWithKernel {
int
anchor_num
=
anchors
.
size
()
/
2
;
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"Input(X) should be a 4-D tensor."
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input(X) should be a 4-D tensor."
"But received X dimension(%s)"
,
dim_x
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s)."
,
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
)));
PADDLE_ENFORCE_EQ
(
dim_imgsize
.
size
(),
2
,
"Input(ImgSize) should be a 2-D tensor."
);
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) should be a 2-D tensor."
"But received Imgsize size(%s)"
,
dim_imgsize
.
size
()));
if
((
dim_imgsize
[
0
]
>
0
&&
dim_x
[
0
]
>
0
)
||
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
0
],
dim_x
[
0
],
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."
));
}
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
1
],
2
,
"Input(ImgSize) dim[1] should be 2."
);
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
1
],
2
,
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) dim[1] should be 2."
"But received imgsize dim[1](%s)."
,
dim_imgsize
[
1
]));
PADDLE_ENFORCE_GT
(
anchors
.
size
(),
0
,
"Attr(anchors) length should be greater than 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be greater than 0."
"But received anchors length(%s)."
,
anchors
.
size
()));
PADDLE_ENFORCE_EQ
(
anchors
.
size
()
%
2
,
0
,
"Attr(anchors) length should be even integer."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be even integer."
"But received anchors length (%s)"
,
anchors
.
size
()));
PADDLE_ENFORCE_GT
(
class_num
,
0
,
"Attr(class_num) should be an integer greater than 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(class_num) should be an integer greater than 0."
"But received class_num (%s)"
,
class_num
));
int
box_num
=
dim_x
[
2
]
*
dim_x
[
3
]
*
anchor_num
;
std
::
vector
<
int64_t
>
dim_boxes
({
dim_x
[
0
],
box_num
,
4
});
...
...
paddle/fluid/operators/detection/yolov3_loss_op.cc
浏览文件 @
a9fe09f8
...
...
@@ -23,19 +23,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"GTBox"
),
"Input(GTBox) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"GTLabel"
),
"Input(GTLabel) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Loss"
),
"Output(Loss) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ObjectnessMask"
),
"Output(ObjectnessMask) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"GTMatchMask"
),
"Output(GTMatchMask) of Yolov3LossOp should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"GTBox"
),
"Input"
,
"GTBox"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"GTLabel"
),
"Input"
,
"GTLabel"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Loss"
),
"Output"
,
"Loss"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"ObjectnessMask"
),
"Output"
,
"ObjectnessMask"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"GTMatchMask"
),
"Output"
,
"GTMatchMask"
,
"Yolov3LossOp"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_gtbox
=
ctx
->
GetInputDim
(
"GTBox"
);
...
...
@@ -46,44 +42,96 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
int
mask_num
=
anchor_mask
.
size
();
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"Input(X) should be a 4-D tensor."
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input(X) should be a 4-D tensor. But received "
"X dimension size(%s)"
,
dim_x
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_x
[
2
],
dim_x
[
3
],
"Input(X) dim[3] and dim[4] should be euqal."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[3] and dim[4] should be euqal."
"But received dim[3](%s) != dim[4](%s)"
,
dim_x
[
2
],
dim_x
[
3
]));
PADDLE_ENFORCE_EQ
(
dim_x
[
1
],
mask_num
*
(
5
+
class_num
),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s)."
,
dim_x
[
1
],
mask_num
*
(
5
+
class_num
)));
PADDLE_ENFORCE_EQ
(
dim_gtbox
.
size
(),
3
,
"Input(GTBox) should be a 3-D tensor"
);
PADDLE_ENFORCE_EQ
(
dim_gtbox
[
2
],
4
,
"Input(GTBox) dim[2] should be 5"
);
PADDLE_ENFORCE_EQ
(
dim_gtlabel
.
size
(),
2
,
"Input(GTLabel) should be a 2-D tensor"
);
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
0
],
dim_gtbox
[
0
],
"Input(GTBox) and Input(GTLabel) dim[0] should be same"
);
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
1
],
dim_gtbox
[
1
],
"Input(GTBox) and Input(GTLabel) dim[1] should be same"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) should be a 3-D tensor, but "
"received gtbox dimension size(%s)"
,
dim_gtbox
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_gtbox
[
2
],
4
,
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) dim[2] should be 4"
,
"But receive dim[2](%s) != 5. "
,
dim_gtbox
[
2
]));
PADDLE_ENFORCE_EQ
(
dim_gtlabel
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(GTLabel) should be a 2-D tensor,"
"But received Input(GTLabel) dimension size(%s) != 2."
,
dim_gtlabel
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
0
],
dim_gtbox
[
0
],
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) dim[0] and Input(GTLabel) dim[0] should be same,"
"But received Input(GTLabel) dim[0](%s) != "
"Input(GTBox) dim[0](%s)"
,
dim_gtlabel
[
0
],
dim_gtbox
[
0
]));
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
1
],
dim_gtbox
[
1
],
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) and Input(GTLabel) dim[1] should be same,"
"But received Input(GTBox) dim[1](%s) != Input(GTLabel) "
"dim[1](%s)"
,
dim_gtbox
[
1
],
dim_gtlabel
[
1
]));
PADDLE_ENFORCE_GT
(
anchors
.
size
(),
0
,
"Attr(anchors) length should be greater then 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be greater then 0."
"But received anchors length(%s)"
,
anchors
.
size
()));
PADDLE_ENFORCE_EQ
(
anchors
.
size
()
%
2
,
0
,
"Attr(anchors) length should be even integer."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be even integer."
"But received anchors length(%s)"
,
anchors
.
size
()));
for
(
size_t
i
=
0
;
i
<
anchor_mask
.
size
();
i
++
)
{
PADDLE_ENFORCE_LT
(
anchor_mask
[
i
],
anchor_num
,
"Attr(anchor_mask) should not crossover Attr(anchors)."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchor_mask) should not crossover Attr(anchors)."
"But received anchor_mask[i](%s) > anchor_num(%s)"
,
anchor_mask
[
i
],
anchor_num
));
}
PADDLE_ENFORCE_GT
(
class_num
,
0
,
"Attr(class_num) should be an integer greater then 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(class_num) should be an integer greater then 0."
"But received class_num(%s) < 0"
,
class_num
));
if
(
ctx
->
HasInput
(
"GTScore"
))
{
auto
dim_gtscore
=
ctx
->
GetInputDim
(
"GTScore"
);
PADDLE_ENFORCE_EQ
(
dim_gtscore
.
size
(),
2
,
"Input(GTScore) should be a 2-D tensor"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTScore) should be a 2-D tensor"
"But received GTScore dimension(%s)"
,
dim_gtbox
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_gtscore
[
0
],
dim_gtbox
[
0
],
"Input(GTBox) and Input(GTScore) dim[0] should be same"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) and Input(GTScore) dim[0] should be same"
"But received GTBox dim[0](%s) != GTScore dim[0](%s)"
,
dim_gtbox
[
0
],
dim_gtscore
[
0
]));
PADDLE_ENFORCE_EQ
(
dim_gtscore
[
1
],
dim_gtbox
[
1
],
"Input(GTBox) and Input(GTScore) dim[1] should be same"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) and Input(GTScore) dim[1] should be same"
"But received GTBox dim[1](%s) != GTScore dim[1](%s)"
,
dim_gtscore
[
1
],
dim_gtbox
[
1
]));
}
std
::
vector
<
int64_t
>
dim_out
({
dim_x
[
0
]});
...
...
@@ -245,9 +293,12 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@GRAD) should not be null"
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
platform
::
errors
::
NotFound
(
"Input(X) should not be null"
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
true
,
platform
::
errors
::
NotFound
(
"Input(Loss@GRAD) should not be null"
));
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
dim_x
);
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
a9fe09f8
...
...
@@ -3178,6 +3178,18 @@ def multiclass_nms(bboxes,
keep_top_k=200,
normalized=False)
"""
check_variable_and_dtype
(
bboxes
,
'BBoxes'
,
[
'float32'
,
'float64'
],
'multiclass_nms'
)
check_variable_and_dtype
(
scores
,
'Scores'
,
[
'float32'
,
'float64'
],
'multiclass_nms'
)
check_type
(
score_threshold
,
'score_threshold'
,
float
,
'multicalss_nms'
)
check_type
(
nms_top_k
,
'nums_top_k'
,
int
,
'multiclass_nms'
)
check_type
(
keep_top_k
,
'keep_top_k'
,
int
,
'mutliclass_nms'
)
check_type
(
nms_threshold
,
'nms_threshold'
,
float
,
'multiclass_nms'
)
check_type
(
normalized
,
'normalized'
,
bool
,
'multiclass_nms'
)
check_type
(
nms_eta
,
'nms_eta'
,
float
,
'multiclass_nms'
)
check_type
(
background_label
,
'background_label'
,
int
,
'multiclass_nms'
)
helper
=
LayerHelper
(
'multiclass_nms'
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
bboxes
.
dtype
)
...
...
@@ -3192,7 +3204,6 @@ def multiclass_nms(bboxes,
'nms_threshold'
:
nms_threshold
,
'nms_eta'
:
nms_eta
,
'keep_top_k'
:
keep_top_k
,
'nms_eta'
:
nms_eta
,
'normalized'
:
normalized
},
outputs
=
{
'Out'
:
output
})
...
...
python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
浏览文件 @
a9fe09f8
...
...
@@ -17,6 +17,8 @@ import unittest
import
numpy
as
np
import
copy
from
op_test
import
OpTest
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
def
softmax
(
x
):
...
...
@@ -487,5 +489,36 @@ class TestMulticlassNMS2LoDNoOutput(TestMulticlassNMS2LoDInput):
self
.
score_threshold
=
2.0
class
TestMulticlassNMSError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
M
=
1200
N
=
7
C
=
21
BOX_SIZE
=
4
boxes_np
=
np
.
random
.
random
((
M
,
C
,
BOX_SIZE
)).
astype
(
'float32'
)
scores
=
np
.
random
.
random
((
N
*
M
,
C
)).
astype
(
'float32'
)
scores
=
np
.
apply_along_axis
(
softmax
,
1
,
scores
)
scores
=
np
.
reshape
(
scores
,
(
N
,
M
,
C
))
scores_np
=
np
.
transpose
(
scores
,
(
0
,
2
,
1
))
boxes_data
=
fluid
.
data
(
name
=
'bboxes'
,
shape
=
[
M
,
C
,
BOX_SIZE
],
dtype
=
'float32'
)
scores_data
=
fluid
.
data
(
name
=
'scores'
,
shape
=
[
N
,
C
,
M
],
dtype
=
'float32'
)
def
test_bboxes_Variable
():
# the bboxes type must be Variable
fluid
.
layers
.
multiclass_nms
(
bboxes
=
boxes_np
,
scores
=
scores_data
)
def
test_scores_Variable
():
# the bboxes type must be Variable
fluid
.
layers
.
multiclass_nms
(
bboxes
=
boxes_data
,
scores
=
scores_np
)
self
.
assertRaises
(
TypeError
,
test_bboxes_Variable
)
self
.
assertRaises
(
TypeError
,
test_scores_Variable
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录