Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
98f8fa4c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
98f8fa4c
编写于
7月 26, 2022
作者:
L
lyq
提交者:
GitHub
7月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Migrate box coder to phi. (#44550)
上级
ab198b45
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
851 addition
and
431 deletion
+851
-431
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+2
-3
paddle/fluid/operators/detection/box_coder_op.cc
paddle/fluid/operators/detection/box_coder_op.cc
+12
-125
paddle/fluid/operators/detection/box_coder_op.h
paddle/fluid/operators/detection/box_coder_op.h
+0
-295
paddle/fluid/operators/detection/box_coder_op_npu.cc
paddle/fluid/operators/detection/box_coder_op_npu.cc
+7
-3
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+10
-0
paddle/phi/infermeta/ternary.cc
paddle/phi/infermeta/ternary.cc
+110
-0
paddle/phi/infermeta/ternary.h
paddle/phi/infermeta/ternary.h
+10
-0
paddle/phi/kernels/box_coder_kernel.h
paddle/phi/kernels/box_coder_kernel.h
+34
-0
paddle/phi/kernels/cpu/box_coder.cc
paddle/phi/kernels/cpu/box_coder.cc
+281
-0
paddle/phi/kernels/gpu/box_coder.cu
paddle/phi/kernels/gpu/box_coder.cu
+246
-0
paddle/phi/kernels/impl/box_coder.h
paddle/phi/kernels/impl/box_coder.h
+43
-0
paddle/phi/ops/compat/box_coder_sig.cc
paddle/phi/ops/compat/box_coder_sig.cc
+28
-0
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+17
-0
python/paddle/fluid/tests/unittests/test_box_coder_op.py
python/paddle/fluid/tests/unittests/test_box_coder_op.py
+49
-4
python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py
...n/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py
+2
-1
未找到文件。
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
98f8fa4c
...
...
@@ -29,12 +29,11 @@ function(detection_library TARGET_NAME)
endfunction
()
if
(
WITH_ASCEND_CL
)
detection_library
(
box_coder_op SRCS box_coder_op.cc box_coder_op.cu
box_coder_op_npu.cc
)
detection_library
(
box_coder_op SRCS box_coder_op.cc box_coder_op_npu.cc
)
detection_library
(
density_prior_box_op SRCS density_prior_box_op.cc
density_prior_box_op.cu density_prior_box_op_npu.cc
)
else
()
detection_library
(
box_coder_op SRCS box_coder_op.cc
box_coder_op.cu
)
detection_library
(
box_coder_op SRCS box_coder_op.cc
)
detection_library
(
density_prior_box_op SRCS density_prior_box_op.cc
density_prior_box_op.cu
)
endif
()
...
...
paddle/fluid/operators/detection/box_coder_op.cc
浏览文件 @
98f8fa4c
...
...
@@ -9,135 +9,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace
paddle
{
namespace
operators
{
class
BoxCoderOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"PriorBox"
),
true
,
platform
::
errors
::
NotFound
(
"Input(PriorBox) of BoxCoder operator is not found."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"TargetBox"
),
true
,
platform
::
errors
::
NotFound
(
"Input(TargetBox) of BoxCoder operator is not found."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"OutputBox"
),
true
,
platform
::
errors
::
NotFound
(
"Output(OutputBox) of BoxCoder operator is not found."
));
auto
prior_box_dims
=
ctx
->
GetInputDim
(
"PriorBox"
);
auto
target_box_dims
=
ctx
->
GetInputDim
(
"TargetBox"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
prior_box_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input PriorBox in BoxCoder operator "
"must be 2. But received rank = %d"
,
prior_box_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
prior_box_dims
[
1
],
4
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of PriorBox in BoxCoder "
"operator must be 4. But received dimension = %d"
,
prior_box_dims
[
1
]));
if
(
ctx
->
HasInput
(
"PriorBoxVar"
))
{
auto
prior_box_var_dims
=
ctx
->
GetInputDim
(
"PriorBoxVar"
);
PADDLE_ENFORCE_EQ
(
prior_box_var_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(PriorBoxVar) in BoxCoder operator"
" should be 2. But received rank = %d"
,
prior_box_var_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
,
platform
::
errors
::
InvalidArgument
(
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox) in BoxCoder operator "
"when the rank is 2."
));
}
}
auto
code_type
=
GetBoxCodeType
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"code_type"
));
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input TargetBox in BoxCoder operator "
"must be 2. But received rank is %d"
,
target_box_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
4
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of TargetBox in BoxCoder "
"operator is 4. But received dimension is %d"
,
target_box_dims
[
1
]));
ctx
->
SetOutputDim
(
"OutputBox"
,
phi
::
make_ddim
({
target_box_dims
[
0
],
prior_box_dims
[
0
],
4
}));
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
3
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input TargetBox in BoxCoder "
"operator must be 3. But received rank is %d"
,
target_box_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
axis
==
0
||
axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"axis in BoxCoder operator must be 0 or 1."
"But received axis = %d"
,
axis
));
if
(
ctx
->
IsRuntime
())
{
if
(
axis
==
0
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
prior_box_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"When axis is 0, The second "
"dimension of TargetBox in BoxCoder "
"should be equal to the first dimension of PriorBox."
));
}
else
if
(
axis
==
1
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
[
0
],
prior_box_dims
[
0
],
platform
::
errors
::
InvalidArgument
(
"When axis is 1, The first "
"dimension of TargetBox in BoxCoder "
"should be equal to the first dimension of PriorBox."
));
}
PADDLE_ENFORCE_EQ
(
target_box_dims
[
2
],
prior_box_dims
[
1
],
platform
::
errors
::
InvalidArgument
(
"The third dimension of TargetBox"
" in BoxCoder should be equal to the "
"second dimension of PriorBox."
));
}
ctx
->
ShareDim
(
"TargetBox"
,
/*->*/
"OutputBox"
);
}
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
&&
axis
==
1
)
{
ctx
->
ShareLoD
(
"PriorBox"
,
/*->*/
"OutputBox"
);
}
else
{
ctx
->
ShareLoD
(
"TargetBox"
,
/*->*/
"OutputBox"
);
}
}
};
class
BoxCoderOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -245,12 +129,15 @@ box will broadcast to target box along the assigned axis.
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
box_coder
,
BoxCoderInferShapeFunctor
,
PD_INFER_META
(
phi
::
BoxCoderInferMeta
));
REGISTER_OPERATOR
(
box_coder
,
ops
::
BoxCoderOp
,
ops
::
BoxCoderOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
box_coder
,
ops
::
BoxCoderKernel
<
phi
::
CPUContext
,
float
>
,
ops
::
BoxCoderKernel
<
phi
::
CPUContext
,
double
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
BoxCoderInferShapeFunctor
);
paddle/fluid/operators/detection/box_coder_op.h
已删除
100644 → 0
浏览文件 @
ab198b45
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
enum
class
BoxCodeType
{
kEncodeCenterSize
=
0
,
kDecodeCenterSize
=
1
};
inline
BoxCodeType
GetBoxCodeType
(
const
std
::
string
&
type
)
{
PADDLE_ENFORCE_EQ
(
(
type
==
"encode_center_size"
)
||
(
type
==
"decode_center_size"
),
true
,
platform
::
errors
::
InvalidArgument
(
"The 'code_type' attribute in BoxCoder"
" must be 'encode_center_size' or 'decode_center_size'. "
"But received 'code_type' is %s"
,
type
));
if
(
type
==
"encode_center_size"
)
{
return
BoxCodeType
::
kEncodeCenterSize
;
}
else
{
return
BoxCodeType
::
kDecodeCenterSize
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
BoxCoderKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
EncodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
const
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
->
dims
()[
1
];
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
T
prior_box_width
=
prior_box_data
[
j
*
len
+
2
]
-
prior_box_data
[
j
*
len
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
j
*
len
+
3
]
-
prior_box_data
[
j
*
len
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
prior_box_data
[
j
*
len
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
prior_box_data
[
j
*
len
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
(
target_box_data
[
i
*
len
+
2
]
+
target_box_data
[
i
*
len
])
/
2
;
T
target_box_center_y
=
(
target_box_data
[
i
*
len
+
3
]
+
target_box_data
[
i
*
len
+
1
])
/
2
;
T
target_box_width
=
target_box_data
[
i
*
len
+
2
]
-
target_box_data
[
i
*
len
]
+
(
normalized
==
false
);
T
target_box_height
=
target_box_data
[
i
*
len
+
3
]
-
target_box_data
[
i
*
len
+
1
]
+
(
normalized
==
false
);
output
[
offset
]
=
(
target_box_center_x
-
prior_box_center_x
)
/
prior_box_width
;
output
[
offset
+
1
]
=
(
target_box_center_y
-
prior_box_center_y
)
/
prior_box_height
;
output
[
offset
+
2
]
=
std
::
log
(
std
::
fabs
(
target_box_width
/
prior_box_width
));
output
[
offset
+
3
]
=
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
));
}
}
if
(
prior_box_var
)
{
const
T
*
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
int
prior_var_offset
=
j
*
len
;
output
[
offset
+
k
]
/=
prior_box_var_data
[
prior_var_offset
+
k
];
}
}
}
}
else
if
(
!
(
variance
.
empty
()))
{
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
output
[
offset
+
k
]
/=
static_cast
<
T
>
(
variance
[
k
]);
}
}
}
}
}
template
<
int
axis
,
int
var_size
>
void
DecodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
target_box
->
dims
()[
1
];
int64_t
len
=
target_box
->
dims
()[
2
];
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
T
var_data
[
4
]
=
{
1.
,
1.
,
1.
,
1.
};
T
*
var_ptr
=
var_data
;
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
int
prior_box_offset
=
axis
==
0
?
j
*
len
:
i
*
len
;
T
prior_box_width
=
prior_box_data
[
prior_box_offset
+
2
]
-
prior_box_data
[
prior_box_offset
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
prior_box_offset
+
3
]
-
prior_box_data
[
prior_box_offset
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
prior_box_data
[
prior_box_offset
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
prior_box_data
[
prior_box_offset
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
0
,
target_box_center_y
=
0
;
T
target_box_width
=
0
,
target_box_height
=
0
;
int
prior_var_offset
=
axis
==
0
?
j
*
len
:
i
*
len
;
if
(
var_size
==
2
)
{
std
::
memcpy
(
var_ptr
,
prior_box_var
->
data
<
T
>
()
+
prior_var_offset
,
4
*
sizeof
(
T
));
}
else
if
(
var_size
==
1
)
{
var_ptr
=
reinterpret_cast
<
T
*>
(
variance
.
data
());
}
T
box_var_x
=
*
var_ptr
;
T
box_var_y
=
*
(
var_ptr
+
1
);
T
box_var_w
=
*
(
var_ptr
+
2
);
T
box_var_h
=
*
(
var_ptr
+
3
);
target_box_center_x
=
box_var_x
*
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
box_var_y
*
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
box_var_w
*
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
box_var_h
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
output
[
offset
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
offset
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
output
[
offset
+
2
]
=
target_box_center_x
+
target_box_width
/
2
-
(
normalized
==
false
);
output
[
offset
+
3
]
=
target_box_center_y
+
target_box_height
/
2
-
(
normalized
==
false
);
}
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
prior_box
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBox"
);
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
std
::
vector
<
float
>
variance
=
context
.
Attr
<
std
::
vector
<
float
>>
(
"variance"
);
const
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"Input(TargetBox) of BoxCoder operator "
"supports LoD with only one level. But received "
"level = %d"
,
target_box
->
lod
().
size
()));
}
if
(
prior_box_var
)
{
PADDLE_ENFORCE_EQ
(
variance
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input 'PriorBoxVar' and attribute 'variance' "
"of BoxCoder operator should not be used at the "
"same time."
));
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE_EQ
(
static_cast
<
int
>
(
variance
.
size
()),
4
,
platform
::
errors
::
InvalidArgument
(
"Size of attribute 'variance' of BoxCoder "
"operator should be 4. But received "
"size = %d"
,
variance
.
size
()));
}
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
auto
row
=
target_box
->
dims
()[
0
];
auto
col
=
prior_box
->
dims
()[
0
];
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
col
=
target_box
->
dims
()[
1
];
}
auto
len
=
prior_box
->
dims
()[
1
];
output_box
->
mutable_data
<
T
>
({
row
,
col
,
len
},
context
.
GetPlace
());
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
if
(
prior_box_var
)
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
0
,
2
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
1
,
2
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
}
else
if
(
!
(
variance
.
empty
()))
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
0
,
1
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
1
,
1
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
}
else
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
0
,
0
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
1
,
0
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detection/box_coder_op_npu.cc
浏览文件 @
98f8fa4c
...
...
@@ -9,8 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/impl/box_coder.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -407,10 +410,11 @@ class BoxCoderNPUKernel : public framework::OpKernel<T> {
" supports LoD with one level."
));
}
auto
code_type
=
GetBoxCodeType
(
ctx
.
Attr
<
std
::
string
>
(
"code_type"
));
auto
code_type
=
phi
::
funcs
::
GetBoxCodeType
(
ctx
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
ctx
.
Attr
<
bool
>
(
"box_normalized"
);
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kEncodeCenterSize
)
{
BoxCoderEnc
<
T
>
(
ctx
,
target_box
,
prior_box
,
...
...
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
98f8fa4c
...
...
@@ -326,6 +326,16 @@
kernel
:
func
:
bitwise_xor
# box_coder
-
api
:
box_coder
args
:
(Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type, bool box_normalized, int axis, float[] variance)
output
:
Tensor(output_box)
infer_meta
:
func
:
BoxCoderInferMeta
kernel
:
func
:
box_coder
optional
:
prior_box_var
# brelu
-
api
:
brelu
args
:
(Tensor x, float t_min, float t_max)
...
...
paddle/phi/infermeta/ternary.cc
浏览文件 @
98f8fa4c
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/box_coder.h"
namespace
phi
{
...
...
@@ -145,6 +146,115 @@ void AddmmInferMeta(const MetaTensor& input,
out
->
set_dtype
(
input
.
dtype
());
}
void
BoxCoderInferMeta
(
const
MetaTensor
&
prior_box
,
const
MetaTensor
&
prior_box_var
,
const
MetaTensor
&
target_box
,
const
std
::
string
&
code_type
,
bool
box_normalized
,
int
axis
,
const
std
::
vector
<
float
>&
variance
,
MetaTensor
*
output_box
,
MetaConfig
config
)
{
auto
prior_box_dims
=
prior_box
.
dims
();
auto
target_box_dims
=
target_box
.
dims
();
if
(
config
.
is_runtime
)
{
PADDLE_ENFORCE_EQ
(
prior_box_dims
.
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"The rank of Input PriorBox in BoxCoder operator "
"must be 2. But received rank = %d"
,
prior_box_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
prior_box_dims
[
1
],
4
,
phi
::
errors
::
InvalidArgument
(
"The second dimension of PriorBox in BoxCoder "
"operator must be 4. But received dimension = %d"
,
prior_box_dims
[
1
]));
if
(
prior_box_var
)
{
auto
prior_box_var_dims
=
prior_box_var
.
dims
();
PADDLE_ENFORCE_EQ
(
prior_box_var_dims
.
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"The rank of Input(PriorBoxVar) in BoxCoder operator"
" should be 2. But received rank = %d"
,
prior_box_var_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
,
phi
::
errors
::
InvalidArgument
(
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox) in BoxCoder operator "
"when the rank is 2."
));
}
}
auto
box_code_type
=
phi
::
funcs
::
GetBoxCodeType
(
code_type
);
if
(
box_code_type
==
phi
::
funcs
::
BoxCodeType
::
kEncodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"The rank of Input TargetBox in BoxCoder operator "
"must be 2. But received rank is %d"
,
target_box_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
4
,
phi
::
errors
::
InvalidArgument
(
"The second dimension of TargetBox in BoxCoder "
"operator is 4. But received dimension is %d"
,
target_box_dims
[
1
]));
output_box
->
set_dims
({
target_box_dims
[
0
],
prior_box_dims
[
0
],
4
});
}
else
if
(
box_code_type
==
phi
::
funcs
::
BoxCodeType
::
kDecodeCenterSize
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
3
,
phi
::
errors
::
InvalidArgument
(
"The rank of Input TargetBox in BoxCoder "
"operator must be 3. But received rank is %d"
,
target_box_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
axis
==
0
||
axis
==
1
,
true
,
phi
::
errors
::
InvalidArgument
(
"axis in BoxCoder operator must be 0 or 1."
"But received axis = %d"
,
axis
));
if
(
config
.
is_runtime
)
{
if
(
axis
==
0
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
prior_box_dims
[
0
],
phi
::
errors
::
InvalidArgument
(
"When axis is 0, The second "
"dimension of TargetBox in BoxCoder "
"should be equal to the first dimension of PriorBox."
));
}
else
if
(
axis
==
1
)
{
PADDLE_ENFORCE_EQ
(
target_box_dims
[
0
],
prior_box_dims
[
0
],
phi
::
errors
::
InvalidArgument
(
"When axis is 1, The first "
"dimension of TargetBox in BoxCoder "
"should be equal to the first dimension of PriorBox."
));
}
PADDLE_ENFORCE_EQ
(
target_box_dims
[
2
],
prior_box_dims
[
1
],
phi
::
errors
::
InvalidArgument
(
"The third dimension of TargetBox"
" in BoxCoder should be equal to the "
"second dimension of PriorBox."
));
}
output_box
->
share_dims
(
target_box
);
}
if
(
box_code_type
==
phi
::
funcs
::
BoxCodeType
::
kDecodeCenterSize
&&
axis
==
1
)
{
output_box
->
share_lod
(
prior_box
);
}
else
{
output_box
->
share_lod
(
target_box
);
}
output_box
->
set_dtype
(
target_box
.
dtype
());
}
void
ArangeInferMeta
(
const
MetaTensor
&
start
,
const
MetaTensor
&
end
,
const
MetaTensor
&
step
,
...
...
paddle/phi/infermeta/ternary.h
浏览文件 @
98f8fa4c
...
...
@@ -52,6 +52,16 @@ void ArangeInferMeta(const MetaTensor& start,
const
MetaTensor
&
step
,
MetaTensor
*
out
);
void
BoxCoderInferMeta
(
const
MetaTensor
&
prior_box
,
const
MetaTensor
&
prior_box_var
,
const
MetaTensor
&
target_box
,
const
std
::
string
&
code_type
,
bool
box_normalized
,
int
axis
,
const
std
::
vector
<
float
>&
variance
,
MetaTensor
*
output_box
,
MetaConfig
config
=
MetaConfig
());
void
InstanceNormInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
scale
,
const
MetaTensor
&
bias
,
...
...
paddle/phi/kernels/box_coder_kernel.h
0 → 100644
浏览文件 @
98f8fa4c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
BoxCoderKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
prior_box
,
const
paddle
::
optional
<
DenseTensor
>&
prior_box_var
,
const
DenseTensor
&
target_box
,
const
std
::
string
&
code_type
,
bool
box_normalized
,
int
axis
,
const
std
::
vector
<
float
>&
variance
,
DenseTensor
*
output_box
);
}
// namespace phi
paddle/phi/kernels/cpu/box_coder.cc
0 → 100644
浏览文件 @
98f8fa4c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/box_coder_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/box_coder.h"
namespace
phi
{
template
<
typename
T
>
void
EncodeCenterSize
(
const
DenseTensor
*
target_box
,
const
DenseTensor
*
prior_box
,
const
DenseTensor
*
prior_box_var
,
const
bool
normalized
,
const
std
::
vector
<
float
>
variance
,
T
*
output
)
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
prior_box
->
dims
()[
0
];
int64_t
len
=
prior_box
->
dims
()[
1
];
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
T
prior_box_width
=
prior_box_data
[
j
*
len
+
2
]
-
prior_box_data
[
j
*
len
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
j
*
len
+
3
]
-
prior_box_data
[
j
*
len
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
prior_box_data
[
j
*
len
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
prior_box_data
[
j
*
len
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
(
target_box_data
[
i
*
len
+
2
]
+
target_box_data
[
i
*
len
])
/
2
;
T
target_box_center_y
=
(
target_box_data
[
i
*
len
+
3
]
+
target_box_data
[
i
*
len
+
1
])
/
2
;
T
target_box_width
=
target_box_data
[
i
*
len
+
2
]
-
target_box_data
[
i
*
len
]
+
(
normalized
==
false
);
T
target_box_height
=
target_box_data
[
i
*
len
+
3
]
-
target_box_data
[
i
*
len
+
1
]
+
(
normalized
==
false
);
output
[
offset
]
=
(
target_box_center_x
-
prior_box_center_x
)
/
prior_box_width
;
output
[
offset
+
1
]
=
(
target_box_center_y
-
prior_box_center_y
)
/
prior_box_height
;
output
[
offset
+
2
]
=
std
::
log
(
std
::
fabs
(
target_box_width
/
prior_box_width
));
output
[
offset
+
3
]
=
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
));
}
}
if
(
prior_box_var
)
{
const
T
*
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
int
prior_var_offset
=
j
*
len
;
output
[
offset
+
k
]
/=
prior_box_var_data
[
prior_var_offset
+
k
];
}
}
}
}
else
if
(
!
(
variance
.
empty
()))
{
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
output
[
offset
+
k
]
/=
static_cast
<
T
>
(
variance
[
k
]);
}
}
}
}
}
template
<
typename
T
,
int
axis
,
int
var_size
>
void
DecodeCenterSize
(
const
DenseTensor
*
target_box
,
const
DenseTensor
*
prior_box
,
const
DenseTensor
*
prior_box_var
,
const
bool
normalized
,
std
::
vector
<
float
>
variance
,
T
*
output
)
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
target_box
->
dims
()[
1
];
int64_t
len
=
target_box
->
dims
()[
2
];
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
T
var_data
[
4
]
=
{
1.
,
1.
,
1.
,
1.
};
T
*
var_ptr
=
var_data
;
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
int
prior_box_offset
=
axis
==
0
?
j
*
len
:
i
*
len
;
T
prior_box_width
=
prior_box_data
[
prior_box_offset
+
2
]
-
prior_box_data
[
prior_box_offset
]
+
(
normalized
==
false
);
T
prior_box_height
=
prior_box_data
[
prior_box_offset
+
3
]
-
prior_box_data
[
prior_box_offset
+
1
]
+
(
normalized
==
false
);
T
prior_box_center_x
=
prior_box_data
[
prior_box_offset
]
+
prior_box_width
/
2
;
T
prior_box_center_y
=
prior_box_data
[
prior_box_offset
+
1
]
+
prior_box_height
/
2
;
T
target_box_center_x
=
0
,
target_box_center_y
=
0
;
T
target_box_width
=
0
,
target_box_height
=
0
;
int
prior_var_offset
=
axis
==
0
?
j
*
len
:
i
*
len
;
if
(
var_size
==
2
)
{
std
::
memcpy
(
var_ptr
,
prior_box_var
->
data
<
T
>
()
+
prior_var_offset
,
4
*
sizeof
(
T
));
}
else
if
(
var_size
==
1
)
{
var_ptr
=
reinterpret_cast
<
T
*>
(
variance
.
data
());
}
T
box_var_x
=
*
var_ptr
;
T
box_var_y
=
*
(
var_ptr
+
1
);
T
box_var_w
=
*
(
var_ptr
+
2
);
T
box_var_h
=
*
(
var_ptr
+
3
);
target_box_center_x
=
box_var_x
*
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
target_box_center_y
=
box_var_y
*
target_box_data
[
offset
+
1
]
*
prior_box_height
+
prior_box_center_y
;
target_box_width
=
std
::
exp
(
box_var_w
*
target_box_data
[
offset
+
2
])
*
prior_box_width
;
target_box_height
=
std
::
exp
(
box_var_h
*
target_box_data
[
offset
+
3
])
*
prior_box_height
;
output
[
offset
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
offset
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
output
[
offset
+
2
]
=
target_box_center_x
+
target_box_width
/
2
-
(
normalized
==
false
);
output
[
offset
+
3
]
=
target_box_center_y
+
target_box_height
/
2
-
(
normalized
==
false
);
}
}
}
template
<
typename
T
,
typename
Context
>
void
BoxCoderKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
prior_box
,
const
paddle
::
optional
<
DenseTensor
>
&
prior_box_var
,
const
DenseTensor
&
target_box
,
const
std
::
string
&
code_type_str
,
bool
normalized
,
int
axis
,
const
std
::
vector
<
float
>
&
variance
,
DenseTensor
*
output_box
)
{
if
(
target_box
.
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
.
lod
().
size
(),
1UL
,
phi
::
errors
::
InvalidArgument
(
"Input(TargetBox) of BoxCoder operator "
"supports LoD with only one level. But received "
"level = %d"
,
target_box
.
lod
().
size
()));
}
if
(
prior_box_var
)
{
PADDLE_ENFORCE_EQ
(
variance
.
empty
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Input 'PriorBoxVar' and attribute 'variance' "
"of BoxCoder operator should not be used at the "
"same time."
));
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE_EQ
(
static_cast
<
int
>
(
variance
.
size
()),
4
,
phi
::
errors
::
InvalidArgument
(
"Size of attribute 'variance' of BoxCoder "
"operator should be 4. But received "
"size = %d"
,
variance
.
size
()));
}
auto
code_type
=
phi
::
funcs
::
GetBoxCodeType
(
code_type_str
);
auto
row
=
target_box
.
dims
()[
0
];
auto
col
=
prior_box
.
dims
()[
0
];
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kDecodeCenterSize
)
{
col
=
target_box
.
dims
()[
1
];
}
auto
len
=
prior_box
.
dims
()[
1
];
output_box
->
Resize
({
row
,
col
,
len
});
dev_ctx
.
template
Alloc
<
T
>(
output_box
);
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSize
<
T
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
else
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kDecodeCenterSize
)
{
if
(
prior_box_var
)
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
T
,
0
,
2
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
T
,
1
,
2
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
}
else
if
(
!
(
variance
.
empty
()))
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
T
,
0
,
1
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
T
,
1
,
1
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
}
else
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
T
,
0
,
0
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
T
,
1
,
0
>
(
&
target_box
,
&
prior_box
,
prior_box_var
.
get_ptr
(),
normalized
,
variance
,
output
);
}
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
box_coder
,
CPU
,
ALL_LAYOUT
,
phi
::
BoxCoderKernel
,
float
,
double
)
{}
paddle/
fluid/operators/detection/box_coder_op
.cu
→
paddle/
phi/kernels/gpu/box_coder
.cu
浏览文件 @
98f8fa4c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/box_coder_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/box_coder.h"
namespace
paddle
{
namespace
operators
{
namespace
phi
{
template
<
typename
T
>
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
const
T
prior_box_var_size
,
const
float
*
variance
,
const
float
*
variance
,
const
int
var_size
,
T
*
output
)
{
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
row
*
col
)
{
const
int
row_idx
=
idx
/
col
;
...
...
@@ -77,18 +83,18 @@ __global__ void EncodeCenterSizeKernel(const T* prior_box_data,
}
template
<
typename
T
>
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
const
int
row
,
const
int
col
,
const
int
len
,
const
bool
normalized
,
const
T
prior_box_var_size
,
const
float
*
variance
,
const
float
*
variance
,
const
int
var_size
,
const
int
axis
,
T
*
output
)
{
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
prior_box_offset
=
0
;
if
(
idx
<
row
*
col
)
{
...
...
@@ -141,108 +147,100 @@ __global__ void DecodeCenterSizeKernel(const T* prior_box_data,
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
BoxCoderCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
prior_box
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBox"
);
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
framework
::
Tensor
>
(
"OutputBox"
);
std
::
vector
<
float
>
variance
=
context
.
Attr
<
std
::
vector
<
float
>>
(
"variance"
);
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
auto
prior_box_var_size
=
0
;
if
(
prior_box_var
)
{
PADDLE_ENFORCE_EQ
(
variance
.
empty
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input 'PriorBoxVar' and attribute 'variance'"
" of BoxCoder operator should not be used at the "
"same time."
));
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
prior_box_var_size
=
prior_box_var
->
dims
().
size
();
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE_EQ
(
static_cast
<
int
>
(
variance
.
size
()),
4
,
platform
::
errors
::
InvalidArgument
(
"Size of attribute 'variance' in BoxCoder operator"
" should be 4. But received size is %d"
,
variance
.
size
()));
}
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input 'TargetBox' of BoxCoder operator only"
" supports LoD with one level."
));
}
const
int
var_size
=
static_cast
<
int
>
(
variance
.
size
());
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
bool
normalized
=
context
.
Attr
<
bool
>
(
"box_normalized"
);
int
axis
=
context
.
Attr
<
int
>
(
"axis"
);
template
<
typename
T
,
typename
Context
>
void
BoxCoderKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
prior_box
,
const
paddle
::
optional
<
DenseTensor
>
&
prior_box_var
,
const
DenseTensor
&
target_box
,
const
std
::
string
&
code_type_str
,
bool
normalized
,
int
axis
,
const
std
::
vector
<
float
>
&
variance
,
DenseTensor
*
output_box
)
{
const
T
*
prior_box_data
=
prior_box
.
template
data
<
T
>();
const
T
*
target_box_data
=
target_box
.
template
data
<
T
>();
const
T
*
prior_box_var_data
=
nullptr
;
auto
prior_box_var_size
=
0
;
if
(
prior_box_var
)
{
PADDLE_ENFORCE_EQ
(
variance
.
empty
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Input 'PriorBoxVar' and attribute 'variance'"
" of BoxCoder operator should not be used at the "
"same time."
));
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
prior_box_var_size
=
prior_box_var
->
dims
().
size
();
}
if
(
!
(
variance
.
empty
()))
{
PADDLE_ENFORCE_EQ
(
static_cast
<
int
>
(
variance
.
size
()),
4
,
phi
::
errors
::
InvalidArgument
(
"Size of attribute 'variance' in BoxCoder operator"
" should be 4. But received size is %d"
,
variance
.
size
()));
}
auto
row
=
target_box
->
dims
()[
0
];
auto
col
=
prior_box
->
dims
()[
0
];
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
col
=
target_box
->
dims
()[
1
];
}
auto
len
=
prior_box
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
row
*
col
+
block
-
1
)
/
block
;
auto
&
device_ctx
=
context
.
cuda_device_context
();
int
bytes
=
var_size
*
sizeof
(
float
);
auto
dev_var
=
memory
::
Alloc
(
device_ctx
,
bytes
);
float
*
dev_var_data
=
reinterpret_cast
<
float
*>
(
dev_var
->
ptr
());
auto
cplace
=
platform
::
CPUPlace
();
const
auto
gplace
=
context
.
GetPlace
();
memory
::
Copy
(
gplace
,
dev_var_data
,
cplace
,
&
variance
[
0
],
bytes
,
device_ctx
.
stream
());
output_box
->
mutable_data
<
T
>
({
row
,
col
,
len
},
context
.
GetPlace
());
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSizeKernel
<
T
>
<<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSizeKernel
<
T
>
<<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
axis
,
output
);
}
if
(
target_box
.
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
.
lod
().
size
(),
1
,
phi
::
errors
::
InvalidArgument
(
"Input 'TargetBox' of BoxCoder operator only"
" supports LoD with one level."
));
}
};
const
int
var_size
=
static_cast
<
int
>
(
variance
.
size
());
auto
code_type
=
phi
::
funcs
::
GetBoxCodeType
(
code_type_str
);
auto
row
=
target_box
.
dims
()[
0
];
auto
col
=
prior_box
.
dims
()[
0
];
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kDecodeCenterSize
)
{
col
=
target_box
.
dims
()[
1
];
}
auto
len
=
prior_box
.
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
row
*
col
+
block
-
1
)
/
block
;
int
bytes
=
var_size
*
sizeof
(
float
);
auto
dev_var
=
paddle
::
memory
::
Alloc
(
dev_ctx
,
bytes
);
float
*
dev_var_data
=
reinterpret_cast
<
float
*>
(
dev_var
->
ptr
());
auto
cplace
=
phi
::
CPUPlace
();
const
auto
gplace
=
dev_ctx
.
GetPlace
();
paddle
::
memory
::
Copy
(
gplace
,
dev_var_data
,
cplace
,
&
variance
[
0
],
bytes
,
dev_ctx
.
stream
());
output_box
->
Resize
({
row
,
col
,
len
});
dev_ctx
.
template
Alloc
<
T
>(
output_box
);
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSizeKernel
<
T
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
output
);
}
else
if
(
code_type
==
phi
::
funcs
::
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSizeKernel
<
T
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
len
,
normalized
,
prior_box_var_size
,
dev_var_data
,
var_size
,
axis
,
output
);
}
}
}
// namespace operators
}
// namespace paddle
}
// namespace phi
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
box_coder
,
ops
::
BoxCoderCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
BoxCoderCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
PD_REGISTER_KERNEL
(
box_coder
,
GPU
,
ALL_LAYOUT
,
phi
::
BoxCoderKernel
,
float
,
double
)
{}
paddle/phi/kernels/impl/box_coder.h
0 → 100644
浏览文件 @
98f8fa4c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/enforce.h"
namespace
phi
{
namespace
funcs
{
enum
class
BoxCodeType
{
kEncodeCenterSize
=
0
,
kDecodeCenterSize
=
1
};
inline
BoxCodeType
GetBoxCodeType
(
const
std
::
string
&
type
)
{
PADDLE_ENFORCE_EQ
(
(
type
==
"encode_center_size"
)
||
(
type
==
"decode_center_size"
),
true
,
phi
::
errors
::
InvalidArgument
(
"The 'code_type' attribute in BoxCoder"
" must be 'encode_center_size' or 'decode_center_size'. "
"But received 'code_type' is %s"
,
type
));
if
(
type
==
"encode_center_size"
)
{
return
BoxCodeType
::
kEncodeCenterSize
;
}
else
{
return
BoxCodeType
::
kDecodeCenterSize
;
}
}
}
// namespace funcs
}
// namespace phi
paddle/phi/ops/compat/box_coder_sig.cc
0 → 100644
浏览文件 @
98f8fa4c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
BoxCoderOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"box_coder"
,
{
"PriorBox"
,
"PriorBoxVar"
,
"TargetBox"
},
{
"code_type"
,
"box_normalized"
,
"axis"
,
"variance"
},
{
"OutputBox"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
box_coder
,
phi
::
BoxCoderOpArgumentMapping
);
python/paddle/fluid/layers/detection.py
浏览文件 @
98f8fa4c
...
...
@@ -37,6 +37,7 @@ from functools import reduce
from
..data_feeder
import
convert_dtype
,
check_variable_and_dtype
,
check_type
,
check_dtype
from
paddle.utils
import
deprecated
from
paddle
import
_C_ops
from
..framework
import
in_dygraph_mode
__all__
=
[
'prior_box'
,
...
...
@@ -948,6 +949,22 @@ def box_coder(prior_box,
'box_coder'
)
check_variable_and_dtype
(
target_box
,
'target_box'
,
[
'float32'
,
'float64'
],
'box_coder'
)
if
in_dygraph_mode
():
if
isinstance
(
prior_box_var
,
Variable
):
box_coder_op
=
_C_ops
.
final_state_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
code_type
,
box_normalized
,
axis
,
[])
elif
isinstance
(
prior_box_var
,
list
):
box_coder_op
=
_C_ops
.
final_state_box_coder
(
prior_box
,
None
,
target_box
,
code_type
,
box_normalized
,
axis
,
prior_box_var
)
else
:
raise
TypeError
(
"Input variance of box_coder must be Variable or lisz"
)
return
box_coder_op
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
output_box
=
helper
.
create_variable_for_type_inference
(
...
...
python/paddle/fluid/tests/unittests/test_box_coder_op.py
浏览文件 @
98f8fa4c
...
...
@@ -19,6 +19,8 @@ import numpy as np
import
sys
import
math
from
op_test
import
OpTest
import
paddle
import
paddle.fluid.core
as
core
def
box_decoder
(
t_box
,
p_box
,
pb_v
,
output_box
,
norm
,
axis
=
0
):
...
...
@@ -105,10 +107,11 @@ def batch_box_coder(p_box, pb_v, t_box, lod, code_type, norm, axis=0):
class
TestBoxCoderOp
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
self
.
python_api
=
paddle
.
fluid
.
layers
.
box_coder
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
...
...
@@ -132,9 +135,10 @@ class TestBoxCoderOp(OpTest):
class
TestBoxCoderOpWithoutBoxVar
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
def
setUp
(
self
):
self
.
python_api
=
paddle
.
fluid
.
layers
.
box_coder
self
.
op_type
=
"box_coder"
lod
=
[[
0
,
1
,
2
,
3
,
4
,
5
]]
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
...
...
@@ -147,6 +151,7 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'PriorBoxVar'
:
prior_box_var
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
...
...
@@ -159,9 +164,10 @@ class TestBoxCoderOpWithoutBoxVar(OpTest):
class
TestBoxCoderOpWithLoD
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
def
setUp
(
self
):
self
.
python_api
=
paddle
.
fluid
.
layers
.
box_coder
self
.
op_type
=
"box_coder"
lod
=
[[
10
,
20
,
20
]]
prior_box
=
np
.
random
.
random
((
20
,
4
)).
astype
(
'float32'
)
...
...
@@ -184,9 +190,10 @@ class TestBoxCoderOpWithLoD(OpTest):
class
TestBoxCoderOpWithAxis
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
def
setUp
(
self
):
self
.
python_api
=
paddle
.
fluid
.
layers
.
box_coder
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
...
...
@@ -241,5 +248,43 @@ class TestBoxCoderOpWithVariance(OpTest):
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithVarianceDygraphAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
self
.
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
self
.
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
self
.
target_box
=
np
.
random
.
random
((
30
,
81
,
4
)).
astype
(
'float32'
)
self
.
code_type
=
"DecodeCenterSize"
self
.
box_normalized
=
False
self
.
axis
=
1
self
.
output_ref
=
batch_box_coder
(
self
.
prior_box
,
self
.
prior_box_var
,
self
.
target_box
,
self
.
lod
[
0
],
self
.
code_type
,
self
.
box_normalized
,
self
.
axis
)
self
.
place
=
[
paddle
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
self
.
place
.
append
(
paddle
.
CUDAPlace
(
0
))
def
test_dygraph_api
(
self
):
def
run
(
place
):
paddle
.
disable_static
(
place
)
output_box
=
paddle
.
fluid
.
layers
.
box_coder
(
paddle
.
to_tensor
(
self
.
prior_box
),
self
.
prior_box_var
.
tolist
(),
paddle
.
to_tensor
(
self
.
target_box
),
"decode_center_size"
,
self
.
box_normalized
,
axis
=
self
.
axis
)
self
.
assertEqual
(
np
.
allclose
(
np
.
sum
(
self
.
output_ref
),
np
.
sum
(
output_box
.
numpy
())),
True
)
paddle
.
enable_static
()
for
place
in
self
.
place
:
run
(
place
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py
浏览文件 @
98f8fa4c
...
...
@@ -50,7 +50,8 @@ class TestL2LossOp(OpTest):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
self
.
max_relative_error
)
max_relative_error
=
self
.
max_relative_error
,
check_eager
=
True
)
class
TestL2LossDeterministic
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录