Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
72eccb23
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
72eccb23
编写于
1月 28, 2018
作者:
G
gaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add box coder op
上级
430fdc52
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
531 addition
and
0 deletion
+531
-0
paddle/operators/box_coder_op.cc
paddle/operators/box_coder_op.cc
+106
-0
paddle/operators/box_coder_op.cu
paddle/operators/box_coder_op.cu
+145
-0
paddle/operators/box_coder_op.h
paddle/operators/box_coder_op.h
+163
-0
python/paddle/v2/fluid/tests/test_box_coder_op.py
python/paddle/v2/fluid/tests/test_box_coder_op.py
+117
-0
未找到文件。
paddle/operators/box_coder_op.cc
0 → 100644
浏览文件 @
72eccb23
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/box_coder_op.h"
namespace
paddle
{
namespace
operators
{
class
BoxCoderOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PriorBox"
),
"Input(PriorBox) of BoxCoderOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PriorBoxVar"
),
"Input(PriorBoxVar) of BoxCoderOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PriorBox"
),
"Input(TargetBox) of BoxCoderOp should not be null."
);
auto
prior_box_dims
=
ctx
->
GetInputDim
(
"PriorBox"
);
auto
prior_box_var_dims
=
ctx
->
GetInputDim
(
"PriorBoxVar"
);
auto
target_box_dims
=
ctx
->
GetInputDim
(
"TargetBox"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
.
size
(),
2UL
,
"The shape of PriorBox is [N, 4]"
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
[
1
],
4UL
,
"The shape of PriorBox is [N, 4]"
);
PADDLE_ENFORCE_EQ
(
prior_box_var_dims
.
size
(),
2UL
,
"The shape of PriorBoxVar is [N, 4]"
);
PADDLE_ENFORCE_EQ
(
prior_box_var_dims
[
1
],
4UL
,
"The shape of PriorBoxVar is [N, 4]"
);
PADDLE_ENFORCE_EQ
(
target_box_dims
.
size
(),
2UL
,
"The shape of TargetBox is [M, 4]"
);
PADDLE_ENFORCE_EQ
(
target_box_dims
[
1
],
4UL
,
"The shape of TargetBox is [M, 4]"
);
GetBoxCodeType
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"code_type"
));
ctx
->
SetOutputDim
(
"OutputBox"
,
framework
::
make_ddim
({
target_box_dims
[
0
],
target_box_dims
[
1
]}));
}
};
class
BoxCoderOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
BoxCoderOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"PriorBox"
,
"(Tensor, default Tensor<float>) "
"Box list PriorBox is a 2-D Tensor with shape [M, 4] holds N boxes, "
"each box is represented as [xmin, ymin, xmax, ymax], "
"[xmin, ymin] is the left top coordinate of the anchor box, "
"if the input is image feature map, they are close to the origin "
"of the coordinate system. [xmax, ymax] is the right bottom "
"coordinate of the anchor box."
);
AddInput
(
"PriorBoxVar"
,
"(Tensor, default Tensor<float>) "
"PriorBoxVar is a 2-D Tensor with shape [M, 4] holds N group "
"of variance."
);
AddInput
(
"TargetBox"
,
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
"[N, 4], each box is represented as [xmin, ymin, xmax, ymax], "
"[xmin, ymin] is the left top coordinate of the box if the input "
"is image feature map, they are close to the origin of the coordinate "
"system. [xmax, ymax] is the right bottom coordinate of the box. "
"This tensor can contain LoD information to represent a batch "
"of inputs. One instance of this batch can contain different "
"numbers of entities."
);
AddAttr
<
std
::
string
>
(
"code_type"
,
"(string, default encode_center_size) "
"the code type used with the target box"
)
.
SetDefault
(
"encode_center_size"
)
.
InEnum
({
"encode_center_size"
,
"decode_center_size"
});
AddOutput
(
"OutputBox"
,
"(Tensor, default Tensor<float>)"
"(Tensor) The output of box_coder_op, a tensor with shape [N, M, 4] "
"representing the result of N target boxes encoded/decoded with "
"M Prior boxes and variances."
);
AddComment
(
R"DOC(
Bounding Box Coder Operator.
Encode/Decode the priorbox information with the target bounding box.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
box_coder
,
ops
::
BoxCoderOp
,
ops
::
BoxCoderOpMaker
);
REGISTER_OP_CPU_KERNEL
(
box_coder
,
ops
::
BoxCoderKernel
<
float
>
,
ops
::
BoxCoderKernel
<
double
>
);
paddle/operators/box_coder_op.cu
0 → 100644
浏览文件 @
72eccb23
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/box_coder_op.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
>
__global__
void
EncodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
int
row
,
int
col
,
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
row
*
col
)
{
const
int
row_idx
=
idx
/
col
;
const
int
col_idx
=
idx
%
col
;
T
prior_box_width
=
prior_box_data
[
col_idx
*
4
+
2
]
-
prior_box_data
[
col_idx
*
4
];
T
prior_box_height
=
prior_box_data
[
col_idx
*
4
+
3
]
-
prior_box_data
[
col_idx
*
4
+
1
];
T
prior_box_center_x
=
(
prior_box_data
[
col_idx
*
4
+
2
]
+
prior_box_data
[
col_idx
*
4
])
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
col_idx
*
4
+
3
]
+
prior_box_data
[
col_idx
*
4
+
1
])
/
2
;
T
target_box_center_x
=
(
target_box_data
[
row_idx
*
4
+
2
]
+
target_box_data
[
row_idx
*
4
])
/
2
;
T
target_box_center_y
=
(
target_box_data
[
row_idx
*
4
+
3
]
+
target_box_data
[
row_idx
*
4
+
1
])
/
2
;
T
target_box_width
=
target_box_data
[
row_idx
*
4
+
2
]
-
target_box_data
[
row_idx
*
4
];
T
target_box_height
=
target_box_data
[
row_idx
*
4
+
3
]
-
target_box_data
[
row_idx
*
4
+
1
];
output
[
idx
*
4
]
=
(
target_box_center_x
-
prior_box_center_x
)
/
prior_box_width
/
prior_box_var_data
[
col_idx
*
4
];
output
[
idx
*
4
+
1
]
=
(
target_box_center_y
-
prior_box_center_y
)
/
prior_box_height
/
prior_box_var_data
[
col_idx
*
4
+
1
];
output
[
idx
*
4
+
2
]
=
log
(
fabs
(
target_box_width
/
prior_box_width
))
/
prior_box_var_data
[
col_idx
*
4
+
2
];
output
[
idx
*
4
+
3
]
=
log
(
fabs
(
target_box_height
/
prior_box_height
))
/
prior_box_var_data
[
col_idx
*
4
+
3
];
}
}
template
<
typename
T
>
__global__
void
DecodeCenterSizeKernel
(
const
T
*
prior_box_data
,
const
T
*
prior_box_var_data
,
const
T
*
target_box_data
,
int
row
,
int
col
,
T
*
output
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
row
*
col
)
{
const
int
row_idx
=
idx
/
col
;
const
int
col_idx
=
idx
%
col
;
T
prior_box_width
=
prior_box_data
[
col_idx
*
4
+
2
]
-
prior_box_data
[
col_idx
*
4
];
T
prior_box_height
=
prior_box_data
[
col_idx
*
4
+
3
]
-
prior_box_data
[
col_idx
*
4
+
1
];
T
prior_box_center_x
=
(
prior_box_data
[
col_idx
*
4
+
2
]
+
prior_box_data
[
col_idx
*
4
])
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
col_idx
*
4
+
3
]
+
prior_box_data
[
col_idx
*
4
+
1
])
/
2
;
T
target_box_width
=
exp
(
prior_box_var_data
[
col_idx
*
4
+
2
]
*
target_box_data
[
row_idx
*
4
+
2
])
*
prior_box_width
;
T
target_box_height
=
exp
(
prior_box_var_data
[
col_idx
*
4
+
3
]
*
target_box_data
[
row_idx
*
4
+
3
])
*
prior_box_height
;
T
target_box_center_x
=
prior_box_var_data
[
col_idx
*
4
]
*
target_box_data
[
row_idx
*
4
]
*
prior_box_width
+
prior_box_center_x
;
T
target_box_center_y
=
prior_box_var_data
[
col_idx
*
4
+
1
]
*
target_box_data
[
row_idx
*
4
+
1
]
*
prior_box_height
+
prior_box_center_y
;
output
[
idx
*
4
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
idx
*
4
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
output
[
idx
*
4
+
2
]
=
target_box_center_x
+
target_box_width
/
2
;
output
[
idx
*
4
+
3
]
=
target_box_center_y
+
target_box_height
/
2
;
}
}
template
<
typename
T
>
class
BoxCoderCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
*
prior_box
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBox"
);
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
Tensor
>
(
"OutputBox"
);
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1UL
,
"Only support 1 level of LoD."
);
}
auto
row
=
target_box
->
dims
()[
0
];
auto
col
=
prior_box
->
dims
()[
0
];
int
block
=
512
;
int
grid
=
(
row
*
col
+
block
-
1
)
/
block
;
auto
&
device_ctx
=
context
.
cuda_device_context
();
const
T
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
const
T
*
target_box_data
=
target_box
->
data
<
T
>
();
output_box
->
mutable_data
<
T
>
({
row
,
col
,
4
},
context
.
GetPlace
());
T
*
output
=
output_box
->
data
<
T
>
();
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSizeKernel
<
T
><<<
grid
,
block
,
0
,
device_ctx
.
stream
()
>>>
(
prior_box_data
,
prior_box_var_data
,
target_box_data
,
row
,
col
,
output
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
box_coder
,
ops
::
BoxCoderCUDAKernel
<
float
>
,
ops
::
BoxCoderCUDAKernel
<
double
>
);
paddle/operators/box_coder_op.h
0 → 100644
浏览文件 @
72eccb23
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
enum
class
BoxCodeType
{
kEncodeCenterSize
=
0
,
kDecodeCenterSize
=
1
};
inline
BoxCodeType
GetBoxCodeType
(
const
std
::
string
&
type
)
{
if
(
type
==
"encode_center_size"
)
{
return
BoxCodeType
::
kEncodeCenterSize
;
}
else
if
(
type
==
"decode_center_size"
)
{
return
BoxCodeType
::
kDecodeCenterSize
;
}
PADDLE_THROW
(
"Not support type %s."
,
type
);
}
template
<
typename
T
>
class
BoxCoderKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
EncodeCenterSize
(
const
Tensor
&
target_box
,
const
Tensor
&
prior_box
,
const
Tensor
&
prior_box_var
,
T
*
output
)
const
{
PADDLE_ENFORCE_EQ
(
target_box
.
dims
().
size
(),
2
,
"The rank of target_box must be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box
.
dims
().
size
(),
2
,
"The rank of prior_box must be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box_var
.
dims
().
size
(),
2
,
"The rank of prior_box_var must be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box
.
dims
()[
0
],
prior_box_var
.
dims
()[
0
],
"The dims of prior_box must equal to prior_box_var."
);
int64_t
row
=
target_box
.
dims
()[
0
];
int64_t
col
=
prior_box
.
dims
()[
0
];
auto
*
target_box_data
=
target_box
.
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
.
data
<
T
>
();
auto
*
prior_box_var_data
=
prior_box_var
.
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
T
prior_box_width
=
prior_box_data
[
j
*
4
+
2
]
-
prior_box_data
[
j
*
4
];
T
prior_box_height
=
prior_box_data
[
j
*
4
+
3
]
-
prior_box_data
[
j
*
4
+
1
];
T
prior_box_center_x
=
(
prior_box_data
[
j
*
4
+
2
]
+
prior_box_data
[
j
*
4
])
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
j
*
4
+
3
]
+
prior_box_data
[
j
*
4
+
1
])
/
2
;
T
target_box_center_x
=
(
target_box_data
[
i
*
4
+
2
]
+
target_box_data
[
i
*
4
])
/
2
;
T
target_box_center_y
=
(
target_box_data
[
i
*
4
+
3
]
+
target_box_data
[
i
*
4
+
1
])
/
2
;
T
target_box_width
=
target_box_data
[
i
*
4
+
2
]
-
target_box_data
[
i
*
4
];
T
target_box_height
=
target_box_data
[
i
*
4
+
3
]
-
target_box_data
[
i
*
4
+
1
];
size_t
offset
=
i
*
col
*
4
+
j
*
4
;
output
[
offset
]
=
(
target_box_center_x
-
prior_box_center_x
)
/
prior_box_width
/
prior_box_var_data
[
j
*
4
];
output
[
offset
+
1
]
=
(
target_box_center_y
-
prior_box_center_y
)
/
prior_box_height
/
prior_box_var_data
[
j
*
4
+
1
];
output
[
offset
+
2
]
=
std
::
log
(
std
::
fabs
(
target_box_width
/
prior_box_width
))
/
prior_box_var_data
[
j
*
4
+
2
];
output
[
offset
+
3
]
=
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
))
/
prior_box_var_data
[
j
*
4
+
3
];
}
}
}
void
DecodeCenterSize
(
const
Tensor
&
target_box
,
const
Tensor
&
prior_box
,
const
Tensor
&
prior_box_var
,
T
*
output
)
const
{
PADDLE_ENFORCE_EQ
(
target_box
.
dims
().
size
(),
2
,
"The rank of target_box must be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box
.
dims
().
size
(),
2
,
"The rank of prior_box must be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box_var
.
dims
().
size
(),
2
,
"The rank of prior_box_var must be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box
.
dims
()[
0
],
prior_box_var
.
dims
()[
0
],
"The dims of prior_box must equal to prior_box_var."
);
int64_t
row
=
target_box
.
dims
()[
0
];
int64_t
col
=
prior_box
.
dims
()[
0
];
auto
*
target_box_data
=
target_box
.
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
.
data
<
T
>
();
auto
*
prior_box_var_data
=
prior_box_var
.
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
T
prior_box_width
=
prior_box_data
[
j
*
4
+
2
]
-
prior_box_data
[
j
*
4
];
T
prior_box_height
=
prior_box_data
[
j
*
4
+
3
]
-
prior_box_data
[
j
*
4
+
1
];
T
prior_box_center_x
=
(
prior_box_data
[
j
*
4
+
2
]
+
prior_box_data
[
j
*
4
])
/
2
;
T
prior_box_center_y
=
(
prior_box_data
[
j
*
4
+
3
]
+
prior_box_data
[
j
*
4
+
1
])
/
2
;
T
target_box_center_x
=
prior_box_var_data
[
j
*
4
]
*
target_box_data
[
i
*
4
]
*
prior_box_width
+
prior_box_center_x
;
T
target_box_center_y
=
prior_box_var_data
[
j
*
4
+
1
]
*
target_box_data
[
i
*
4
+
1
]
*
prior_box_height
+
prior_box_center_y
;
T
target_box_width
=
std
::
exp
(
prior_box_var_data
[
j
*
4
+
2
]
*
target_box_data
[
i
*
4
+
2
])
*
prior_box_width
;
T
target_box_height
=
std
::
exp
(
prior_box_var_data
[
j
*
4
+
3
]
*
target_box_data
[
i
*
4
+
3
])
*
prior_box_height
;
size_t
offset
=
i
*
col
*
4
+
j
*
4
;
output
[
offset
]
=
target_box_center_x
-
target_box_width
/
2
;
output
[
offset
+
1
]
=
target_box_center_y
-
target_box_height
/
2
;
output
[
offset
+
2
]
=
target_box_center_x
+
target_box_width
/
2
;
output
[
offset
+
3
]
=
target_box_center_y
+
target_box_height
/
2
;
}
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
prior_box
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBox"
);
auto
*
prior_box_var
=
context
.
Input
<
framework
::
Tensor
>
(
"PriorBoxVar"
);
auto
*
target_box
=
context
.
Input
<
framework
::
LoDTensor
>
(
"TargetBox"
);
auto
*
output_box
=
context
.
Output
<
Tensor
>
(
"OutputBox"
);
if
(
target_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
target_box
->
lod
().
size
(),
1UL
,
"Only support 1 level of LoD."
);
}
auto
row
=
target_box
->
dims
()[
0
];
auto
col
=
prior_box
->
dims
()[
0
];
output_box
->
mutable_data
<
T
>
({
row
,
col
,
4
},
context
.
GetPlace
());
auto
code_type
=
GetBoxCodeType
(
context
.
Attr
<
std
::
string
>
(
"code_type"
));
T
*
output
=
output_box
->
data
<
T
>
();
if
(
code_type
==
BoxCodeType
::
kEncodeCenterSize
)
{
EncodeCenterSize
(
*
target_box
,
*
prior_box
,
*
prior_box_var
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSize
(
*
target_box
,
*
prior_box
,
*
prior_box_var
,
output
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/v2/fluid/tests/test_box_coder_op.py
0 → 100644
浏览文件 @
72eccb23
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
sys
import
math
from
op_test
import
OpTest
def
box_coder
(
target_box
,
prior_box
,
prior_box_var
,
output_box
,
code_type
):
prior_box_x
=
(
prior_box
[:,
2
]
+
prior_box
[:,
0
])
/
2
prior_box_y
=
(
prior_box
[:,
3
]
+
prior_box
[:,
1
])
/
2
prior_box_width
=
(
prior_box
[:,
2
]
-
prior_box
[:,
0
])
prior_box_height
=
(
prior_box
[:,
3
]
-
prior_box
[:,
1
])
if
(
code_type
==
"EncodeCenterSize"
):
target_box_x
=
(
target_box
[:,
2
]
+
target_box
[:,
0
])
/
2
target_box_y
=
(
target_box
[:,
3
]
+
target_box
[:,
1
])
/
2
target_box_width
=
(
target_box
[:,
2
]
-
target_box
[:,
0
])
target_box_height
=
(
target_box
[:,
3
]
-
target_box
[:,
1
])
for
i
in
range
(
target_box
.
shape
[
0
]):
output_box
[
i
,:,
0
]
=
(
target_box_x
[
i
]
-
prior_box_x
)
/
prior_box_width
/
\
prior_box_var
[:,
0
]
output_box
[
i
,:,
1
]
=
(
target_box_y
[
i
]
-
prior_box_y
)
/
prior_box_height
/
\
prior_box_var
[:,
1
]
output_box
[
i
,:,
2
]
=
np
.
log
(
np
.
fabs
(
target_box_width
[
i
]
/
prior_box_width
))
/
\
prior_box_var
[:,
2
]
output_box
[
i
,:,
3
]
=
np
.
log
(
np
.
fabs
(
target_box_height
[
i
]
/
prior_box_height
))
/
\
prior_box_var
[:,
3
]
elif
(
code_type
==
"DecodeCenterSize"
):
for
i
in
range
(
target_box
.
shape
[
0
]):
target_box_x
=
prior_box_var
[:,
0
]
*
target_box
[
i
][
0
]
*
\
prior_box_width
[:]
+
prior_box_x
[:]
target_box_y
=
prior_box_var
[:,
1
]
*
target_box
[
i
][
1
]
*
\
prior_box_height
[:]
+
prior_box_y
[:]
target_box_width
=
np
.
exp
(
prior_box_var
[:,
2
]
*
target_box
[
i
][
2
])
*
\
prior_box_width
[:]
target_box_height
=
np
.
exp
(
prior_box_var
[:,
3
]
*
target_box
[
i
][
3
])
*
\
prior_box_height
[:]
output_box
[
i
,
:,
0
]
=
target_box_x
-
target_box_width
/
2
output_box
[
i
,
:,
1
]
=
target_box_y
-
target_box_height
/
2
output_box
[
i
,
:,
2
]
=
target_box_x
+
target_box_width
/
2
output_box
[
i
,
:,
3
]
=
target_box_y
+
target_box_height
/
2
def
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
,
code_type
):
n
=
target_box
.
shape
[
0
]
m
=
prior_box
.
shape
[
0
]
output_box
=
np
.
zeros
((
n
,
m
,
4
),
dtype
=
np
.
float32
)
for
i
in
range
(
len
(
lod
)
-
1
):
box_coder
(
target_box
[
lod
[
i
]:
lod
[
i
+
1
],
:],
prior_box
,
prior_box_var
,
output_box
[
lod
[
i
]:
lod
[
i
+
1
],
:,
:],
code_type
)
return
output_box
class
TestBoxCoderOp
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
0
,
20
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'PriorBoxVar'
:
prior_box_var
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithLoD
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
0
,
4
,
12
,
20
]]
prior_box
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
10
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
4
)).
astype
(
'float32'
)
code_type
=
"EncodeCenterSize"
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'PriorBoxVar'
:
prior_box_var
,
'TargetBox'
:
(
target_box
,
lod
),
}
self
.
attrs
=
{
'code_type'
:
'encode_center_size'
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录