Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a39240c3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a39240c3
编写于
1月 25, 2019
作者:
J
jerrywgz
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add attr variance for box coder, test=develop
上级
6928f831
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
236 addition
and
53 deletion
+236
-53
paddle/fluid/operators/detection/box_coder_op.cc
paddle/fluid/operators/detection/box_coder_op.cc
+7
-0
paddle/fluid/operators/detection/box_coder_op.cu
paddle/fluid/operators/detection/box_coder_op.cu
+43
-16
paddle/fluid/operators/detection/box_coder_op.h
paddle/fluid/operators/detection/box_coder_op.h
+34
-4
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+108
-18
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+1
-1
python/paddle/fluid/tests/unittests/test_box_coder_op.py
python/paddle/fluid/tests/unittests/test_box_coder_op.py
+43
-14
未找到文件。
paddle/fluid/operators/detection/box_coder_op.cc
浏览文件 @
a39240c3
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include <vector>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -134,6 +135,12 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"when code type is decode_center_size"
)
"when code type is decode_center_size"
)
.
SetDefault
(
0
)
.
SetDefault
(
0
)
.
InEnum
({
0
,
1
});
.
InEnum
({
0
,
1
});
AddAttr
<
std
::
vector
<
float
>>
(
"variance"
,
"(vector<float>, default {}),"
"variance of prior box with shape [4]. PriorBoxVar and variance can"
"not be provided at the same time."
)
.
SetDefault
(
std
::
vector
<
float
>
{});
AddOutput
(
"OutputBox"
,
AddOutput
(
"OutputBox"
,
"(LoDTensor or Tensor) "
"(LoDTensor or Tensor) "
"When code_type is 'encode_center_size', the output tensor of "
"When code_type is 'encode_center_size', the output tensor of "
...
...
paddle/fluid/operators/detection/box_coder_op.cu
浏览文件 @
a39240c3
...
@@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -9,6 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
@@ -16,12 +18,11 @@ namespace paddle {
...
@@ -16,12 +18,11 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_data
,
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_var_data
,
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
const
T
prior_box_var_size
,
const
float
*
variance
,
const
bool
normalized
,
const
int
var_size
,
T
*
output
)
{
const
T
prior_box_var_size
,
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
row
*
col
)
{
if
(
idx
<
row
*
col
)
{
const
int
row_idx
=
idx
/
col
;
const
int
row_idx
=
idx
/
col
;
...
@@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
...
@@ -62,18 +63,20 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
output
[
idx
*
len
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
idx
*
len
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
idx
*
len
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
output
[
idx
*
len
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
output
[
idx
*
len
+
3
]
/=
prior_box_var_data
[
prior_var_offset
+
3
];
output
[
idx
*
len
+
3
]
/=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
var_size
==
4
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
output
[
idx
*
len
+
k
]
/=
static_cast
<
T
>
(
variance
[
k
]);
}
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_data
,
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_var_data
,
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
const
T
prior_box_var_size
,
const
float
*
variance
,
const
bool
normalized
,
const
int
var_size
,
const
int
axis
,
T
*
output
)
{
const
T
prior_box_var_size
,
const
int
axis
,
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
prior_box_offset
=
0
;
int
prior_box_offset
=
0
;
if
(
idx
<
row
*
col
)
{
if
(
idx
<
row
*
col
)
{
...
@@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
...
@@ -110,6 +113,20 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
target_box_data
[
idx
*
len
+
1
]
*
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_height
+
prior_box_center_y
;
prior_box_center_y
;
}
else
if
(
var_size
==
4
)
{
target_box_width
=
exp
(
static_cast
<
T
>
(
variance
[
2
])
*
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_height
=
exp
(
static_cast
<
T
>
(
variance
[
3
])
*
target_box_data
[
idx
*
len
+
3
])
*
prior_box_height
;
target_box_center_x
=
static_cast
<
T
>
(
variance
[
0
])
*
target_box_data
[
idx
*
len
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
static_cast
<
T
>
(
variance
[
1
])
*
target_box_data
[
idx
*
len
+
1
]
*
prior_box_height
+
prior_box_center_y
;
}
else
{
}
else
{
target_box_width
=
exp
(
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_width
=
exp
(
target_box_data
[
idx
*
len
+
2
])
*
prior_box_width
;
target_box_height
=
target_box_height
=
...
@@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
...
@@ -139,20 +156,30 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
std
::
vector
<
float
>
variance
=
context
.
Attr
<
std
::
vector
<
float
>>
(
"variance"
);
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
const
T
*
prior_box_var_data
=
nullptr
;
auto
prior_box_var_size
=
0
;
auto
prior_box_var_size
=
0
;
if
(
prior_box_var
)
{
if
(
prior_box_var
)
{
PADDLE_ENFORCE
(
variance
.
empty
(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time."
);
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
prior_box_var_size
=
prior_box_var
->
dims
().
size
();
prior_box_var_size
=
prior_box_var
->
dims
().
size
();
}
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
variance
.
size
())
==
4
,
"Size of attribute 'variance' should be 4"
);
}
if
(
target_box
->
lod
().
size
())
{
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1
,
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1
,
"Only support 1 level of LoD."
);
"Only support 1 level of LoD."
);
}
}
const
int
var_size
=
static_cast
<
T
>
(
variance
.
size
());
thrust
::
device_vector
<
float
>
dev_variance
(
variance
.
begin
(),
variance
.
end
());
const
float
*
dev_var_data
=
thrust
::
raw_pointer_cast
(
dev_variance
.
data
());
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
...
@@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
...
@@ -173,11 +200,11 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
EncodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
prior_box_var_size
,
output
);
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
DecodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
prior_box_var_size
,
axis
,
output
);
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
axis
,
output
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/detection/box_coder_op.h
浏览文件 @
a39240c3
...
@@ -11,6 +11,7 @@ limitations under the License. */
...
@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
...
@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -34,7 +35,8 @@ class BoxCoderKernel : public framework::OpKernel<T> {
void
EncodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
void
EncodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
T
*
output
)
const
{
const
bool
normalized
,
const
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
->
dims
()[
1
];
int64_t
len
=
prior_box
->
dims
()[
1
];
...
@@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -85,6 +87,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output
[
offset
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
offset
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
offset
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
output
[
offset
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
output
[
offset
+
3
]
/=
prior_box_var_data
[
prior_var_offset
+
3
];
output
[
offset
+
3
]
/=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
!
(
variance
.
empty
()))
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
output
[
offset
+
k
]
/=
static_cast
<
T
>
(
variance
[
k
]);
}
}
}
}
}
}
}
...
@@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -93,7 +99,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
const
int
axis
,
const
bool
normalized
,
const
int
axis
,
T
*
output
)
const
{
const
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
target_box
->
dims
()[
1
];
int64_t
col
=
target_box
->
dims
()[
1
];
int64_t
len
=
target_box
->
dims
()[
2
];
int64_t
len
=
target_box
->
dims
()[
2
];
...
@@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -149,6 +155,20 @@ class BoxCoderKernel : public framework::OpKernel<T> {
std
::
exp
(
prior_box_var_data
[
prior_var_offset
+
3
]
*
std
::
exp
(
prior_box_var_data
[
prior_var_offset
+
3
]
*
target_box_data
[
offset
+
3
])
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
prior_box_height
;
}
else
if
(
!
(
variance
.
empty
()))
{
target_box_center_x
=
static_cast
<
T
>
(
variance
[
0
])
*
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
static_cast
<
T
>
(
variance
[
1
])
*
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
static_cast
<
T
>
(
variance
[
2
])
*
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
static_cast
<
T
>
(
variance
[
3
])
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
}
else
{
}
else
{
target_box_center_x
=
target_box_center_x
=
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
...
@@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -175,11 +195,21 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
std
::
vector
<
float
>
variance
=
context
.
Attr
<
std
::
vector
<
float
>>
(
"variance"
);
const
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
const
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
if
(
target_box
->
lod
().
size
())
{
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1UL
,
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1UL
,
"Only support 1 level of LoD."
);
"Only support 1 level of LoD."
);
}
}
if
(
prior_box_var
)
{
PADDLE_ENFORCE
(
variance
.
empty
(),
"Input 'PriorBoxVar' and attribute 'variance' should not"
"be used at the same time."
);
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
variance
.
size
())
==
4
,
"Size of attribute 'variance' should be 4"
);
}
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
...
@@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
...
@@ -195,10 +225,10 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T
*
output
=
output_box
->
data
<
T
>
();
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
EncodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
output
);
variance
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
axis
,
DecodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
axis
,
output
);
variance
,
output
);
}
}
}
}
};
};
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
a39240c3
...
@@ -346,18 +346,104 @@ def box_coder(prior_box,
...
@@ -346,18 +346,104 @@ def box_coder(prior_box,
name
=
None
,
name
=
None
,
axis
=
0
):
axis
=
0
):
"""
"""
${comment}
**Box Coder Layer**
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
.. math::
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = \log(
\a
bs(tw / pw)) / pwv
oh = \log(
\a
bs(th / ph)) / phv
The Decoding schema described below:
.. math::
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = \exp(pwv * tw) * pw + tw / 2
oh = \exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates,
width and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote
the priorbox's (anchor) center coordinates, width and height. `pxv`,
`pyv`, `pwv`, `phv` denote the variance of the priorbox and `ox`, `oy`,
`ow`, `oh` denote the encoded/decoded coordinates, width and height.
During Box Decoding, two modes for broadcast are supported. Say target
box has shape [N, M, 4], and the shape of prior box can be [N, 4] or
[M, 4]. Then prior box will broadcast to target box along the
assigned axis.
Args:
Args:
prior_box(${prior_box_type}): ${prior_box_comment}
prior_box(Variable): Box list prior_box is a 2-D Tensor with shape
prior_box_var(${prior_box_var_type}): ${prior_box_var_comment}
[M, 4] holds M boxes, each box is represented as
target_box(${target_box_type}): ${target_box_comment}
[xmin, ymin, xmax, ymax], [xmin, ymin] is the
code_type(${code_type_type}): ${code_type_comment}
left top coordinate of the anchor box, if the
box_normalized(${box_normalized_type}): ${box_normalized_comment}
input is image feature map, they are close to
axis(${axis_type}): ${axis_comment}
the origin of the coordinate system. [xmax, ymax]
is the right bottom coordinate of the anchor box.
prior_box_var(Variable|list): prior_box_var supports two types of input.
One is variable with shape [M, 4] holds M group.
The other one is list consist of 4 elements
shared by all boxes.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'.
This input also can be a 3-D Tensor with shape
[N, M, 4] when code_type is 'decode_center_size'.
Each box is represented as
[xmin, ymin, xmax, ymax]. This tensor can
contain LoD information to represent a batch
of inputs.
code_type(string): The code type used with the target box. It can be
encode_center_size or decode_center_size
box_normalized(int): Whether treat the priorbox as a noramlized box.
Set true by default.
name(string): The name of box coder.
axis(int): Which axis in PriorBox to broadcast for box decode,
for example, if axis is 0 and TargetBox has shape
[N, M, 4] and PriorBox has shape [M, 4], then PriorBox
will broadcast to [N, M, 4] for decoding. It is only valid
when code type is decode_center_size. Set 0 by default.
Returns:
Returns:
output_box(${output_box_type}): ${output_box_comment}
output_box(Variable): When code_type is 'encode_center_size', the
output tensor of box_coder_op with shape
[N, M, 4] representing the result of N target
boxes encoded with M Prior boxes and variances.
When code_type is 'decode_center_size',
N represents the batch size and M represents
the number of deocded boxes.
Examples:
.. code-block:: python
prior_box = fluid.layers.data(name='prior_box',
shape=[512, 4],
dtype='float32',
append_batch_size=False)
target_box = fluid.layers.data(name='target_box',
shape=[512,81,4],
dtype='float32',
append_batch_size=False)
output = fluid.layers.box_coder(prior_box=prior_box,
prior_box_var=[0.1,0.1,0.2,0.2],
target_box=target_box,
code_type="decode_center_size",
box_normalized=False,
axis=1)
"""
"""
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
...
@@ -368,18 +454,22 @@ def box_coder(prior_box,
...
@@ -368,18 +454,22 @@ def box_coder(prior_box,
output_box
=
helper
.
create_variable
(
output_box
=
helper
.
create_variable
(
name
=
name
,
dtype
=
prior_box
.
dtype
,
persistable
=
False
)
name
=
name
,
dtype
=
prior_box
.
dtype
,
persistable
=
False
)
helper
.
append_op
(
inputs
=
{
"PriorBox"
:
prior_box
,
"TargetBox"
:
target_box
}
type
=
"box_coder"
,
attrs
=
{
inputs
=
{
"PriorBox"
:
prior_box
,
"PriorBoxVar"
:
prior_box_var
,
"TargetBox"
:
target_box
},
attrs
=
{
"code_type"
:
code_type
,
"code_type"
:
code_type
,
"box_normalized"
:
box_normalized
,
"box_normalized"
:
box_normalized
,
"axis"
:
axis
"axis"
:
axis
},
}
if
isinstance
(
prior_box_var
,
Variable
):
inputs
[
'PriorBoxVar'
]
=
prior_box_var
elif
isinstance
(
prior_box_var
,
list
):
attrs
[
'variance'
]
=
prior_box_var
else
:
raise
TypeError
(
"Input variance of box_coder must be Variable or lisz"
)
helper
.
append_op
(
type
=
"box_coder"
,
inputs
=
inputs
,
attrs
=
attrs
,
outputs
=
{
"OutputBox"
:
output_box
})
outputs
=
{
"OutputBox"
:
output_box
})
return
output_box
return
output_box
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
a39240c3
...
@@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase):
...
@@ -59,7 +59,7 @@ class TestDetection(unittest.TestCase):
iou
=
layers
.
iou_similarity
(
x
=
x
,
y
=
y
)
iou
=
layers
.
iou_similarity
(
x
=
x
,
y
=
y
)
bcoder
=
layers
.
box_coder
(
bcoder
=
layers
.
box_coder
(
prior_box
=
x
,
prior_box
=
x
,
prior_box_var
=
y
,
prior_box_var
=
[
0.2
,
0.3
,
0.3
,
0.2
]
,
target_box
=
z
,
target_box
=
z
,
code_type
=
'encode_center_size'
)
code_type
=
'encode_center_size'
)
self
.
assertIsNotNone
(
iou
)
self
.
assertIsNotNone
(
iou
)
...
...
python/paddle/fluid/tests/unittests/test_box_coder_op.py
浏览文件 @
a39240c3
...
@@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest):
...
@@ -106,9 +106,9 @@ class TestBoxCoderOp(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
,
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
code_type
=
"DecodeCenterSize"
box_normalized
=
False
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
@@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest):
...
@@ -132,9 +132,9 @@ class TestBoxCoderOpWithOneRankVar(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
6
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
3
,
6
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
code_type
=
"DecodeCenterSize"
box_normalized
=
False
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
@@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
...
@@ -159,9 +159,9 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
self
.
op_type
=
"box_coder"
lod
=
[[
0
,
1
,
2
,
3
,
4
,
5
]]
lod
=
[[
0
,
1
,
2
,
3
,
4
,
5
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
ones
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
ones
((
81
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
,
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
code_type
=
"DecodeCenterSize"
box_normalized
=
False
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
@@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest):
...
@@ -184,10 +184,10 @@ class TestBoxCoderOpWithLoD(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
self
.
op_type
=
"box_coder"
lod
=
[[
4
,
8
,
8
]]
lod
=
[[
10
,
20
,
20
]]
prior_box
=
np
.
random
.
random
((
1
0
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
2
0
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
1
0
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
2
0
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
2
0
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
0
,
4
)).
astype
(
'float32'
)
code_type
=
"EncodeCenterSize"
code_type
=
"EncodeCenterSize"
box_normalized
=
True
box_normalized
=
True
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
...
@@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest):
...
@@ -209,9 +209,9 @@ class TestBoxCoderOpWithAxis(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
5
,
4
)).
astype
(
'float32'
)
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
5
,
6
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
30
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
code_type
=
"DecodeCenterSize"
box_normalized
=
False
box_normalized
=
False
axis
=
1
axis
=
1
...
@@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest):
...
@@ -231,5 +231,34 @@ class TestBoxCoderOpWithAxis(OpTest):
self
.
outputs
=
{
'OutputBox'
:
output_box
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithVariance
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
30
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
axis
=
1
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
,
box_normalized
,
axis
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
,
'box_normalized'
:
False
,
'variance'
:
prior_box_var
.
astype
(
np
.
float
).
flatten
(),
'axis'
:
axis
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录