Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ee7d8421
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee7d8421
编写于
2月 07, 2018
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update doc and follow comments.
上级
09b78c72
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
83 addition
and
52 deletion
+83
-52
paddle/operators/target_assign_op.cc
paddle/operators/target_assign_op.cc
+44
-14
paddle/operators/target_assign_op.cu
paddle/operators/target_assign_op.cu
+13
-13
paddle/operators/target_assign_op.h
paddle/operators/target_assign_op.h
+26
-21
python/paddle/v2/fluid/tests/test_target_assign_op.py
python/paddle/v2/fluid/tests/test_target_assign_op.py
+0
-4
未找到文件。
paddle/operators/target_assign_op.cc
浏览文件 @
ee7d8421
/* Copyright (c) 201
6
PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -61,10 +61,12 @@ class TargetAssignOp : public framework::OperatorWithKernel {
...
@@ -61,10 +61,12 @@ class TargetAssignOp : public framework::OperatorWithKernel {
"The rank of Input(NegIndices) must be 2."
);
"The rank of Input(NegIndices) must be 2."
);
PADDLE_ENFORCE_EQ
(
blabel_dims
[
0
],
slabel_dims
[
0
],
PADDLE_ENFORCE_EQ
(
blabel_dims
[
0
],
slabel_dims
[
0
],
"The 1st dimension of Input(EncodedGTBBox) and "
"The 1st dimension (means the total number of "
"Input(GTScoreLabel) must be the same."
);
"ground-truth bounding boxes) of Input(EncodedGTBBox) "
"and Input(GTScoreLabel) must be the same."
);
PADDLE_ENFORCE_EQ
(
blabel_dims
[
1
],
mi_dims
[
1
],
PADDLE_ENFORCE_EQ
(
blabel_dims
[
1
],
mi_dims
[
1
],
"The 2nd dimension of Input(EncodedGTBBox) and "
"The 2nd dimension (means the number of priod boxes) "
"of Input(EncodedGTBBox) and "
"Input(MatchIndices) must be the same."
);
"Input(MatchIndices) must be the same."
);
PADDLE_ENFORCE_EQ
(
blabel_dims
[
2
],
4
,
PADDLE_ENFORCE_EQ
(
blabel_dims
[
2
],
4
,
"The 3rd dimension of Input(EncodedGTBBox) must be 4."
);
"The 3rd dimension of Input(EncodedGTBBox) must be 4."
);
...
@@ -101,31 +103,31 @@ class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -101,31 +103,31 @@ class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"labels with shape [Ng, 1], where the Ng is the same as it in "
"labels with shape [Ng, 1], where the Ng is the same as it in "
"the input of EncodedGTBBox."
);
"the input of EncodedGTBBox."
);
AddInput
(
"MatchIndices"
,
AddInput
(
"MatchIndices"
,
"(Tensor, default
LoD
Tensor<int>), The input matched indices "
"(Tensor, default Tensor<int>), The input matched indices "
"with shape [N, Np], where N is the batch size, Np is the same "
"with shape [N, Np], where N is the batch size, Np is the same "
"as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
"as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the j-th prior box is not matched to any ground-truh "
"is -1, the j-th prior box is not matched to any ground-truh "
"box in i-th instance."
);
"box in i-th instance."
);
AddInput
(
"NegIndices"
,
AddInput
(
"NegIndices"
,
"(LoDTensor, default LoDTensor<int>), The input negative example "
"(LoDTensor, default LoDTensor<int>), The input negative example "
"indics with shape [Neg, 1], where is the total number of "
"indic
e
s with shape [Neg, 1], where is the total number of "
"negative example indices."
);
"negative example indices."
);
AddAttr
<
int
>
(
"background_label"
,
AddAttr
<
int
>
(
"background_label"
,
"(int, default 0), Label i
d for
background class."
)
"(int, default 0), Label i
ndex of
background class."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddOutput
(
"PredBBoxLabel"
,
AddOutput
(
"PredBBoxLabel"
,
"(Tensor), The output encoded ground-truth labels "
"(Tensor), The output encoded ground-truth labels "
"with shape [N, Np, 4], N is the batch size and Np, 4 is the "
"with shape [N, Np, 4], N is the batch size and Np, 4 is the "
"same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
"same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
"is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
"box for background_label
_id
in i-th instance."
);
"box for background_label in i-th instance."
);
AddOutput
(
"PredBBoxWeight"
,
AddOutput
(
"PredBBoxWeight"
,
"(Tensor), The weight for PredBBoxLabel with the shape "
"(Tensor), The weight for PredBBoxLabel with the shape "
"of [N, Np, 1]"
);
"of [N, Np, 1]"
);
AddOutput
(
"PredScoreLabel"
,
AddOutput
(
"PredScoreLabel"
,
"(Tensor, default Tensor<int>), The output score labels for "
"(Tensor, default Tensor<int>), The output score labels for "
"each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
"each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
"is -1, PredScoreLabel[i][j] = background_label
_id
."
);
"is -1, PredScoreLabel[i][j] = background_label."
);
AddOutput
(
"PredScoreWeight"
,
AddOutput
(
"PredScoreWeight"
,
"(Tensor), The weight for PredScoreLabel with the shape "
"(Tensor), The weight for PredScoreLabel with the shape "
"of [N, Np, 1]"
);
"of [N, Np, 1]"
);
...
@@ -136,19 +138,47 @@ and regression targets to each prior box as well as weights to each
...
@@ -136,19 +138,47 @@ and regression targets to each prior box as well as weights to each
prior box. The weights is used to specify which prior box would not contribute
prior box. The weights is used to specify which prior box would not contribute
to training loss.
to training loss.
TODO(dang qingqing) add an example.
For each instance, the output `PredBBoxLabel`, `PredBBoxWeight`,
`PredScoreLabel` and `PredScoreWeight` are assigned based on `MatchIndices`.
Assumed that the row offset for each instance in `EncodedGTBBox` is called lod,
this operato assigns classification/regression targets by performing the
following steps:
1. Assigning all outpts based on `MatchIndices`:
If id = MatchIndices[i][j] > 0,
PredBBoxLabel[i][j] = EncodedGTBBox[lod[i] + id][j]
PredBBoxWeight[i][j] = 1.
PredScoreLabel[i][j] = GTScoreLabel[lod[i] + id]
PredScoreWeight[i][j] = 1.
Otherwise,
PredBBoxLabel[j][j] = [0., 0., 0., 0.]
PredBBoxWeight[i][j] = 0.
PredScoreLabel[i][j] = background_label
PredScoreWeight[i][j] = 0.
2. Assigning PredScoreWeight based on `NegIndices`:
Assumed that the row offset for each instance in `NegIndices` is caleed neg_lod,
for i-th instance and all ids of NegIndices in this instance:
PredScoreLabel[i][id] = background_label
PredScoreWeight[i][id] = 1.0
)DOC"
);
)DOC"
);
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
UpdateTargetLabel
Functor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
NegTargetAssign
Functor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
int
*
neg_indices
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
int
*
neg_indices
,
const
size_t
*
lod
,
const
int
num
,
const
int
num_prior_box
,
const
size_t
*
lod
,
const
int
num
,
const
int
num_prior_box
,
const
int
background_label
,
int
*
out_label
,
T
*
out_label_wt
)
{
const
int
background_label
,
int
*
out_label
,
T
*
out_label_wt
)
{
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
for
(
in
t
j
=
lod
[
i
];
j
<
lod
[
i
+
1
];
++
j
)
{
for
(
size_
t
j
=
lod
[
i
];
j
<
lod
[
i
+
1
];
++
j
)
{
int
id
=
neg_indices
[
j
];
int
id
=
neg_indices
[
j
];
out_label
[
i
*
num_prior_box
+
id
]
=
background_label
;
out_label
[
i
*
num_prior_box
+
id
]
=
background_label
;
out_label_wt
[
i
*
num_prior_box
+
id
]
=
static_cast
<
T
>
(
1.0
);
out_label_wt
[
i
*
num_prior_box
+
id
]
=
static_cast
<
T
>
(
1.0
);
...
@@ -157,8 +187,8 @@ struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> {
...
@@ -157,8 +187,8 @@ struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> {
}
}
};
};
template
struct
UpdateTargetLabel
Functor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
NegTargetAssign
Functor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
UpdateTargetLabel
Functor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
NegTargetAssign
Functor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/operators/target_assign_op.cu
浏览文件 @
ee7d8421
/* Copyright (c) 201
6
PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -18,38 +18,38 @@ namespace paddle {
...
@@ -18,38 +18,38 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
__global__
void
UpdateTargetLabelKernel
(
const
int
*
neg_indices
,
__global__
void
NegTargetAssignKernel
(
const
int
*
neg_indices
,
const
size_t
*
lod
,
const
size_t
*
lod
,
const
int
num
,
const
int
num
,
const
int
num_prior_box
,
const
int
num_prior_box
,
const
int
background_label
,
const
int
background_label
,
int
*
out_label
,
T
*
out_label_wt
)
{
int
*
out_label
,
T
*
out_label_wt
)
{
int
bidx
=
blockIdx
.
x
;
int
bidx
=
blockIdx
.
x
;
int
st
=
lod
[
bidx
];
int
st
=
lod
[
bidx
];
int
ed
=
lod
[
bidx
+
1
];
int
ed
=
lod
[
bidx
+
1
];
int
row_start
=
bidx
*
num_prior_box
;
for
(
int
i
=
st
+
threadIdx
.
x
;
i
<
ed
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
st
+
threadIdx
.
x
;
i
<
ed
;
i
+=
blockDim
.
x
)
{
int
id
=
neg_indices
[
i
];
int
id
=
row_start
+
neg_indices
[
i
];
out_label
[
bidx
*
num_prior_box
+
id
]
=
background_label
;
out_label
[
id
]
=
background_label
;
out_label_wt
[
bidx
*
num_prior_box
+
id
]
=
1.
;
out_label_wt
[
id
]
=
1.
;
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
UpdateTargetLabel
Functor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
NegTargetAssign
Functor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
int
*
neg_indices
,
const
size_t
*
lod
,
const
int
num
,
const
int
*
neg_indices
,
const
size_t
*
lod
,
const
int
num
,
const
int
num_prior_box
,
const
int
background_label
,
const
int
num_prior_box
,
const
int
background_label
,
int
*
out_label
,
T
*
out_label_wt
)
{
int
*
out_label
,
T
*
out_label_wt
)
{
const
int
block_size
=
256
;
const
int
block_size
=
256
;
const
int
grid_size
=
num
;
const
int
grid_size
=
num
;
UpdateTargetLabel
Kernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
NegTargetAssign
Kernel
<
T
><<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
neg_indices
,
lod
,
num
,
num_prior_box
,
background_label
,
out_label
,
neg_indices
,
lod
,
num
,
num_prior_box
,
background_label
,
out_label
,
out_label_wt
);
out_label_wt
);
}
}
};
};
template
struct
UpdateTargetLabel
Functor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
NegTargetAssign
Functor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
UpdateTargetLabel
Functor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
NegTargetAssign
Functor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/operators/target_assign_op.h
浏览文件 @
ee7d8421
/* Copyright (c) 201
6
PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -56,40 +56,41 @@ struct TargetAssignFunctor {
...
@@ -56,40 +56,41 @@ struct TargetAssignFunctor {
int
row
=
i
/
num_prior_box_
;
int
row
=
i
/
num_prior_box_
;
int
col
=
i
-
row
*
num_prior_box_
;
int
col
=
i
-
row
*
num_prior_box_
;
size_t
off
=
lod_
[
row
];
size_t
row_off
=
lod_
[
row
];
int
offset
=
row
*
num_prior_box_
+
col
;
int
id
=
match_indices_
[
row
*
num_prior_box_
+
col
];
int
id
=
match_indices_
[
offset
];
T
*
obox
=
out_box_
+
(
row
*
num_prior_box_
+
col
)
*
4
;
T
*
obox
=
out_box_
+
offset
*
4
;
int
*
olabel
=
out_label_
+
row
*
num_prior_box_
+
col
;
int
*
olabel
=
out_label_
+
offset
;
T
*
obox_wt
=
out_box_wt_
+
row
*
num_prior_box_
+
col
;
T
*
obox_wt
=
out_box_wt_
+
offset
;
T
*
olabel_wt
=
out_label_wt_
+
row
*
num_prior_box_
+
col
;
T
*
olabel_wt
=
out_label_wt_
+
offset
;
if
(
id
>
-
1
)
{
if
(
id
>
-
1
)
{
const
T
*
gtbox
=
gt_box_
+
((
off
+
id
)
*
num_prior_box_
+
col
)
*
4
;
const
T
*
gtbox
=
gt_box_
+
((
row_
off
+
id
)
*
num_prior_box_
+
col
)
*
4
;
obox
[
0
]
=
gtbox
[
0
];
obox
[
0
]
=
gtbox
[
0
];
obox
[
1
]
=
gtbox
[
1
];
obox
[
1
]
=
gtbox
[
1
];
obox
[
2
]
=
gtbox
[
2
];
obox
[
2
]
=
gtbox
[
2
];
obox
[
3
]
=
gtbox
[
3
];
obox
[
3
]
=
gtbox
[
3
];
olabel
[
0
]
=
gt_label_
[
off
+
id
];
olabel
[
0
]
=
gt_label_
[
row_
off
+
id
];
obox_wt
[
0
]
=
1.
;
obox_wt
[
0
]
=
static_cast
<
T
>
(
1.
)
;
olabel_wt
[
0
]
=
1.
;
olabel_wt
[
0
]
=
static_cast
<
T
>
(
1.
)
;
}
else
{
}
else
{
obox
[
0
]
=
0.
;
obox
[
0
]
=
static_cast
<
T
>
(
0.
)
;
obox
[
1
]
=
0.
;
obox
[
1
]
=
static_cast
<
T
>
(
0.
)
;
obox
[
2
]
=
0.
;
obox
[
2
]
=
static_cast
<
T
>
(
0.
)
;
obox
[
3
]
=
0.
;
obox
[
3
]
=
static_cast
<
T
>
(
0.
)
;
olabel
[
0
]
=
background_label_
;
olabel
[
0
]
=
background_label_
;
obox_wt
[
0
]
=
0.
;
obox_wt
[
0
]
=
static_cast
<
T
>
(
0.
)
;
olabel_wt
[
0
]
=
0.
;
olabel_wt
[
0
]
=
static_cast
<
T
>
(
0.
)
;
}
}
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
UpdateTargetLabel
Functor
{
struct
NegTargetAssign
Functor
{
void
operator
()(
const
platform
::
DeviceContext
&
ctx
,
const
int
*
neg_indices
,
void
operator
()(
const
platform
::
DeviceContext
&
ctx
,
const
int
*
neg_indices
,
const
size_t
*
lod
,
const
int
num
,
const
int
num_prior_box
,
const
size_t
*
lod
,
const
int
num
,
const
int
num_prior_box
,
const
int
background_label
,
int
*
out_label
,
const
int
background_label
,
int
*
out_label
,
...
@@ -130,7 +131,11 @@ class TargetAssignKernel : public framework::OpKernel<T> {
...
@@ -130,7 +131,11 @@ class TargetAssignKernel : public framework::OpKernel<T> {
int64_t
num_prior_box
=
match_indices
->
dims
()[
1
];
int64_t
num_prior_box
=
match_indices
->
dims
()[
1
];
auto
gt_lod
=
enc_gt_box
->
lod
().
back
();
auto
gt_lod
=
enc_gt_box
->
lod
().
back
();
auto
gt_label_lod
=
gt_label
->
lod
().
back
();
auto
neg_lod
=
neg_indices
->
lod
().
back
();
auto
neg_lod
=
neg_indices
->
lod
().
back
();
for
(
size_t
i
=
0
;
i
<
gt_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
gt_lod
.
data
()[
i
],
gt_label_lod
.
data
()[
i
]);
}
size_t
*
gt_lod_data
=
gt_lod
.
data
(
ctx
.
GetPlace
());
size_t
*
gt_lod_data
=
gt_lod
.
data
(
ctx
.
GetPlace
());
size_t
*
neg_lod_data
=
neg_lod
.
data
(
ctx
.
GetPlace
());
size_t
*
neg_lod_data
=
neg_lod
.
data
(
ctx
.
GetPlace
());
...
@@ -145,9 +150,9 @@ class TargetAssignKernel : public framework::OpKernel<T> {
...
@@ -145,9 +150,9 @@ class TargetAssignKernel : public framework::OpKernel<T> {
num
*
num_prior_box
);
num
*
num_prior_box
);
for_range
(
functor
);
for_range
(
functor
);
UpdateTargetLabelFunctor
<
DeviceContext
,
T
>
update
_functor
;
NegTargetAssignFunctor
<
DeviceContext
,
T
>
neg_trg
_functor
;
update
_functor
(
device_ctx
,
neg_idx_data
,
neg_lod_data
,
num
,
num_prior_box
,
neg_trg
_functor
(
device_ctx
,
neg_idx_data
,
neg_lod_data
,
num
,
num_prior_box
,
background_label
,
olabel_data
,
olabel_wt_data
);
background_label
,
olabel_data
,
olabel_wt_data
);
}
}
};
};
...
...
python/paddle/v2/fluid/tests/test_target_assign_op.py
浏览文件 @
ee7d8421
...
@@ -14,8 +14,6 @@
...
@@ -14,8 +14,6 @@
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
math
import
sys
import
random
import
random
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -89,8 +87,6 @@ class TestTargetAssginOp(OpTest):
...
@@ -89,8 +87,6 @@ class TestTargetAssginOp(OpTest):
num_class
=
21
num_class
=
21
gt_lod
=
[
0
,
5
,
11
,
23
]
gt_lod
=
[
0
,
5
,
11
,
23
]
neg_lod
=
[
0
,
4
,
7
,
13
]
neg_lod
=
[
0
,
4
,
7
,
13
]
#gt_lod = [0, 2, 5]
#neg_lod = [0, 2, 4]
batch_size
=
len
(
gt_lod
)
-
1
batch_size
=
len
(
gt_lod
)
-
1
num_gt
=
gt_lod
[
-
1
]
num_gt
=
gt_lod
[
-
1
]
background_label
=
0
background_label
=
0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录