Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d92b2f2d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d92b2f2d
编写于
7月 30, 2022
作者:
Z
zhiboniu
提交者:
GitHub
7月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Phi prior box (#44431)
* phi_prior_box * add float[] support * phi_prior_box_optest * update
上级
46be6854
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
667 addition
and
101 deletion
+667
-101
paddle/fluid/operators/detection/prior_box_op.cc
paddle/fluid/operators/detection/prior_box_op.cc
+8
-79
paddle/fluid/operators/detection/prior_box_op.cu
paddle/fluid/operators/detection/prior_box_op.cu
+0
-5
paddle/phi/api/yaml/generator/api_base.py
paddle/phi/api/yaml/generator/api_base.py
+1
-1
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+8
-0
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+104
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+15
-0
paddle/phi/kernels/cpu/prior_box_kernel.cc
paddle/phi/kernels/cpu/prior_box_kernel.cc
+173
-0
paddle/phi/kernels/gpu/prior_box_kernel.cu
paddle/phi/kernels/gpu/prior_box_kernel.cu
+201
-0
paddle/phi/kernels/prior_box_kernel.h
paddle/phi/kernels/prior_box_kernel.h
+62
-0
paddle/phi/ops/compat/prior_box_sig.cc
paddle/phi/ops/compat/prior_box_sig.cc
+37
-0
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+24
-13
python/paddle/fluid/tests/unittests/test_prior_box_op.py
python/paddle/fluid/tests/unittests/test_prior_box_op.py
+34
-3
未找到文件。
paddle/fluid/operators/detection/prior_box_op.cc
浏览文件 @
d92b2f2d
...
...
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/prior_box_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/binary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
...
...
@@ -28,79 +29,6 @@ class PriorBoxOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"PriorBoxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Image"
),
"Input"
,
"Image"
,
"PriorBoxOp"
);
auto
image_dims
=
ctx
->
GetInputDim
(
"Image"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
image_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The Input(Image) of Op(PriorBoxOp) should be a 4-D Tensor "
"and data format is NCHW. But received Image's dimensions = %d, "
"shape = [%s]."
,
image_dims
.
size
(),
image_dims
));
PADDLE_ENFORCE_EQ
(
input_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The Input(Input) of Op(PriorBoxOp) should be a 4-D Tensor "
"and data format is NCHW. But received Input's dimensions = %d, "
"shape = [%s]."
,
input_dims
.
size
(),
input_dims
));
auto
min_sizes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
float
>>
(
"min_sizes"
);
auto
max_sizes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
float
>>
(
"max_sizes"
);
auto
variances
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
float
>>
(
"variances"
);
auto
aspect_ratios
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
float
>>
(
"aspect_ratios"
);
bool
flip
=
ctx
->
Attrs
().
Get
<
bool
>
(
"flip"
);
std
::
vector
<
float
>
aspect_ratios_vec
;
ExpandAspectRatios
(
aspect_ratios
,
flip
,
&
aspect_ratios_vec
);
size_t
num_priors
=
aspect_ratios_vec
.
size
()
*
min_sizes
.
size
();
if
(
max_sizes
.
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
max_sizes
.
size
(),
min_sizes
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The length of min_size and "
"max_size must be equal. But received: min_size's length is %d, "
"max_size's length is %d."
,
min_sizes
.
size
(),
max_sizes
.
size
()));
num_priors
+=
max_sizes
.
size
();
for
(
size_t
i
=
0
;
i
<
max_sizes
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
max_sizes
[
i
],
min_sizes
[
i
],
platform
::
errors
::
InvalidArgument
(
"max_size[%d] must be greater "
"than min_size[%d]. But received: max_size[%d] is %f, "
"min_size[%d] is %f."
,
i
,
i
,
i
,
max_sizes
[
i
],
i
,
min_sizes
[
i
]));
}
}
std
::
vector
<
int64_t
>
dim_vec
(
4
);
dim_vec
[
0
]
=
input_dims
[
2
];
dim_vec
[
1
]
=
input_dims
[
3
];
dim_vec
[
2
]
=
num_priors
;
dim_vec
[
3
]
=
4
;
ctx
->
SetOutputDim
(
"Boxes"
,
phi
::
make_ddim
(
dim_vec
));
ctx
->
SetOutputDim
(
"Variances"
,
phi
::
make_ddim
(
dim_vec
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -274,17 +202,18 @@ https://arxiv.org/abs/1512.02325.
}
// namespace operators
}
// namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR
(
prior_box
,
PriorBoxInferShapeFunctor
,
PD_INFER_META
(
phi
::
PriorBoxInferMeta
));
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
prior_box
,
ops
::
PriorBoxOp
,
ops
::
PriorBoxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
prior_box
,
ops
::
PriorBoxOpKernel
<
float
,
float
>
,
ops
::
PriorBoxOpKernel
<
double
,
double
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
PriorBoxInferShapeFunctor
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
prior_box
,
MKLDNN
,
...
...
paddle/fluid/operators/detection/prior_box_op.cu
浏览文件 @
d92b2f2d
...
...
@@ -194,8 +194,3 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
prior_box
,
ops
::
PriorBoxOpCUDAKernel
<
float
>
,
ops
::
PriorBoxOpCUDAKernel
<
double
>
);
paddle/phi/api/yaml/generator/api_base.py
浏览文件 @
d92b2f2d
...
...
@@ -141,7 +141,7 @@ class BaseAPI(object):
'DataLayout'
:
'DataLayout'
,
'DataType'
:
'DataType'
,
'int64_t[]'
:
'const std::vector<int64_t>&'
,
'int[]'
:
'const std::vector<int>&'
'int[]'
:
'const std::vector<int>&'
,
}
optional_types_trans
=
{
'Tensor'
:
'const paddle::optional<Tensor>&'
,
...
...
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
d92b2f2d
...
...
@@ -1791,6 +1791,14 @@
func
:
prelu
backward
:
prelu_grad
-
api
:
prior_box
args
:
(Tensor input, Tensor image, float[] min_sizes, float[] aspect_ratios, float[] variances, float[] max_sizes = {}, bool flip=true, bool clip=true, float step_w=0.0, float step_h=0.0, float offset=0.5, bool min_max_aspect_ratios_order=false)
output
:
Tensor(out), Tensor(var)
infer_meta
:
func
:
PriorBoxInferMeta
kernel
:
func
:
prior_box
-
api
:
psroi_pool
args
:
(Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, int output_channels, float spatial_scale)
output
:
Tensor
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
d92b2f2d
...
...
@@ -1809,6 +1809,110 @@ void PReluInferMeta(const MetaTensor& x,
out
->
share_lod
(
x
);
}
inline
void
ExpandAspectRatios
(
const
std
::
vector
<
float
>&
input_aspect_ratior
,
bool
flip
,
std
::
vector
<
float
>*
output_aspect_ratior
)
{
constexpr
float
epsilon
=
1e-6
;
output_aspect_ratior
->
clear
();
output_aspect_ratior
->
push_back
(
1.0
f
);
for
(
size_t
i
=
0
;
i
<
input_aspect_ratior
.
size
();
++
i
)
{
float
ar
=
input_aspect_ratior
[
i
];
bool
already_exist
=
false
;
for
(
size_t
j
=
0
;
j
<
output_aspect_ratior
->
size
();
++
j
)
{
if
(
fabs
(
ar
-
output_aspect_ratior
->
at
(
j
))
<
epsilon
)
{
already_exist
=
true
;
break
;
}
}
if
(
!
already_exist
)
{
output_aspect_ratior
->
push_back
(
ar
);
if
(
flip
)
{
output_aspect_ratior
->
push_back
(
1.0
f
/
ar
);
}
}
}
}
void
PriorBoxInferMeta
(
const
MetaTensor
&
input
,
const
MetaTensor
&
image
,
const
std
::
vector
<
float
>&
min_sizes
,
const
std
::
vector
<
float
>&
aspect_ratios
,
const
std
::
vector
<
float
>&
variances
,
const
std
::
vector
<
float
>&
max_sizes
,
bool
flip
,
bool
clip
,
float
step_w
,
float
step_h
,
float
offset
,
bool
min_max_aspect_ratios_order
,
MetaTensor
*
out
,
MetaTensor
*
var
)
{
auto
image_dims
=
image
.
dims
();
auto
input_dims
=
input
.
dims
();
PADDLE_ENFORCE_EQ
(
image_dims
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"The Input(Image) of Op(PriorBoxOp) should be a 4-D Tensor "
"and data format is NCHW. But received Image's dimensions = %d, "
"shape = [%s]."
,
image_dims
.
size
(),
image_dims
));
PADDLE_ENFORCE_EQ
(
input_dims
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"The Input(Input) of Op(PriorBoxOp) should be a 4-D Tensor "
"and data format is NCHW. But received Input's dimensions = %d, "
"shape = [%s]."
,
input_dims
.
size
(),
input_dims
));
std
::
vector
<
float
>
aspect_ratios_vec
;
ExpandAspectRatios
(
aspect_ratios
,
flip
,
&
aspect_ratios_vec
);
size_t
num_priors
=
aspect_ratios_vec
.
size
()
*
min_sizes
.
size
();
if
(
max_sizes
.
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
max_sizes
.
size
(),
min_sizes
.
size
(),
phi
::
errors
::
InvalidArgument
(
"The length of min_size and "
"max_size must be equal. But received: min_size's length is %d, "
"max_size's length is %d."
,
min_sizes
.
size
(),
max_sizes
.
size
()));
num_priors
+=
max_sizes
.
size
();
for
(
size_t
i
=
0
;
i
<
max_sizes
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
max_sizes
[
i
],
min_sizes
[
i
],
phi
::
errors
::
InvalidArgument
(
"max_size[%d] must be greater "
"than min_size[%d]. But received: max_size[%d] is %f, "
"min_size[%d] is %f."
,
i
,
i
,
i
,
max_sizes
[
i
],
i
,
min_sizes
[
i
]));
}
}
std
::
vector
<
int64_t
>
dim_vec
(
4
);
dim_vec
[
0
]
=
input_dims
[
2
];
dim_vec
[
1
]
=
input_dims
[
3
];
dim_vec
[
2
]
=
num_priors
;
dim_vec
[
3
]
=
4
;
out
->
set_dtype
(
input
.
dtype
());
var
->
set_dtype
(
input
.
dtype
());
out
->
set_dims
(
phi
::
make_ddim
(
dim_vec
));
var
->
set_dims
(
phi
::
make_ddim
(
dim_vec
));
}
void
SearchsortedInferMeta
(
const
MetaTensor
&
sorted_sequence
,
const
MetaTensor
&
value
,
bool
out_int32
,
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
d92b2f2d
...
...
@@ -256,6 +256,21 @@ void PReluInferMeta(const MetaTensor& x,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
void
PriorBoxInferMeta
(
const
MetaTensor
&
input
,
const
MetaTensor
&
image
,
const
std
::
vector
<
float
>&
min_sizes
,
const
std
::
vector
<
float
>&
aspect_ratios
,
const
std
::
vector
<
float
>&
variances
,
const
std
::
vector
<
float
>&
max_sizes
,
bool
flip
,
bool
clip
,
float
step_w
,
float
step_h
,
float
offset
,
bool
min_max_aspect_ratios_order
,
MetaTensor
*
out
,
MetaTensor
*
var
);
void
SearchsortedInferMeta
(
const
MetaTensor
&
sorted_sequence
,
const
MetaTensor
&
value
,
bool
out_int32
,
...
...
paddle/phi/kernels/cpu/prior_box_kernel.cc
0 → 100644
浏览文件 @
d92b2f2d
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prior_box_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
PriorBoxKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
image
,
const
std
::
vector
<
float
>&
min_sizes
,
const
std
::
vector
<
float
>&
aspect_ratios
,
const
std
::
vector
<
float
>&
variances
,
const
std
::
vector
<
float
>&
max_sizes
,
bool
flip
,
bool
clip
,
float
step_w
,
float
step_h
,
float
offset
,
bool
min_max_aspect_ratios_order
,
DenseTensor
*
out
,
DenseTensor
*
var
)
{
std
::
vector
<
float
>
new_aspect_ratios
;
ExpandAspectRatios
(
aspect_ratios
,
flip
,
&
new_aspect_ratios
);
T
new_step_w
=
static_cast
<
T
>
(
step_w
);
T
new_step_h
=
static_cast
<
T
>
(
step_h
);
T
new_offset
=
static_cast
<
T
>
(
offset
);
auto
img_width
=
image
.
dims
()[
3
];
auto
img_height
=
image
.
dims
()[
2
];
auto
feature_width
=
input
.
dims
()[
3
];
auto
feature_height
=
input
.
dims
()[
2
];
T
step_width
,
step_height
;
if
(
new_step_w
==
0
||
new_step_h
==
0
)
{
step_width
=
static_cast
<
T
>
(
img_width
)
/
feature_width
;
step_height
=
static_cast
<
T
>
(
img_height
)
/
feature_height
;
}
else
{
step_width
=
new_step_w
;
step_height
=
new_step_h
;
}
int
num_priors
=
new_aspect_ratios
.
size
()
*
min_sizes
.
size
();
if
(
max_sizes
.
size
()
>
0
)
{
num_priors
+=
max_sizes
.
size
();
}
ctx
.
template
Alloc
<
T
>(
out
);
ctx
.
template
Alloc
<
T
>(
var
);
T
*
b_t
=
out
->
data
<
T
>
();
for
(
int
h
=
0
;
h
<
feature_height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
feature_width
;
++
w
)
{
T
center_x
=
(
w
+
new_offset
)
*
step_width
;
T
center_y
=
(
h
+
new_offset
)
*
step_height
;
T
box_width
,
box_height
;
for
(
size_t
s
=
0
;
s
<
min_sizes
.
size
();
++
s
)
{
auto
min_size
=
min_sizes
[
s
];
if
(
min_max_aspect_ratios_order
)
{
box_width
=
box_height
=
min_size
/
2.
;
b_t
[
0
]
=
(
center_x
-
box_width
)
/
img_width
;
b_t
[
1
]
=
(
center_y
-
box_height
)
/
img_height
;
b_t
[
2
]
=
(
center_x
+
box_width
)
/
img_width
;
b_t
[
3
]
=
(
center_y
+
box_height
)
/
img_height
;
b_t
+=
4
;
if
(
max_sizes
.
size
()
>
0
)
{
auto
max_size
=
max_sizes
[
s
];
// square prior with size sqrt(minSize * maxSize)
box_width
=
box_height
=
sqrt
(
min_size
*
max_size
)
/
2.
;
b_t
[
0
]
=
(
center_x
-
box_width
)
/
img_width
;
b_t
[
1
]
=
(
center_y
-
box_height
)
/
img_height
;
b_t
[
2
]
=
(
center_x
+
box_width
)
/
img_width
;
b_t
[
3
]
=
(
center_y
+
box_height
)
/
img_height
;
b_t
+=
4
;
}
// priors with different aspect ratios
for
(
size_t
r
=
0
;
r
<
new_aspect_ratios
.
size
();
++
r
)
{
float
ar
=
new_aspect_ratios
[
r
];
if
(
fabs
(
ar
-
1.
)
<
1e-6
)
{
continue
;
}
box_width
=
min_size
*
sqrt
(
ar
)
/
2.
;
box_height
=
min_size
/
sqrt
(
ar
)
/
2.
;
b_t
[
0
]
=
(
center_x
-
box_width
)
/
img_width
;
b_t
[
1
]
=
(
center_y
-
box_height
)
/
img_height
;
b_t
[
2
]
=
(
center_x
+
box_width
)
/
img_width
;
b_t
[
3
]
=
(
center_y
+
box_height
)
/
img_height
;
b_t
+=
4
;
}
}
else
{
// priors with different aspect ratios
for
(
size_t
r
=
0
;
r
<
new_aspect_ratios
.
size
();
++
r
)
{
float
ar
=
new_aspect_ratios
[
r
];
box_width
=
min_size
*
sqrt
(
ar
)
/
2.
;
box_height
=
min_size
/
sqrt
(
ar
)
/
2.
;
b_t
[
0
]
=
(
center_x
-
box_width
)
/
img_width
;
b_t
[
1
]
=
(
center_y
-
box_height
)
/
img_height
;
b_t
[
2
]
=
(
center_x
+
box_width
)
/
img_width
;
b_t
[
3
]
=
(
center_y
+
box_height
)
/
img_height
;
b_t
+=
4
;
}
if
(
max_sizes
.
size
()
>
0
)
{
auto
max_size
=
max_sizes
[
s
];
// square prior with size sqrt(minSize * maxSize)
box_width
=
box_height
=
sqrt
(
min_size
*
max_size
)
/
2.
;
b_t
[
0
]
=
(
center_x
-
box_width
)
/
img_width
;
b_t
[
1
]
=
(
center_y
-
box_height
)
/
img_height
;
b_t
[
2
]
=
(
center_x
+
box_width
)
/
img_width
;
b_t
[
3
]
=
(
center_y
+
box_height
)
/
img_height
;
b_t
+=
4
;
}
}
}
}
}
if
(
clip
)
{
T
*
dt
=
out
->
data
<
T
>
();
std
::
transform
(
dt
,
dt
+
out
->
numel
(),
dt
,
[](
T
v
)
->
T
{
return
std
::
min
<
T
>
(
std
::
max
<
T
>
(
v
,
0.
),
1.
);
});
}
DenseTensor
var_t
;
var_t
.
Resize
(
phi
::
make_ddim
({
1
,
static_cast
<
int
>
(
variances
.
size
())}));
ctx
.
template
Alloc
<
T
>(
&
var_t
);
auto
var_et
=
EigenTensor
<
T
,
2
>::
From
(
var_t
);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
size_t
i
=
0
;
i
<
variances
.
size
();
++
i
)
{
var_et
(
0
,
i
)
=
variances
[
i
];
}
int
box_num
=
feature_height
*
feature_width
*
num_priors
;
auto
var_dim
=
var
->
dims
();
var
->
Resize
({
box_num
,
static_cast
<
int
>
(
variances
.
size
())});
auto
e_vars
=
EigenMatrix
<
T
,
Eigen
::
RowMajor
>::
From
(
*
var
);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int
i
=
0
;
i
<
box_num
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
variances
.
size
();
++
j
)
{
e_vars
(
i
,
j
)
=
variances
[
j
];
}
}
var
->
Resize
(
var_dim
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
prior_box
,
CPU
,
ALL_LAYOUT
,
phi
::
PriorBoxKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/prior_box_kernel.cu
0 → 100644
浏览文件 @
d92b2f2d
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prior_box_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
>
__device__
inline
T
clip
(
T
in
)
{
return
min
(
max
(
in
,
0.
),
1.
);
}
template
<
typename
T
>
__global__
void
GenPriorBox
(
T
*
out
,
const
T
*
aspect_ratios
,
const
int
height
,
const
int
width
,
const
int
im_height
,
const
int
im_width
,
const
int
as_num
,
const
T
offset
,
const
T
step_width
,
const
T
step_height
,
const
T
*
min_sizes
,
const
T
*
max_sizes
,
const
int
min_num
,
bool
is_clip
,
bool
min_max_aspect_ratios_order
)
{
int
num_priors
=
max_sizes
?
as_num
*
min_num
+
min_num
:
as_num
*
min_num
;
int
box_num
=
height
*
width
*
num_priors
;
CUDA_KERNEL_LOOP
(
i
,
box_num
)
{
int
h
=
i
/
(
num_priors
*
width
);
int
w
=
(
i
/
num_priors
)
%
width
;
int
p
=
i
%
num_priors
;
int
m
=
max_sizes
?
p
/
(
as_num
+
1
)
:
p
/
as_num
;
T
cx
=
(
w
+
offset
)
*
step_width
;
T
cy
=
(
h
+
offset
)
*
step_height
;
T
bw
,
bh
;
T
min_size
=
min_sizes
[
m
];
if
(
max_sizes
)
{
int
s
=
p
%
(
as_num
+
1
);
if
(
!
min_max_aspect_ratios_order
)
{
if
(
s
<
as_num
)
{
T
ar
=
aspect_ratios
[
s
];
bw
=
min_size
*
sqrt
(
ar
)
/
2.
;
bh
=
min_size
/
sqrt
(
ar
)
/
2.
;
}
else
{
T
max_size
=
max_sizes
[
m
];
bw
=
sqrt
(
min_size
*
max_size
)
/
2.
;
bh
=
bw
;
}
}
else
{
if
(
s
==
0
)
{
bw
=
bh
=
min_size
/
2.
;
}
else
if
(
s
==
1
)
{
T
max_size
=
max_sizes
[
m
];
bw
=
sqrt
(
min_size
*
max_size
)
/
2.
;
bh
=
bw
;
}
else
{
T
ar
=
aspect_ratios
[
s
-
1
];
bw
=
min_size
*
sqrt
(
ar
)
/
2.
;
bh
=
min_size
/
sqrt
(
ar
)
/
2.
;
}
}
}
else
{
int
s
=
p
%
as_num
;
T
ar
=
aspect_ratios
[
s
];
bw
=
min_size
*
sqrt
(
ar
)
/
2.
;
bh
=
min_size
/
sqrt
(
ar
)
/
2.
;
}
T
xmin
=
(
cx
-
bw
)
/
im_width
;
T
ymin
=
(
cy
-
bh
)
/
im_height
;
T
xmax
=
(
cx
+
bw
)
/
im_width
;
T
ymax
=
(
cy
+
bh
)
/
im_height
;
out
[
i
*
4
]
=
is_clip
?
clip
<
T
>
(
xmin
)
:
xmin
;
out
[
i
*
4
+
1
]
=
is_clip
?
clip
<
T
>
(
ymin
)
:
ymin
;
out
[
i
*
4
+
2
]
=
is_clip
?
clip
<
T
>
(
xmax
)
:
xmax
;
out
[
i
*
4
+
3
]
=
is_clip
?
clip
<
T
>
(
ymax
)
:
ymax
;
}
}
template
<
typename
T
>
__global__
void
SetVariance
(
T
*
out
,
const
T
*
var
,
const
int
vnum
,
const
int
num
)
{
CUDA_KERNEL_LOOP
(
i
,
num
)
{
out
[
i
]
=
var
[
i
%
vnum
];
}
}
template
<
typename
T
,
typename
Context
>
void
PriorBoxKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
image
,
const
std
::
vector
<
float
>&
min_sizes
,
const
std
::
vector
<
float
>&
aspect_ratios
,
const
std
::
vector
<
float
>&
variances
,
const
std
::
vector
<
float
>&
max_sizes
,
bool
flip
,
bool
clip
,
float
step_w
,
float
step_h
,
float
offset
,
bool
min_max_aspect_ratios_order
,
DenseTensor
*
out
,
DenseTensor
*
var
)
{
std
::
vector
<
float
>
new_aspect_ratios
;
ExpandAspectRatios
(
aspect_ratios
,
flip
,
&
new_aspect_ratios
);
T
new_step_w
=
static_cast
<
T
>
(
step_w
);
T
new_step_h
=
static_cast
<
T
>
(
step_h
);
T
new_offset
=
static_cast
<
T
>
(
offset
);
auto
im_width
=
image
.
dims
()[
3
];
auto
im_height
=
image
.
dims
()[
2
];
auto
width
=
input
.
dims
()[
3
];
auto
height
=
input
.
dims
()[
2
];
T
step_width
,
step_height
;
if
(
new_step_w
==
0
||
new_step_h
==
0
)
{
step_width
=
static_cast
<
T
>
(
im_width
)
/
width
;
step_height
=
static_cast
<
T
>
(
im_height
)
/
height
;
}
else
{
step_width
=
new_step_w
;
step_height
=
new_step_h
;
}
int
num_priors
=
new_aspect_ratios
.
size
()
*
min_sizes
.
size
();
if
(
max_sizes
.
size
()
>
0
)
{
num_priors
+=
max_sizes
.
size
();
}
int
min_num
=
static_cast
<
int
>
(
min_sizes
.
size
());
int
box_num
=
width
*
height
*
num_priors
;
int
block
=
512
;
int
grid
=
(
box_num
+
block
-
1
)
/
block
;
auto
stream
=
ctx
.
stream
();
ctx
.
template
Alloc
<
T
>(
out
);
ctx
.
template
Alloc
<
T
>(
var
);
DenseTensor
r
;
paddle
::
framework
::
TensorFromVector
(
new_aspect_ratios
,
ctx
,
&
r
);
DenseTensor
min
;
paddle
::
framework
::
TensorFromVector
(
min_sizes
,
ctx
,
&
min
);
T
*
max_data
=
nullptr
;
DenseTensor
max
;
if
(
max_sizes
.
size
()
>
0
)
{
paddle
::
framework
::
TensorFromVector
(
max_sizes
,
ctx
,
&
max
);
max_data
=
max
.
data
<
T
>
();
}
GenPriorBox
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
out
->
data
<
T
>
(),
r
.
data
<
T
>
(),
height
,
width
,
im_height
,
im_width
,
new_aspect_ratios
.
size
(),
new_offset
,
step_width
,
step_height
,
min
.
data
<
T
>
(),
max_data
,
min_num
,
clip
,
min_max_aspect_ratios_order
);
DenseTensor
v
;
paddle
::
framework
::
TensorFromVector
(
variances
,
ctx
,
&
v
);
grid
=
(
box_num
*
4
+
block
-
1
)
/
block
;
SetVariance
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
var
->
data
<
T
>
(),
v
.
data
<
T
>
(),
variances
.
size
(),
box_num
*
4
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
prior_box
,
GPU
,
ALL_LAYOUT
,
phi
::
PriorBoxKernel
,
float
,
double
)
{}
paddle/phi/kernels/prior_box_kernel.h
0 → 100644
浏览文件 @
d92b2f2d
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
PriorBoxKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
image
,
const
std
::
vector
<
float
>&
min_sizes
,
const
std
::
vector
<
float
>&
aspect_ratios
,
const
std
::
vector
<
float
>&
variances
,
const
std
::
vector
<
float
>&
max_sizes
,
bool
flip
,
bool
clip
,
float
step_w
,
float
step_h
,
float
offset
,
bool
min_max_aspect_ratios_order
,
DenseTensor
*
out
,
DenseTensor
*
var
);
inline
void
ExpandAspectRatios
(
const
std
::
vector
<
float
>&
input_aspect_ratior
,
bool
flip
,
std
::
vector
<
float
>*
output_aspect_ratior
)
{
constexpr
float
epsilon
=
1e-6
;
output_aspect_ratior
->
clear
();
output_aspect_ratior
->
push_back
(
1.0
f
);
for
(
size_t
i
=
0
;
i
<
input_aspect_ratior
.
size
();
++
i
)
{
float
ar
=
input_aspect_ratior
[
i
];
bool
already_exist
=
false
;
for
(
size_t
j
=
0
;
j
<
output_aspect_ratior
->
size
();
++
j
)
{
if
(
fabs
(
ar
-
output_aspect_ratior
->
at
(
j
))
<
epsilon
)
{
already_exist
=
true
;
break
;
}
}
if
(
!
already_exist
)
{
output_aspect_ratior
->
push_back
(
ar
);
if
(
flip
)
{
output_aspect_ratior
->
push_back
(
1.0
f
/
ar
);
}
}
}
}
}
// namespace phi
paddle/phi/ops/compat/prior_box_sig.cc
0 → 100644
浏览文件 @
d92b2f2d
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
PriorBoxOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"prior_box"
,
{
"Input"
,
"Image"
},
{
"min_sizes"
,
"aspect_ratios"
,
"variances"
,
"max_sizes"
,
"flip"
,
"clip"
,
"step_w"
,
"step_h"
,
"offset"
,
"min_max_aspect_ratios_order"
},
{
"Boxes"
,
"Variances"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
prior_box
,
phi
::
PriorBoxOpArgumentMapping
);
python/paddle/fluid/layers/detection.py
浏览文件 @
d92b2f2d
...
...
@@ -22,7 +22,7 @@ import paddle
from
.layer_function_generator
import
generate_layer_fn
from
.layer_function_generator
import
autodoc
,
templatedoc
from
..layer_helper
import
LayerHelper
from
..framework
import
Variable
,
_non_static_mode
,
static_only
from
..framework
import
Variable
,
_non_static_mode
,
static_only
,
in_dygraph_mode
from
..
import
core
from
.loss
import
softmax_with_cross_entropy
from
.
import
tensor
...
...
@@ -1794,18 +1794,20 @@ def ssd_loss(location,
return
loss
def
prior_box
(
input
,
image
,
min_sizes
,
max_sizes
=
None
,
aspect_ratios
=
[
1.
],
variance
=
[
0.1
,
0.1
,
0.2
,
0.2
],
flip
=
False
,
clip
=
False
,
steps
=
[
0.0
,
0.0
],
offset
=
0.5
,
name
=
None
,
min_max_aspect_ratios_order
=
False
):
def
prior_box
(
input
,
image
,
min_sizes
,
max_sizes
=
None
,
aspect_ratios
=
[
1.
],
variance
=
[
0.1
,
0.1
,
0.2
,
0.2
],
flip
=
False
,
clip
=
False
,
steps
=
[
0.0
,
0.0
],
offset
=
0.5
,
name
=
None
,
min_max_aspect_ratios_order
=
False
,
):
"""
This op generates prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
...
...
@@ -1905,6 +1907,15 @@ def prior_box(input,
# [6L, 9L, 1L, 4L]
"""
if
in_dygraph_mode
():
step_w
,
step_h
=
steps
if
max_sizes
==
None
:
max_sizes
=
[]
return
_C_ops
.
final_state_prior_box
(
input
,
image
,
min_sizes
,
aspect_ratios
,
variance
,
max_sizes
,
flip
,
clip
,
step_w
,
step_h
,
offset
,
min_max_aspect_ratios_order
)
helper
=
LayerHelper
(
"prior_box"
,
**
locals
())
dtype
=
helper
.
input_dtype
()
check_variable_and_dtype
(
input
,
'input'
,
...
...
python/paddle/fluid/tests/unittests/test_prior_box_op.py
浏览文件 @
d92b2f2d
...
...
@@ -19,6 +19,35 @@ import numpy as np
import
sys
import
math
from
op_test
import
OpTest
import
paddle
def
python_prior_box
(
input
,
image
,
min_sizes
,
aspect_ratios
=
[
1.
],
variances
=
[
0.1
,
0.1
,
0.2
,
0.2
],
max_sizes
=
None
,
flip
=
False
,
clip
=
False
,
step_w
=
0
,
step_h
=
0
,
offset
=
0.5
,
min_max_aspect_ratios_order
=
False
,
name
=
None
):
return
paddle
.
fluid
.
layers
.
detection
.
prior_box
(
input
,
image
,
min_sizes
=
min_sizes
,
max_sizes
=
max_sizes
,
aspect_ratios
=
aspect_ratios
,
variance
=
variances
,
flip
=
flip
,
clip
=
clip
,
steps
=
[
step_w
,
step_h
],
offset
=
offset
,
name
=
name
,
min_max_aspect_ratios_order
=
min_max_aspect_ratios_order
)
class
TestPriorBoxOp
(
OpTest
):
...
...
@@ -35,10 +64,10 @@ class TestPriorBoxOp(OpTest):
'variances'
:
self
.
variances
,
'flip'
:
self
.
flip
,
'clip'
:
self
.
clip
,
'min_max_aspect_ratios_order'
:
self
.
min_max_aspect_ratios_order
,
'step_w'
:
self
.
step_w
,
'step_h'
:
self
.
step_h
,
'offset'
:
self
.
offset
'offset'
:
self
.
offset
,
'min_max_aspect_ratios_order'
:
self
.
min_max_aspect_ratios_order
,
}
if
len
(
self
.
max_sizes
)
>
0
:
self
.
attrs
[
'max_sizes'
]
=
self
.
max_sizes
...
...
@@ -46,10 +75,11 @@ class TestPriorBoxOp(OpTest):
self
.
outputs
=
{
'Boxes'
:
self
.
out_boxes
,
'Variances'
:
self
.
out_var
}
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
def
setUp
(
self
):
self
.
op_type
=
"prior_box"
self
.
python_api
=
python_prior_box
self
.
set_data
()
def
set_max_sizes
(
self
):
...
...
@@ -191,4 +221,5 @@ class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp):
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录