Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a9fe09f8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a9fe09f8
编写于
4月 13, 2020
作者:
X
xiaoting
提交者:
GitHub
4月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish error message (#23696)
* polish error message, test=develop
上级
ff0ab756
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
191 addition
and
67 deletion
+191
-67
paddle/fluid/operators/detection/multiclass_nms_op.cc
paddle/fluid/operators/detection/multiclass_nms_op.cc
+25
-15
paddle/fluid/operators/detection/yolo_box_op.cc
paddle/fluid/operators/detection/yolo_box_op.cc
+35
-16
paddle/fluid/operators/detection/yolov3_loss_op.cc
paddle/fluid/operators/detection/yolov3_loss_op.cc
+86
-35
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+12
-1
python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
...on/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
+33
-0
未找到文件。
paddle/fluid/operators/detection/multiclass_nms_op.cc
浏览文件 @
a9fe09f8
...
...
@@ -26,12 +26,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BBoxes"
),
"Input(BBoxes) of MultiClassNMS should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Scores"
),
"Input(Scores) of MultiClassNMS should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of MultiClassNMS should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BBoxes"
),
"Input"
,
"BBoxes"
,
"MultiClassNMS"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scores"
),
"Input"
,
"Scores"
,
"MultiClassNMS"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"MultiClassNMS"
);
auto
box_dims
=
ctx
->
GetInputDim
(
"BBoxes"
);
auto
score_dims
=
ctx
->
GetInputDim
(
"Scores"
);
...
...
@@ -41,7 +38,10 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
score_size
==
2
||
score_size
==
3
,
"The rank of Input(Scores) must be 2 or 3"
);
PADDLE_ENFORCE_EQ
(
box_dims
.
size
(),
3
,
"The rank of Input(BBoxes) must be 3"
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(BBoxes) must be 3"
"But receive box_dims size(%s)"
,
box_dims
.
size
()));
if
(
score_size
==
3
)
{
PADDLE_ENFORCE
(
box_dims
[
2
]
==
4
||
box_dims
[
2
]
==
8
||
box_dims
[
2
]
==
16
||
box_dims
[
2
]
==
24
||
...
...
@@ -55,16 +55,26 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
"16 points: [xi, yi] i= 1,2,...,16"
);
PADDLE_ENFORCE_EQ
(
box_dims
[
1
],
score_dims
[
2
],
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
);
platform
::
errors
::
InvalidArgument
(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)"
,
box_dims
[
1
],
score_dims
[
2
]));
}
else
{
PADDLE_ENFORCE
(
box_dims
[
2
]
==
4
,
"The last dimension of Input(BBoxes) must be 4"
);
PADDLE_ENFORCE_EQ
(
box_dims
[
2
],
4
,
platform
::
errors
::
InvalidArgument
(
"The last dimension of Input(BBoxes) must be 4"
"But received box_dims[2](%s)."
,
box_dims
[
2
]));
PADDLE_ENFORCE_EQ
(
box_dims
[
1
],
score_dims
[
1
],
"The 2nd dimension of Input(BBoxes)"
"must be equal to the 2nd dimension"
" of Input(Scores)"
);
platform
::
errors
::
InvalidArgument
(
"The 2nd dimension of Input(BBoxes)"
"must be equal to the 2nd dimension"
" of Input(Scores)"
"But received box_dims[1](%s) != "
"score_dims[1](%s)"
,
box_dims
[
1
],
score_dims
[
1
]));
}
}
// Here the box_dims[0] is not the real dimension of output.
...
...
paddle/fluid/operators/detection/yolo_box_op.cc
浏览文件 @
a9fe09f8
...
...
@@ -21,14 +21,10 @@ class YoloBoxOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ImgSize"
),
"Input(ImgSize) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Boxes"
),
"Output(Boxes) of YoloBoxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Scores"
),
"Output(Scores) of YoloBoxOp should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"YoloBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ImgSize"
),
"Input"
,
"ImgSize"
,
"YoloBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Boxes"
),
"Output"
,
"Boxes"
,
"YoloBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Scores"
),
"Output"
,
"Scores"
,
"YoloBoxOp"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_imgsize
=
ctx
->
GetInputDim
(
"ImgSize"
);
...
...
@@ -36,26 +32,49 @@ class YoloBoxOp : public framework::OperatorWithKernel {
int
anchor_num
=
anchors
.
size
()
/
2
;
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"Input(X) should be a 4-D tensor."
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input(X) should be a 4-D tensor."
"But received X dimension(%s)"
,
dim_x
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s)."
,
dim_x
[
1
],
anchor_num
*
(
5
+
class_num
)));
PADDLE_ENFORCE_EQ
(
dim_imgsize
.
size
(),
2
,
"Input(ImgSize) should be a 2-D tensor."
);
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) should be a 2-D tensor."
"But received Imgsize size(%s)"
,
dim_imgsize
.
size
()));
if
((
dim_imgsize
[
0
]
>
0
&&
dim_x
[
0
]
>
0
)
||
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
0
],
dim_x
[
0
],
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same."
));
}
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
1
],
2
,
"Input(ImgSize) dim[1] should be 2."
);
PADDLE_ENFORCE_EQ
(
dim_imgsize
[
1
],
2
,
platform
::
errors
::
InvalidArgument
(
"Input(ImgSize) dim[1] should be 2."
"But received imgsize dim[1](%s)."
,
dim_imgsize
[
1
]));
PADDLE_ENFORCE_GT
(
anchors
.
size
(),
0
,
"Attr(anchors) length should be greater than 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be greater than 0."
"But received anchors length(%s)."
,
anchors
.
size
()));
PADDLE_ENFORCE_EQ
(
anchors
.
size
()
%
2
,
0
,
"Attr(anchors) length should be even integer."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be even integer."
"But received anchors length (%s)"
,
anchors
.
size
()));
PADDLE_ENFORCE_GT
(
class_num
,
0
,
"Attr(class_num) should be an integer greater than 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(class_num) should be an integer greater than 0."
"But received class_num (%s)"
,
class_num
));
int
box_num
=
dim_x
[
2
]
*
dim_x
[
3
]
*
anchor_num
;
std
::
vector
<
int64_t
>
dim_boxes
({
dim_x
[
0
],
box_num
,
4
});
...
...
paddle/fluid/operators/detection/yolov3_loss_op.cc
浏览文件 @
a9fe09f8
...
...
@@ -23,19 +23,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"GTBox"
),
"Input(GTBox) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"GTLabel"
),
"Input(GTLabel) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Loss"
),
"Output(Loss) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ObjectnessMask"
),
"Output(ObjectnessMask) of Yolov3LossOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"GTMatchMask"
),
"Output(GTMatchMask) of Yolov3LossOp should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"GTBox"
),
"Input"
,
"GTBox"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"GTLabel"
),
"Input"
,
"GTLabel"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Loss"
),
"Output"
,
"Loss"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"ObjectnessMask"
),
"Output"
,
"ObjectnessMask"
,
"Yolov3LossOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"GTMatchMask"
),
"Output"
,
"GTMatchMask"
,
"Yolov3LossOp"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_gtbox
=
ctx
->
GetInputDim
(
"GTBox"
);
...
...
@@ -46,44 +42,96 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
int
mask_num
=
anchor_mask
.
size
();
auto
class_num
=
ctx
->
Attrs
().
Get
<
int
>
(
"class_num"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"Input(X) should be a 4-D tensor."
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input(X) should be a 4-D tensor. But received "
"X dimension size(%s)"
,
dim_x
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_x
[
2
],
dim_x
[
3
],
"Input(X) dim[3] and dim[4] should be euqal."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[3] and dim[4] should be euqal."
"But received dim[3](%s) != dim[4](%s)"
,
dim_x
[
2
],
dim_x
[
3
]));
PADDLE_ENFORCE_EQ
(
dim_x
[
1
],
mask_num
*
(
5
+
class_num
),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num))."
"But received dim[1](%s) != (anchor_mask_number * "
"(5+class_num)(%s)."
,
dim_x
[
1
],
mask_num
*
(
5
+
class_num
)));
PADDLE_ENFORCE_EQ
(
dim_gtbox
.
size
(),
3
,
"Input(GTBox) should be a 3-D tensor"
);
PADDLE_ENFORCE_EQ
(
dim_gtbox
[
2
],
4
,
"Input(GTBox) dim[2] should be 5"
);
PADDLE_ENFORCE_EQ
(
dim_gtlabel
.
size
(),
2
,
"Input(GTLabel) should be a 2-D tensor"
);
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
0
],
dim_gtbox
[
0
],
"Input(GTBox) and Input(GTLabel) dim[0] should be same"
);
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
1
],
dim_gtbox
[
1
],
"Input(GTBox) and Input(GTLabel) dim[1] should be same"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) should be a 3-D tensor, but "
"received gtbox dimension size(%s)"
,
dim_gtbox
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_gtbox
[
2
],
4
,
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) dim[2] should be 4"
,
"But receive dim[2](%s) != 5. "
,
dim_gtbox
[
2
]));
PADDLE_ENFORCE_EQ
(
dim_gtlabel
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(GTLabel) should be a 2-D tensor,"
"But received Input(GTLabel) dimension size(%s) != 2."
,
dim_gtlabel
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
0
],
dim_gtbox
[
0
],
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) dim[0] and Input(GTLabel) dim[0] should be same,"
"But received Input(GTLabel) dim[0](%s) != "
"Input(GTBox) dim[0](%s)"
,
dim_gtlabel
[
0
],
dim_gtbox
[
0
]));
PADDLE_ENFORCE_EQ
(
dim_gtlabel
[
1
],
dim_gtbox
[
1
],
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) and Input(GTLabel) dim[1] should be same,"
"But received Input(GTBox) dim[1](%s) != Input(GTLabel) "
"dim[1](%s)"
,
dim_gtbox
[
1
],
dim_gtlabel
[
1
]));
PADDLE_ENFORCE_GT
(
anchors
.
size
(),
0
,
"Attr(anchors) length should be greater then 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be greater then 0."
"But received anchors length(%s)"
,
anchors
.
size
()));
PADDLE_ENFORCE_EQ
(
anchors
.
size
()
%
2
,
0
,
"Attr(anchors) length should be even integer."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchors) length should be even integer."
"But received anchors length(%s)"
,
anchors
.
size
()));
for
(
size_t
i
=
0
;
i
<
anchor_mask
.
size
();
i
++
)
{
PADDLE_ENFORCE_LT
(
anchor_mask
[
i
],
anchor_num
,
"Attr(anchor_mask) should not crossover Attr(anchors)."
);
platform
::
errors
::
InvalidArgument
(
"Attr(anchor_mask) should not crossover Attr(anchors)."
"But received anchor_mask[i](%s) > anchor_num(%s)"
,
anchor_mask
[
i
],
anchor_num
));
}
PADDLE_ENFORCE_GT
(
class_num
,
0
,
"Attr(class_num) should be an integer greater then 0."
);
platform
::
errors
::
InvalidArgument
(
"Attr(class_num) should be an integer greater then 0."
"But received class_num(%s) < 0"
,
class_num
));
if
(
ctx
->
HasInput
(
"GTScore"
))
{
auto
dim_gtscore
=
ctx
->
GetInputDim
(
"GTScore"
);
PADDLE_ENFORCE_EQ
(
dim_gtscore
.
size
(),
2
,
"Input(GTScore) should be a 2-D tensor"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTScore) should be a 2-D tensor"
"But received GTScore dimension(%s)"
,
dim_gtbox
.
size
()));
PADDLE_ENFORCE_EQ
(
dim_gtscore
[
0
],
dim_gtbox
[
0
],
"Input(GTBox) and Input(GTScore) dim[0] should be same"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) and Input(GTScore) dim[0] should be same"
"But received GTBox dim[0](%s) != GTScore dim[0](%s)"
,
dim_gtbox
[
0
],
dim_gtscore
[
0
]));
PADDLE_ENFORCE_EQ
(
dim_gtscore
[
1
],
dim_gtbox
[
1
],
"Input(GTBox) and Input(GTScore) dim[1] should be same"
);
platform
::
errors
::
InvalidArgument
(
"Input(GTBox) and Input(GTScore) dim[1] should be same"
"But received GTBox dim[1](%s) != GTScore dim[1](%s)"
,
dim_gtscore
[
1
],
dim_gtbox
[
1
]));
}
std
::
vector
<
int64_t
>
dim_out
({
dim_x
[
0
]});
...
...
@@ -245,9 +293,12 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
"Input(Loss@GRAD) should not be null"
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
platform
::
errors
::
NotFound
(
"Input(X) should not be null"
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Loss"
)),
true
,
platform
::
errors
::
NotFound
(
"Input(Loss@GRAD) should not be null"
));
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
dim_x
);
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
a9fe09f8
...
...
@@ -3178,6 +3178,18 @@ def multiclass_nms(bboxes,
keep_top_k=200,
normalized=False)
"""
check_variable_and_dtype
(
bboxes
,
'BBoxes'
,
[
'float32'
,
'float64'
],
'multiclass_nms'
)
check_variable_and_dtype
(
scores
,
'Scores'
,
[
'float32'
,
'float64'
],
'multiclass_nms'
)
check_type
(
score_threshold
,
'score_threshold'
,
float
,
'multicalss_nms'
)
check_type
(
nms_top_k
,
'nums_top_k'
,
int
,
'multiclass_nms'
)
check_type
(
keep_top_k
,
'keep_top_k'
,
int
,
'mutliclass_nms'
)
check_type
(
nms_threshold
,
'nms_threshold'
,
float
,
'multiclass_nms'
)
check_type
(
normalized
,
'normalized'
,
bool
,
'multiclass_nms'
)
check_type
(
nms_eta
,
'nms_eta'
,
float
,
'multiclass_nms'
)
check_type
(
background_label
,
'background_label'
,
int
,
'multiclass_nms'
)
helper
=
LayerHelper
(
'multiclass_nms'
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
bboxes
.
dtype
)
...
...
@@ -3192,7 +3204,6 @@ def multiclass_nms(bboxes,
'nms_threshold'
:
nms_threshold
,
'nms_eta'
:
nms_eta
,
'keep_top_k'
:
keep_top_k
,
'nms_eta'
:
nms_eta
,
'normalized'
:
normalized
},
outputs
=
{
'Out'
:
output
})
...
...
python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
浏览文件 @
a9fe09f8
...
...
@@ -17,6 +17,8 @@ import unittest
import
numpy
as
np
import
copy
from
op_test
import
OpTest
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
def
softmax
(
x
):
...
...
@@ -487,5 +489,36 @@ class TestMulticlassNMS2LoDNoOutput(TestMulticlassNMS2LoDInput):
self
.
score_threshold
=
2.0
class
TestMulticlassNMSError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
M
=
1200
N
=
7
C
=
21
BOX_SIZE
=
4
boxes_np
=
np
.
random
.
random
((
M
,
C
,
BOX_SIZE
)).
astype
(
'float32'
)
scores
=
np
.
random
.
random
((
N
*
M
,
C
)).
astype
(
'float32'
)
scores
=
np
.
apply_along_axis
(
softmax
,
1
,
scores
)
scores
=
np
.
reshape
(
scores
,
(
N
,
M
,
C
))
scores_np
=
np
.
transpose
(
scores
,
(
0
,
2
,
1
))
boxes_data
=
fluid
.
data
(
name
=
'bboxes'
,
shape
=
[
M
,
C
,
BOX_SIZE
],
dtype
=
'float32'
)
scores_data
=
fluid
.
data
(
name
=
'scores'
,
shape
=
[
N
,
C
,
M
],
dtype
=
'float32'
)
def
test_bboxes_Variable
():
# the bboxes type must be Variable
fluid
.
layers
.
multiclass_nms
(
bboxes
=
boxes_np
,
scores
=
scores_data
)
def
test_scores_Variable
():
# the bboxes type must be Variable
fluid
.
layers
.
multiclass_nms
(
bboxes
=
boxes_data
,
scores
=
scores_np
)
self
.
assertRaises
(
TypeError
,
test_bboxes_Variable
)
self
.
assertRaises
(
TypeError
,
test_scores_Variable
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录