Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
72ee3c62
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72ee3c62
编写于
1月 30, 2019
作者:
J
jerrywgz
提交者:
GitHub
1月 30, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15398 from jerrywgz/add_axis_for_boxcoder
Add axis for boxcoder
上级
d3eeb92b
cee2e1b0
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
533 addition
and
213 deletion
+533
-213
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/detection/box_coder_op.cc
paddle/fluid/operators/detection/box_coder_op.cc
+63
-18
paddle/fluid/operators/detection/box_coder_op.cu
paddle/fluid/operators/detection/box_coder_op.cu
+98
-54
paddle/fluid/operators/detection/box_coder_op.h
paddle/fluid/operators/detection/box_coder_op.h
+82
-46
paddle/fluid/operators/slice_op.cc
paddle/fluid/operators/slice_op.cc
+3
-0
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+110
-15
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+13
-0
python/paddle/fluid/tests/unittests/test_box_coder_op.py
python/paddle/fluid/tests/unittests/test_box_coder_op.py
+163
-79
未找到文件。
paddle/fluid/API.spec
浏览文件 @
72ee3c62
...
...
@@ -322,7 +322,7 @@ paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_class
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'
], varargs=None, keywords=None, defaults=('encode_center_size', True, None
))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'
, 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0
))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
...
...
paddle/fluid/operators/detection/box_coder_op.cc
浏览文件 @
72ee3c62
...
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include <vector>
namespace
paddle
{
namespace
operators
{
...
...
@@ -32,32 +33,57 @@ class BoxCoderOp : public framework::OperatorWithKernel {
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
prior_box_dims
.
size
(),
2
,
"The rank of Input
of PriorBoxVar
must be 2"
);
"The rank of Input
PriorBox
must be 2"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
[
1
],
4
,
"The shape of PriorBox is [N, 4]"
);
if
(
ctx
->
HasInput
(
"PriorBoxVar"
))
{
auto
prior_box_var_dims
=
ctx
->
GetInputDim
(
"PriorBoxVar"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
);
PADDLE_ENFORCE
(
prior_box_var_dims
.
size
()
==
1
||
prior_box_var_dims
.
size
()
==
2
,
"Input(PriorBoxVar) of BoxCoderOp should be 1 or 2."
);
if
(
prior_box_var_dims
.
size
()
==
1
)
{
PADDLE_ENFORCE_EQ
(
prior_box_var_dims
[
0
],
4
,
"The 1st dimension of Input(PriorBoxVar) should be 4"
"when the rank is 1."
);
}
else
{
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
,
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox when the rank is 2.)"
);
}
}
}
auto
code_type
=
GetBoxCodeType
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"code_type"
));
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
2
,
"The rank of Input of TargetBox must be 2"
);
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
4
,
"The shape of TargetBox is [M, 4]"
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
3
,
"The rank of Input of TargetBox must be 3"
);
auto
code_type
=
GetBoxCodeType
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"code_type"
));
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
2
,
"The rank of Input TargetBox must be 2"
);
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
4
,
"The shape of TargetBox is [M, 4]"
);
ctx
->
SetOutputDim
(
"OutputBox"
,
framework
::
make_ddim
({
target_box_dims
[
0
],
prior_box_dims
[
0
],
4
}));
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
3
,
"The rank of Input TargetBox must be 3"
);
if
(
axis
==
0
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
prior_box_dims
[
0
]);
PADDLE_ENFORCE_EQ
(
target_box_dims
[
2
],
prior_box_dims
[
1
]);
}
else
if
(
axis
==
1
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
[
0
],
prior_box_dims
[
0
]);
}
else
{
PADDLE_THROW
(
"axis must be 0 or 1."
);
}
PADDLE_ENFORCE_EQ
(
target_box_dims
[
2
],
prior_box_dims
[
1
]);
ctx
->
ShareDim
(
"TargetBox"
,
/*->*/
"OutputBox"
);
}
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
&&
axis
==
1
)
{
ctx
->
ShareLoD
(
"PriorBox"
,
/*->*/
"OutputBox"
);
}
else
{
ctx
->
ShareLoD
(
"TargetBox"
,
/*->*/
"OutputBox"
);
}
ctx
->
SetOutputDim
(
"OutputBox"
,
framework
::
make_ddim
({
target_box_dims
[
0
],
prior_box_dims
[
0
],
4
}));
ctx
->
ShareLoD
(
"TargetBox"
,
/*->*/
"OutputBox"
);
}
};
...
...
@@ -100,6 +126,21 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default true) "
"whether treat the priorbox as a noramlized box"
)
.
SetDefault
(
true
);
AddAttr
<
int
>
(
"axis"
,
"(int, default 0)"
"which axis in PriorBox to broadcast for box decode,"
"for example, if axis is 0 and TargetBox has shape"
"[N, M, 4] and PriorBox has shape [M, 4], then PriorBox "
"will broadcast to [N, M, 4] for decoding. It is only valid"
"when code type is decode_center_size"
)
.
SetDefault
(
0
)
.
InEnum
({
0
,
1
});
AddAttr
<
std
::
vector
<
float
>>
(
"variance"
,
"(vector<float>, default {}),"
"variance of prior box with shape [4]. PriorBoxVar and variance can"
"not be provided at the same time."
)
.
SetDefault
(
std
::
vector
<
float
>
{});
AddOutput
(
"OutputBox"
,
"(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of "
...
...
@@ -138,7 +179,11 @@ where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the
encoded/decoded coordinates, width and height.
encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target box has
shape [N, M, 4], and the shape of prior box can be [N, 4] or [M, 4]. Then prior
box will broadcast to target box along the assigned axis.
)DOC"
);
}
};
...
...
paddle/fluid/operators/detection/box_coder_op.cu
浏览文件 @
72ee3c62
...
...
@@ -9,6 +9,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -16,11 +19,11 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
T
*
output
)
{
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
const
T
prior_box_var_size
,
const
float
*
variance
,
const
int
var_size
,
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
row
*
col
)
{
const
int
row_idx
=
idx
/
col
;
...
...
@@ -30,11 +33,9 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
T
prior_box_height
=
prior_box_data
[
col_idx
*
len
+
3
]
-
prior_box_data
[
col_idx
*
len
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
(
prior_box_data
[
col_idx
*
len
+
2
]
+
prior_box_data
[
col_idx
*
len
])
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
col_idx
*
len
+
3
]
+
prior_box_data
[
col_idx
*
len
+
1
])
/
2
;
T
prior_box_center_x
=
prior_box_data
[
col_idx
*
len
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
prior_box_data
[
col_idx
*
len
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
(
target_box_data
[
row_idx
*
len
+
2
]
+
target_box_data
[
row_idx
*
len
])
/
...
...
@@ -55,58 +56,73 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
output
[
idx
*
len
+
2
]
=
log
(
fabs
(
target_box_width
/
prior_box_width
));
output
[
idx
*
len
+
3
]
=
log
(
fabs
(
target_box_height
/
prior_box_height
));
if
(
prior_box_var_data
)
{
output
[
idx
*
len
]
/=
prior_box_var_data
[
col_idx
*
len
];
output
[
idx
*
len
+
1
]
/=
prior_box_var_data
[
col_idx
*
len
+
1
];
output
[
idx
*
len
+
2
]
/=
prior_box_var_data
[
col_idx
*
len
+
2
];
output
[
idx
*
len
+
3
]
/=
prior_box_var_data
[
col_idx
*
len
+
3
];
int
prior_var_offset
=
0
;
if
(
prior_box_var_size
==
2
)
{
prior_var_offset
=
col_idx
*
len
;
}
output
[
idx
*
len
]
/=
prior_box_var_data
[
prior_var_offset
];
output
[
idx
*
len
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
idx
*
len
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
output
[
idx
*
len
+
3
]
/=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
var_size
==
4
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
output
[
idx
*
len
+
k
]
/=
static_cast
<
T
>
(
variance
[
k
]);
}
}
}
}
template
<
typename
T
>
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
T
*
output
)
{
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
const
T
prior_box_var_size
,
const
float
*
variance
,
const
int
var_size
,
const
int
axis
,
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
prior_box_offset
=
0
;
if
(
idx
<
row
*
col
)
{
const
int
col_idx
=
idx
%
col
;
T
prior_box_width
=
prior_box_data
[
col_idx
*
len
+
2
]
-
prior_box_data
[
col_idx
*
len
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
col_idx
*
len
+
3
]
-
prior_box_data
[
col_idx
*
len
+
1
]
+
const
int
row_idx
=
idx
/
col
;
prior_box_offset
=
axis
==
0
?
col_idx
*
len
:
row_idx
*
len
;
T
prior_box_width
=
prior_box_data
[
prior_box_offset
+
2
]
-
prior_box_data
[
prior_box_offset
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
prior_box_offset
+
3
]
-
prior_box_data
[
prior_box_offset
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
(
prior_box_data
[
col_idx
*
len
+
2
]
+
prior_box_data
[
col_idx
*
len
])
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
col_idx
*
len
+
3
]
+
prior_box_data
[
col_idx
*
len
+
1
])
/
2
;
prior_box_data
[
prior_box_offset
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
prior_box_data
[
prior_box_offset
+
1
]
+
prior_box_height
/
2
;
T
target_box_width
,
target_box_height
;
T
target_box_center_x
,
target_box_center_y
;
T
box_var_x
=
T
(
1
),
box_var_y
=
T
(
1
);
T
box_var_w
=
T
(
1
),
box_var_h
=
T
(
1
);
if
(
prior_box_var_data
)
{
target_box_width
=
exp
(
prior_box_var_data
[
col_idx
*
len
+
2
]
*
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_height
=
exp
(
prior_box_var_data
[
col_idx
*
len
+
3
]
*
target_box_data
[
idx
*
len
+
3
])
*
prior_box_height
;
target_box_center_x
=
prior_box_var_data
[
col_idx
*
len
]
*
target_box_data
[
idx
*
len
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
prior_box_var_data
[
col_idx
*
len
+
1
]
*
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_center_y
;
}
else
{
target_box_width
=
exp
(
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_height
=
exp
(
target_box_data
[
idx
*
len
+
3
])
*
prior_box_height
;
target_box_center_x
=
target_box_data
[
idx
*
len
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_center_y
;
int
prior_var_offset
=
0
;
if
(
prior_box_var_size
==
2
)
{
prior_var_offset
=
axis
==
0
?
col_idx
*
len
:
row_idx
*
len
;
}
box_var_x
=
prior_box_var_data
[
prior_var_offset
];
box_var_y
=
prior_box_var_data
[
prior_var_offset
+
1
];
box_var_w
=
prior_box_var_data
[
prior_var_offset
+
2
];
box_var_h
=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
var_size
==
4
)
{
box_var_x
=
static_cast
<
T
>
(
variance
[
0
]);
box_var_y
=
static_cast
<
T
>
(
variance
[
1
]);
box_var_w
=
static_cast
<
T
>
(
variance
[
2
]);
box_var_h
=
static_cast
<
T
>
(
variance
[
3
]);
}
target_box_width
=
exp
(
box_var_w
*
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_height
=
exp
(
box_var_h
*
target_box_data
[
idx
*
len
+
3
])
*
prior_box_height
;
target_box_center_x
=
box_var_x
*
target_box_data
[
idx
*
len
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
box_var_y
*
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_center_y
;
output
[
idx
*
len
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
idx
*
len
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
...
...
@@ -127,36 +143,64 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
std
::
vector
<
float
>
variance
=
context
.
Attr
<
std
::
vector
<
float
>>
(
"variance"
);
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
if
(
prior_box_var
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
auto
prior_box_var_size
=
0
;
if
(
prior_box_var
)
{
PADDLE_ENFORCE
(
variance
.
empty
(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time."
);
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
prior_box_var_size
=
prior_box_var
->
dims
().
size
();
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
variance
.
size
())
==
4
,
"Size of attribute 'variance' should be 4"
);
}
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1
,
"Only support 1 level of LoD."
);
}
const
int
var_size
=
static_cast
<
int
>
(
variance
.
size
());
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
auto
row
=
target_box
->
dims
()[
0
];
auto
col
=
prior_box
->
dims
()[
0
];
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
col
=
target_box
->
dims
()[
1
];
}
auto
len
=
prior_box
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
row
*
col
+
block
-
1
)
/
block
;
auto
&
device_ctx
=
context
.
cuda_device_context
();
auto
&
allocator
=
platform
::
DeviceTemporaryAllocator
::
Instance
().
Get
(
device_ctx
);
int
bytes
=
var_size
*
sizeof
(
float
);
auto
dev_var
=
allocator
.
Allocate
(
bytes
);
float
*
dev_var_data
=
reinterpret_cast
<
float
*>
(
dev_var
->
ptr
());
auto
cplace
=
platform
::
CPUPlace
();
const
auto
gplace
=
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
());
memory
::
Copy
(
gplace
,
dev_var_data
,
cplace
,
&
variance
[
0
],
bytes
,
device_ctx
.
stream
());
output_box
->
mutable_data
<
T
>
({
row
,
col
,
len
},
context
.
GetPlace
());
T
*
output
=
output_box
->
data
<
T
>
();
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
output
);
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
output
);
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
axis
,
output
);
}
}
};
...
...
paddle/fluid/operators/detection/box_coder_op.h
浏览文件 @
72ee3c62
...
...
@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
...
...
@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void
EncodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
T
*
output
)
const
{
const
bool
normalized
,
const
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
->
dims
()[
1
];
...
...
@@ -53,10 +55,9 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T
prior_box_height
=
prior_box_data
[
j
*
len
+
3
]
-
prior_box_data
[
j
*
len
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
(
prior_box_data
[
j
*
len
+
2
]
+
prior_box_data
[
j
*
len
])
/
2
;
T
prior_box_center_x
=
prior_box_data
[
j
*
len
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
j
*
len
+
3
]
+
prior_box_data
[
j
*
len
+
1
])
/
2
;
prior_box_data
[
j
*
len
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
(
target_box_data
[
i
*
len
+
2
]
+
target_box_data
[
i
*
len
])
/
2
;
...
...
@@ -78,10 +79,18 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output
[
offset
+
3
]
=
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
));
if
(
prior_box_var
)
{
output
[
offset
]
/=
prior_box_var_data
[
j
*
len
];
output
[
offset
+
1
]
/=
prior_box_var_data
[
j
*
len
+
1
];
output
[
offset
+
2
]
/=
prior_box_var_data
[
j
*
len
+
2
];
output
[
offset
+
3
]
/=
prior_box_var_data
[
j
*
len
+
3
];
int
prior_var_offset
=
0
;
if
(
prior_box_var
->
dims
().
size
()
==
2
)
{
prior_var_offset
=
j
*
len
;
}
output
[
offset
]
/=
prior_box_var_data
[
prior_var_offset
];
output
[
offset
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
offset
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
output
[
offset
+
3
]
/=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
!
(
variance
.
empty
()))
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
output
[
offset
+
k
]
/=
static_cast
<
T
>
(
variance
[
k
]);
}
}
}
}
...
...
@@ -89,58 +98,71 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void
DecodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
T
*
output
)
const
{
const
bool
normalized
,
const
int
axis
,
const
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
->
dims
()[
1
];
int64_t
col
=
target_box
->
dims
()[
1
];
int64_t
len
=
target_box
->
dims
()[
2
];
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
if
(
prior_box_var
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
int
prior_box_offset
=
0
;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
T
prior_box_width
=
prior_box_data
[
j
*
len
+
2
]
-
prior_box_data
[
j
*
len
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
j
*
len
+
3
]
-
prior_box_data
[
j
*
len
+
1
]
+
if
(
axis
==
0
)
{
prior_box_offset
=
j
*
len
;
}
else
if
(
axis
==
1
)
{
prior_box_offset
=
i
*
len
;
}
T
prior_box_width
=
prior_box_data
[
prior_box_offset
+
2
]
-
prior_box_data
[
prior_box_offset
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
prior_box_offset
+
3
]
-
prior_box_data
[
prior_box_offset
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
(
prior_box_data
[
j
*
len
+
2
]
+
prior_box_data
[
j
*
len
])
/
2
;
prior_box_data
[
prior_box_offset
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
j
*
len
+
3
]
+
prior_box_data
[
j
*
len
+
1
])
/
2
;
prior_box_data
[
prior_box_offset
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
0
,
target_box_center_y
=
0
;
T
target_box_width
=
0
,
target_box_height
=
0
;
T
box_var_x
=
T
(
1
),
box_var_y
=
T
(
1
);
T
box_var_w
=
T
(
1
),
box_var_h
=
T
(
1
);
if
(
prior_box_var
)
{
target_box_center_x
=
prior_box_var_data
[
j
*
len
]
*
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
prior_box_var_data
[
j
*
len
+
1
]
*
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
prior_box_var_data
[
j
*
len
+
2
]
*
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
prior_box_var_data
[
j
*
len
+
3
]
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
}
else
{
target_box_center_x
=
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
target_box_data
[
offset
+
3
])
*
prior_box_height
;
int
prior_var_offset
=
0
;
if
(
prior_box_var
->
dims
().
size
()
==
2
)
{
if
(
axis
==
0
)
prior_var_offset
=
j
*
len
;
else
if
(
axis
==
1
)
prior_var_offset
=
i
*
len
;
}
box_var_x
=
prior_box_var_data
[
prior_var_offset
];
box_var_y
=
prior_box_var_data
[
prior_var_offset
+
1
];
box_var_w
=
prior_box_var_data
[
prior_var_offset
+
2
];
box_var_h
=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
!
(
variance
.
empty
()))
{
box_var_x
=
static_cast
<
T
>
(
variance
[
0
]);
box_var_y
=
static_cast
<
T
>
(
variance
[
1
]);
box_var_w
=
static_cast
<
T
>
(
variance
[
2
]);
box_var_h
=
static_cast
<
T
>
(
variance
[
3
]);
}
target_box_center_x
=
box_var_x
*
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
box_var_y
*
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
box_var_w
*
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
box_var_h
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
output
[
offset
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
offset
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
...
...
@@ -157,26 +179,40 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
std
::
vector
<
float
>
variance
=
context
.
Attr
<
std
::
vector
<
float
>>
(
"variance"
);
const
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1UL
,
"Only support 1 level of LoD."
);
}
if
(
prior_box_var
)
{
PADDLE_ENFORCE
(
variance
.
empty
(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time."
);
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
variance
.
size
())
==
4
,
"Size of attribute 'variance' should be 4"
);
}
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
auto
row
=
target_box
->
dims
()[
0
];
auto
col
=
prior_box
->
dims
()[
0
];
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
col
=
target_box
->
dims
()[
1
];
}
auto
len
=
prior_box
->
dims
()[
1
];
output_box
->
mutable_data
<
T
>
({
row
,
col
,
len
},
context
.
GetPlace
());
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
output
);
variance
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
output
);
DecodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
axis
,
variance
,
output
);
}
}
};
...
...
paddle/fluid/operators/slice_op.cc
浏览文件 @
72ee3c62
...
...
@@ -54,6 +54,9 @@ class SliceOp : public framework::OperatorWithKernel {
out_dims
[
axes
[
i
]]
=
end
-
start
;
}
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
if
(
axes
[
0
]
!=
0
)
{
ctx
->
ShareLoD
(
"Input"
,
/*->*/
"Out"
);
}
}
protected:
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
72ee3c62
...
...
@@ -346,19 +346,107 @@ def box_coder(prior_box,
target_box
,
code_type
=
"encode_center_size"
,
box_normalized
=
True
,
name
=
None
):
name
=
None
,
axis
=
0
):
"""
${comment}
**Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(
\a
bs(tw / pw)) / pwv
oh = \log(
\a
bs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
prior_box(${prior_box_type}): ${prior_box_comment}
prior_box_var(${prior_box_var_type}): ${prior_box_var_comment}
target_box(${target_box_type}): ${target_box_comment}
code_type(${code_type_type}): ${code_type_comment}
box_normalized(${box_normalized_type}): ${box_normalized_comment}
prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
[M, 4] holds M boxes, each box is represented as
[xmin, ymin, xmax, ymax], [xmin, ymin] is the
left top coordinate of the anchor box, if the
input is image feature map, they are close to
the origin of the coordinate system. [xmax, ymax]
is the right bottom coordinate of the anchor box.
prior_box_var(Variable|list): prior_box_var supports two types of input.
One is variable with shape [M, 4] holds M group.
The other one is list consist of 4 elements
shared by all boxes.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'.
This input also can be a 3-D Tensor with shape
[N, M, 4] when code_type is 'decode_center_size'.
Each box is represented as
[xmin, ymin, xmax, ymax]. This tensor can
contain LoD information to represent a batch
of inputs.
code_type(string): The code type used with the target box. It can be
encode_center_size or decode_center_size
box_normalized(int): Whether treat the priorbox as a noramlized box.
Set true by default.
name(string): The name of box coder.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape
[N, M, 4] and PriorBox has shape [M, 4], then PriorBox
will broadcast to [N, M, 4] for decoding. It is only valid
when code type is decode_center_size. Set 0 by default.
Returns:
output_box(${output_box_type}): ${output_box_comment}
output_box(Variable): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape
[N, M, 4] representing the result of N target
boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size',
N represents the batch size and M represents
the number of deocded boxes.
Examples:
.. code-block:: python
prior_box = fluid.layers.data(name='prior_box',
shape=[512, 4],
dtype='float32',
append_batch_size=False)
target_box = fluid.layers.data(name='target_box',
shape=[512,81,4],
dtype='float32',
append_batch_size=False)
output = fluid.layers.box_coder(prior_box=prior_box,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box,
code_type="decode_center_size",
box_normalized=False,
axis=1)
"""
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
...
...
@@ -369,15 +457,22 @@ def box_coder(prior_box,
output_box
=
helper
.
create_variable
(
name
=
name
,
dtype
=
prior_box
.
dtype
,
persistable
=
False
)
inputs
=
{
"PriorBox"
:
prior_box
,
"TargetBox"
:
target_box
}
attrs
=
{
"code_type"
:
code_type
,
"box_normalized"
:
box_normalized
,
"axis"
:
axis
}
if
isinstance
(
prior_box_var
,
Variable
):
inputs
[
'PriorBoxVar'
]
=
prior_box_var
elif
isinstance
(
prior_box_var
,
list
):
attrs
[
'variance'
]
=
prior_box_var
else
:
raise
TypeError
(
"Input variance of box_coder must be Variable or lisz"
)
helper
.
append_op
(
type
=
"box_coder"
,
inputs
=
{
"PriorBox"
:
prior_box
,
"PriorBoxVar"
:
prior_box_var
,
"TargetBox"
:
target_box
},
attrs
=
{
"code_type"
:
code_type
,
"box_normalized"
:
box_normalized
},
inputs
=
inputs
,
attrs
=
attrs
,
outputs
=
{
"OutputBox"
:
output_box
})
return
output_box
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
72ee3c62
...
...
@@ -50,6 +50,19 @@ class TestDetection(unittest.TestCase):
self
.
assertEqual
(
out
.
shape
[
-
1
],
6
)
print
(
str
(
program
))
def
test_box_coder_api
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
4
],
dtype
=
'float32'
)
y
=
layers
.
data
(
name
=
'z'
,
shape
=
[
4
],
dtype
=
'float32'
,
lod_level
=
1
)
bcoder
=
layers
.
box_coder
(
prior_box
=
x
,
prior_box_var
=
[
0.1
,
0.2
,
0.1
,
0.2
],
target_box
=
y
,
code_type
=
'encode_center_size'
)
self
.
assertIsNotNone
(
bcoder
)
print
(
str
(
program
))
def
test_detection_api
(
self
):
program
=
Program
()
with
program_guard
(
program
):
...
...
python/paddle/fluid/tests/unittests/test_box_coder_op.py
浏览文件 @
72ee3c62
...
...
@@ -21,80 +21,80 @@ import math
from
op_test
import
OpTest
def
box_coder
(
target_box
,
prior_box
,
prior_box_var
,
output_box
,
code_type
,
box_normalized
):
prior_box_x
=
(
(
prior_box
[:,
2
]
+
prior_box
[:,
0
])
/
2
).
reshape
(
1
,
prior_box
.
shape
[
0
])
prior_box_y
=
(
(
prior_box
[:,
3
]
+
prior_box
[:,
1
])
/
2
).
reshape
(
1
,
prior_box
.
shape
[
0
])
prior_box_width
=
(
(
prior_box
[:,
2
]
-
prior_box
[:,
0
])).
reshape
(
1
,
prior_box
.
shape
[
0
])
prior_box_height
=
(
(
prior_box
[:,
3
]
-
prior_box
[:,
1
])).
reshape
(
1
,
prior_box
.
shape
[
0
])
prior_box_var
=
prior_box_var
.
reshape
(
1
,
prior_box_var
.
shape
[
0
],
prior_box_var
.
shape
[
1
])
if
not
box_normalized
:
prior_box_height
=
prior_box_height
+
1
prior_box_width
=
prior_box_width
+
1
if
(
code_type
==
"EncodeCenterSize"
):
target_box_x
=
((
target_box
[:,
2
]
+
target_box
[:,
0
])
/
2
).
reshape
(
target_box
.
shape
[
0
],
1
)
target_box_y
=
((
target_box
[:,
3
]
+
target_box
[:,
1
])
/
2
).
reshape
(
target_box
.
shape
[
0
],
1
)
target_box_width
=
((
target_box
[:,
2
]
-
target_box
[:,
0
])).
reshape
(
target_box
.
shape
[
0
],
1
)
target_box_height
=
((
target_box
[:,
3
]
-
target_box
[:,
1
])).
reshape
(
target_box
.
shape
[
0
],
1
)
if
not
box_normalized
:
target_box_height
=
target_box_height
+
1
target_box_width
=
target_box_width
+
1
output_box
[:,:,
0
]
=
(
target_box_x
-
prior_box_x
)
/
prior_box_width
/
\
prior_box_var
[:,:,
0
]
output_box
[:,:,
1
]
=
(
target_box_y
-
prior_box_y
)
/
prior_box_height
/
\
prior_box_var
[:,:,
1
]
output_box
[:,:,
2
]
=
np
.
log
(
np
.
fabs
(
target_box_width
/
prior_box_width
))
/
\
prior_box_var
[:,:,
2
]
output_box
[:,:,
3
]
=
np
.
log
(
np
.
fabs
(
target_box_height
/
prior_box_height
))
/
\
prior_box_var
[:,:,
3
]
elif
(
code_type
==
"DecodeCenterSize"
):
target_box_x
=
prior_box_var
[:,:,
0
]
*
target_box
[:,:,
0
]
*
\
prior_box_width
+
prior_box_x
target_box_y
=
prior_box_var
[:,:,
1
]
*
target_box
[:,:,
1
]
*
\
prior_box_height
+
prior_box_y
target_box_width
=
np
.
exp
(
prior_box_var
[:,:,
2
]
*
target_box
[:,:,
2
])
*
\
prior_box_width
target_box_height
=
np
.
exp
(
prior_box_var
[:,:,
3
]
*
target_box
[:,:,
3
])
*
\
prior_box_height
output_box
[:,
:,
0
]
=
target_box_x
-
target_box_width
/
2
output_box
[:,
:,
1
]
=
target_box_y
-
target_box_height
/
2
output_box
[:,
:,
2
]
=
target_box_x
+
target_box_width
/
2
output_box
[:,
:,
3
]
=
target_box_y
+
target_box_height
/
2
if
not
box_normalized
:
output_box
[:,
:,
2
]
=
output_box
[:,
:,
2
]
-
1
output_box
[:,
:,
3
]
=
output_box
[:,
:,
3
]
-
1
def
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
,
code_type
,
box_normalized
):
n
=
target_box
.
shape
[
0
]
m
=
prior_box
.
shape
[
0
]
def
box_decoder
(
t_box
,
p_box
,
pb_v
,
output_box
,
norm
,
axis
=
0
):
pb_w
=
p_box
[:,
2
]
-
p_box
[:,
0
]
+
(
norm
==
False
)
pb_h
=
p_box
[:,
3
]
-
p_box
[:,
1
]
+
(
norm
==
False
)
pb_x
=
pb_w
*
0.5
+
p_box
[:,
0
]
pb_y
=
pb_h
*
0.5
+
p_box
[:,
1
]
shape
=
(
1
,
p_box
.
shape
[
0
])
if
axis
==
0
else
(
p_box
.
shape
[
0
],
1
)
pb_w
=
pb_w
.
reshape
(
shape
)
pb_h
=
pb_h
.
reshape
(
shape
)
pb_x
=
pb_x
.
reshape
(
shape
)
pb_y
=
pb_y
.
reshape
(
shape
)
if
pb_v
.
ndim
==
2
:
pb_v
=
pb_v
.
reshape
(
1
,
pb_v
.
shape
[
0
],
pb_v
.
shape
[
1
])
if
pb_v
.
ndim
==
1
:
tb_x
=
pb_v
[
0
]
*
t_box
[:,
:,
0
]
*
pb_w
+
pb_x
tb_y
=
pb_v
[
1
]
*
t_box
[:,
:,
1
]
*
pb_h
+
pb_y
tb_w
=
np
.
exp
(
pb_v
[
2
]
*
t_box
[:,
:,
2
])
*
pb_w
tb_h
=
np
.
exp
(
pb_v
[
3
]
*
t_box
[:,
:,
3
])
*
pb_h
else
:
tb_x
=
pb_v
[:,
:,
0
]
*
t_box
[:,
:,
0
]
*
pb_w
+
pb_x
tb_y
=
pb_v
[:,
:,
1
]
*
t_box
[:,
:,
1
]
*
pb_h
+
pb_y
tb_w
=
np
.
exp
(
pb_v
[:,
:,
2
]
*
t_box
[:,
:,
2
])
*
pb_w
tb_h
=
np
.
exp
(
pb_v
[:,
:,
3
]
*
t_box
[:,
:,
3
])
*
pb_h
output_box
[:,
:,
0
]
=
tb_x
-
tb_w
/
2
output_box
[:,
:,
1
]
=
tb_y
-
tb_h
/
2
output_box
[:,
:,
2
]
=
tb_x
+
tb_w
/
2
-
(
not
norm
)
output_box
[:,
:,
3
]
=
tb_y
+
tb_h
/
2
-
(
not
norm
)
def
box_encoder
(
t_box
,
p_box
,
pb_v
,
output_box
,
norm
):
pb_w
=
p_box
[:,
2
]
-
p_box
[:,
0
]
+
(
norm
==
False
)
pb_h
=
p_box
[:,
3
]
-
p_box
[:,
1
]
+
(
norm
==
False
)
pb_x
=
pb_w
*
0.5
+
p_box
[:,
0
]
pb_y
=
pb_h
*
0.5
+
p_box
[:,
1
]
shape
=
(
1
,
p_box
.
shape
[
0
])
pb_w
=
pb_w
.
reshape
(
shape
)
pb_h
=
pb_h
.
reshape
(
shape
)
pb_x
=
pb_x
.
reshape
(
shape
)
pb_y
=
pb_y
.
reshape
(
shape
)
if
pb_v
.
ndim
==
2
:
pb_v
=
pb_v
.
reshape
(
1
,
pb_v
.
shape
[
0
],
pb_v
.
shape
[
1
])
tb_x
=
((
t_box
[:,
2
]
+
t_box
[:,
0
])
/
2
).
reshape
(
t_box
.
shape
[
0
],
1
)
tb_y
=
((
t_box
[:,
3
]
+
t_box
[:,
1
])
/
2
).
reshape
(
t_box
.
shape
[
0
],
1
)
tb_w
=
(
t_box
[:,
2
]
-
t_box
[:,
0
]).
reshape
(
t_box
.
shape
[
0
],
1
)
+
(
not
norm
)
tb_h
=
(
t_box
[:,
3
]
-
t_box
[:,
1
]).
reshape
(
t_box
.
shape
[
0
],
1
)
+
(
not
norm
)
if
pb_v
.
ndim
==
1
:
output_box
[:,
:,
0
]
=
(
tb_x
-
pb_x
)
/
pb_w
/
pb_v
[
0
]
output_box
[:,
:,
1
]
=
(
tb_y
-
pb_y
)
/
pb_h
/
pb_v
[
1
]
output_box
[:,
:,
2
]
=
np
.
log
(
np
.
fabs
(
tb_w
/
pb_w
))
/
pb_v
[
2
]
output_box
[:,
:,
3
]
=
np
.
log
(
np
.
fabs
(
tb_h
/
pb_h
))
/
pb_v
[
3
]
else
:
output_box
[:,
:,
0
]
=
(
tb_x
-
pb_x
)
/
pb_w
/
pb_v
[:,
:,
0
]
output_box
[:,
:,
1
]
=
(
tb_y
-
pb_y
)
/
pb_h
/
pb_v
[:,
:,
1
]
output_box
[:,
:,
2
]
=
np
.
log
(
np
.
fabs
(
tb_w
/
pb_w
))
/
pb_v
[:,
:,
2
]
output_box
[:,
:,
3
]
=
np
.
log
(
np
.
fabs
(
tb_h
/
pb_h
))
/
pb_v
[:,
:,
3
]
def
batch_box_coder
(
p_box
,
pb_v
,
t_box
,
lod
,
code_type
,
norm
,
axis
=
0
):
n
=
t_box
.
shape
[
0
]
m
=
p_box
.
shape
[
0
]
if
code_type
==
"DecodeCenterSize"
:
m
=
t_box
.
shape
[
1
]
output_box
=
np
.
zeros
((
n
,
m
,
4
),
dtype
=
np
.
float32
)
cur_offset
=
0
for
i
in
range
(
len
(
lod
)):
if
(
code_type
==
"EncodeCenterSize"
):
box_coder
(
target_box
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:],
prior_box
,
prior_box_var
,
output_box
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:],
code_type
,
box_normalized
)
box_encoder
(
t_box
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:],
p_box
,
pb_v
,
output_box
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:],
norm
)
elif
(
code_type
==
"DecodeCenterSize"
):
box_coder
(
target_box
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:],
prior_box
,
prior_box_var
,
output_box
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:],
code_type
,
box_normalized
)
box_decoder
(
t_box
,
p_box
,
pb_v
,
output_box
,
norm
,
axis
)
cur_offset
+=
lod
[
i
]
return
output_box
...
...
@@ -106,9 +106,35 @@ class TestBoxCoderOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
,
10
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
,
box_normalized
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'PriorBoxVar'
:
prior_box_var
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
,
'box_normalized'
:
False
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithOneRankVar
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
...
@@ -133,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
0
,
1
,
2
,
3
,
4
,
5
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
ones
((
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
,
10
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
ones
((
81
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
...
@@ -158,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
4
,
8
,
8
]]
prior_box
=
np
.
random
.
random
((
1
0
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
1
0
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
2
0
,
4
)).
astype
(
'float32'
)
lod
=
[[
10
,
20
,
20
]]
prior_box
=
np
.
random
.
random
((
2
0
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
2
0
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
0
,
4
)).
astype
(
'float32'
)
code_type
=
"EncodeCenterSize"
box_normalized
=
True
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
...
@@ -176,5 +202,63 @@ class TestBoxCoderOpWithLoD(OpTest):
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithAxis
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
30
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
axis
=
1
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
,
box_normalized
,
axis
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'PriorBoxVar'
:
prior_box_var
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
,
'box_normalized'
:
False
,
'axis'
:
axis
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithVariance
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
30
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
axis
=
1
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
,
box_normalized
,
axis
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
,
'box_normalized'
:
False
,
'variance'
:
prior_box_var
.
astype
(
np
.
float
).
flatten
(),
'axis'
:
axis
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录