Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
201d4f2a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
201d4f2a
编写于
10月 24, 2018
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into add_dropout_att_new
上级
a6e6bc45
e906c8e5
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
193 addition
and
286 deletion
+193
-286
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+0
-1
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+0
-53
paddle/fluid/framework/op_proto_maker.h
paddle/fluid/framework/op_proto_maker.h
+0
-11
paddle/fluid/framework/op_proto_maker_test.cc
paddle/fluid/framework/op_proto_maker_test.cc
+0
-117
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+4
-6
paddle/fluid/inference/api/demo_ci/run.sh
paddle/fluid/inference/api/demo_ci/run.sh
+1
-1
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+1
-1
paddle/fluid/operators/adam_op.cc
paddle/fluid/operators/adam_op.cc
+3
-3
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+3
-5
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+2
-4
paddle/fluid/operators/detection/rpn_target_assign_op.cc
paddle/fluid/operators/detection/rpn_target_assign_op.cc
+52
-16
paddle/fluid/operators/elementwise_op.h
paddle/fluid/operators/elementwise_op.h
+0
-5
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+1
-1
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+2
-4
paddle/fluid/operators/sgd_op.cc
paddle/fluid/operators/sgd_op.cc
+1
-2
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+1
-2
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+1
-1
paddle/fluid/operators/top_k_op.cc
paddle/fluid/operators/top_k_op.cc
+1
-1
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+16
-16
paddle/fluid/operators/top_k_op.h
paddle/fluid/operators/top_k_op.h
+1
-4
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+11
-5
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+5
-2
python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py
...paddle/fluid/tests/unittests/test_rpn_target_assign_op.py
+34
-14
python/paddle/fluid/tests/unittests/test_top_k_op.py
python/paddle/fluid/tests/unittests/test_top_k_op.py
+53
-11
未找到文件。
paddle/fluid/framework/framework.proto
浏览文件 @
201d4f2a
...
...
@@ -80,7 +80,6 @@ message OpProto {
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
dispensable
=
5
[
default
=
false
];
optional
string
reuse
=
6
;
}
// AttrProto describes the C++ type Attribute.
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
201d4f2a
...
...
@@ -21,7 +21,6 @@ namespace framework {
void
OpProtoAndCheckerMaker
::
Validate
()
{
validated_
=
true
;
CheckNoDuplicatedInOutAttrs
();
CheckReuseVars
();
}
OpProtoAndCheckerMaker
::
VariableBuilder
OpProtoAndCheckerMaker
::
AddInput
(
...
...
@@ -40,40 +39,6 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
return
OpProtoAndCheckerMaker
::
VariableBuilder
{
output
};
}
void
OpProtoAndCheckerMaker
::
Reuse
(
const
std
::
string
&
name
,
const
std
::
string
&
reused_name
)
{
bool
found
=
false
;
proto
::
OpProto
::
Var
*
var
;
for
(
auto
&
var
:
proto_
->
inputs
())
{
if
(
var
.
name
()
==
reused_name
)
{
found
=
true
;
break
;
}
}
PADDLE_ENFORCE
(
found
==
true
,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched."
,
name
,
reused_name
);
found
=
false
;
for
(
int
i
=
0
;
i
<
proto_
->
outputs
().
size
();
++
i
)
{
var
=
proto_
->
mutable_outputs
()
->
Mutable
(
i
);
if
(
var
->
name
()
==
name
)
{
PADDLE_ENFORCE
(
!
var
->
has_reuse
(),
"Output(%s) has been set reused var of %s"
,
name
,
var
->
reuse
());
found
=
true
;
var
->
set_reuse
(
reused_name
);
break
;
}
}
PADDLE_ENFORCE
(
found
==
true
,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched."
,
name
,
reused_name
);
}
void
OpProtoAndCheckerMaker
::
CheckNoDuplicatedInOutAttrs
()
{
std
::
unordered_set
<
std
::
string
>
names
;
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
...
...
@@ -91,24 +56,6 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
}
}
void
OpProtoAndCheckerMaker
::
CheckReuseVars
()
{
std
::
unordered_set
<
std
::
string
>
names
;
for
(
auto
&
input
:
proto_
->
inputs
())
{
names
.
insert
(
input
.
name
());
}
auto
checker
=
[
&
](
const
std
::
string
&
name
,
const
std
::
string
&
reused
)
{
PADDLE_ENFORCE
(
names
.
count
(
reused
),
"Output [%s] reuse Input [%s], but the input is not registered."
,
name
,
reused
);
};
for
(
auto
&
output
:
proto_
->
outputs
())
{
if
(
output
.
has_reuse
())
{
checker
(
output
.
name
(),
output
.
reuse
());
}
}
}
void
OpProtoAndCheckerMaker
::
operator
()(
proto
::
OpProto
*
proto
,
OpAttrChecker
*
attr_checker
)
{
proto_
=
proto
;
...
...
paddle/fluid/framework/op_proto_maker.h
浏览文件 @
201d4f2a
...
...
@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
...
...
@@ -73,11 +71,6 @@ class OpProtoAndCheckerMaker {
var_
->
set_dispensable
(
true
);
return
*
this
;
}
VariableBuilder
&
Reuse
(
const
std
::
string
&
name
)
{
var_
->
set_reuse
(
name
);
return
*
this
;
}
};
VariableBuilder
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
);
...
...
@@ -85,8 +78,6 @@ class OpProtoAndCheckerMaker {
VariableBuilder
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
);
void
Reuse
(
const
std
::
string
&
name
,
const
std
::
string
&
reused_name
);
template
<
typename
T
>
TypedAttrChecker
<
T
>
&
AddAttr
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
...
...
@@ -105,8 +96,6 @@ class OpProtoAndCheckerMaker {
void
CheckNoDuplicatedInOutAttrs
();
void
Validate
();
void
CheckReuseVars
();
proto
::
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
bool
validated_
{
false
};
...
...
paddle/fluid/framework/op_proto_maker_test.cc
浏览文件 @
201d4f2a
...
...
@@ -47,120 +47,3 @@ TEST(ProtoMaker, DuplicatedInOut) {
ASSERT_THROW
(
proto_maker
(
&
op_proto
,
&
op_checker
),
paddle
::
platform
::
EnforceNotMet
);
}
class
TestInplaceProtoMaker
:
public
paddle
::
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"input of test op"
);
AddOutput
(
"XOut"
,
"output of test op"
).
Reuse
(
"X"
);
}
};
class
TestInplaceProtoMaker2
:
public
paddle
::
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"input of test op"
);
AddOutput
(
"XOut"
,
"output of test op"
).
Reuse
(
"X"
);
AddOutput
(
"NoOut"
,
"output of test op"
).
Reuse
(
"NotExists"
);
}
};
TEST
(
ProtoMaker
,
InplaceOutput
)
{
paddle
::
framework
::
proto
::
OpProto
op_proto
,
op_proto2
;
paddle
::
framework
::
OpAttrChecker
op_checker
;
TestInplaceProtoMaker
proto_maker
;
TestInplaceProtoMaker2
proto_maker2
;
proto_maker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker2
(
&
op_proto2
,
&
op_checker
),
paddle
::
platform
::
EnforceNotMet
);
}
// normal reuse
class
TestReuseProtoMaker
:
public
paddle
::
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"input of test op"
);
AddInput
(
"Y"
,
"input of test op"
);
AddOutput
(
"Out"
,
"output of test op"
);
AddOutput
(
"XOut"
,
"output of test op"
);
// avoid destructor exception.
// Validate();
TestReuse
();
}
virtual
void
TestReuse
()
{}
};
// test duplicate reuse error
class
TestReuseProtoMaker2
:
public
TestReuseProtoMaker
{
public:
void
TestReuse
()
{
Reuse
(
"Out"
,
"X"
);
Reuse
(
"Out"
,
"Y"
);
}
};
// NotExists Input
class
TestReuseProtoMaker3
:
public
TestReuseProtoMaker
{
public:
void
TestReuse
()
{
Reuse
(
"Out"
,
"NotExists"
);
Reuse
(
"XOut"
,
"X"
);
}
};
// NotExists Output
class
TestReuseProtoMaker4
:
public
TestReuseProtoMaker
{
public:
void
TestReuse
()
{
Reuse
(
"NotExists"
,
"X"
);
}
};
TEST
(
ProtoMaker
,
Reuse
)
{
paddle
::
framework
::
proto
::
OpProto
op_proto
;
paddle
::
framework
::
OpAttrChecker
op_checker
;
TestReuseProtoMaker
proto_maker
;
proto_maker
(
&
op_proto
,
&
op_checker
);
}
// NOTE(dzhwinter):
// There is a Fatal CHECK on base class destructor, which will call abort inside
// instead of
// throw an exception. If we throw an exception in Make(), we will trigger the
// CHECK and terminate the tests.
//
// I had tried to replace the default CHECK with a exception, however, it's
// still not supported by glog.
// the details:
// https://github.com/google/glog/issues/249
// https://github.com/facebookresearch/TensorComprehensions/issues/351
/*
TEST(ProtoMaker, ReuseWithException) {
paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker2 proto_maker2;
TestReuseProtoMaker3 proto_maker3;
TestReuseProtoMaker4 proto_maker4;
EXPECT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker3(&op_proto3, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker4(&op_proto4, &op_checker),
paddle::platform::EnforceNotMet);
}
void FailureFunction() {
throw std::runtime_error("Check failed in destructor.");
// return 0;
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
google::InstallFailureFunction(&FailureFunction);
return RUN_ALL_TESTS();
}
*/
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
201d4f2a
...
...
@@ -156,12 +156,10 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
if
(
VLOG_IS_ON
(
5
))
{
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
"The number of graph should be only one"
);
}
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
"The number of graph should be only one"
);
}
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
...
...
paddle/fluid/inference/api/demo_ci/run.sh
浏览文件 @
201d4f2a
...
...
@@ -21,7 +21,7 @@ else
fi
USE_TENSORRT
=
OFF
if
[
[
-d
"
$TENSORRT_INCLUDE_DIR
"
]
-a
[
-d
"
$TENSORRT_LIB_DIR
"
]
]
;
then
if
[
-d
"
$TENSORRT_INCLUDE_DIR
"
-a
-d
"
$TENSORRT_LIB_DIR
"
]
;
then
USE_TENSORRT
=
ON
fi
...
...
paddle/fluid/operators/activation_op.cc
浏览文件 @
201d4f2a
...
...
@@ -28,7 +28,7 @@ using paddle::framework::Tensor;
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator")
.Reuse("X");
\
AddOutput("Out", "Output of " #OP_NAME " operator")
;
\
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
...
...
paddle/fluid/operators/adam_op.cc
浏览文件 @
201d4f2a
...
...
@@ -92,9 +92,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Beta1Pow"
,
"(Tensor) Input beta1 power accumulator"
);
AddInput
(
"Beta2Pow"
,
"(Tensor) Input beta2 power accumulator"
);
AddOutput
(
"ParamOut"
,
"(Tensor) Output parameter"
)
.
Reuse
(
"Param"
)
;
AddOutput
(
"Moment1Out"
,
"(Tensor) Output first moment"
)
.
Reuse
(
"Moment1"
)
;
AddOutput
(
"Moment2Out"
,
"(Tensor) Output second moment"
)
.
Reuse
(
"Moment2"
)
;
AddOutput
(
"ParamOut"
,
"(Tensor) Output parameter"
);
AddOutput
(
"Moment1Out"
,
"(Tensor) Output first moment"
);
AddOutput
(
"Moment2Out"
,
"(Tensor) Output second moment"
);
AddAttr
<
float
>
(
"beta1"
,
"(float, default 0.9) "
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
201d4f2a
...
...
@@ -135,15 +135,13 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Variance"
,
"The global variance (for training) "
"or estimated Variance (for testing)"
);
AddOutput
(
"Y"
,
"result after normalization"
)
.
Reuse
(
"X"
)
;
AddOutput
(
"Y"
,
"result after normalization"
);
AddOutput
(
"MeanOut"
,
"Share memory with Mean. "
"Store the global mean when training"
)
.
Reuse
(
"Mean"
);
"Store the global mean when training"
);
AddOutput
(
"VarianceOut"
,
"Share memory with Variance. "
"Store the global Variance when training"
)
.
Reuse
(
"Variance"
);
"Store the global Variance when training"
);
AddOutput
(
"SavedMean"
,
"Mean of the current mini batch, "
"will apply to output when training"
)
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
201d4f2a
...
...
@@ -130,8 +130,7 @@ void Conv2DOpMaker::Make() {
.
AsDispensable
();
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."
)
.
Reuse
(
"Input"
);
"The format of output tensor is also NCHW."
);
AddInput
(
"ResidualData"
,
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
...
...
@@ -238,8 +237,7 @@ void Conv3DOpMaker::Make() {
"input image channels divided by the groups."
);
AddOutput
(
"Output"
,
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."
)
.
Reuse
(
"Input"
);
"The format of output tensor is also NCDHW."
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
...
...
paddle/fluid/operators/detection/rpn_target_assign_op.cc
浏览文件 @
201d4f2a
...
...
@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"TargetBBox"
),
"Output(TargetBBox) of RpnTargetAssignOp should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BBoxInsideWeight"
),
"Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null"
);
auto
anchor_dims
=
ctx
->
GetInputDim
(
"Anchor"
);
auto
gt_boxes_dims
=
ctx
->
GetInputDim
(
"GtBoxes"
);
...
...
@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"ScoreIndex"
,
{
-
1
});
ctx
->
SetOutputDim
(
"TargetLabel"
,
{
-
1
,
1
});
ctx
->
SetOutputDim
(
"TargetBBox"
,
{
-
1
,
4
});
ctx
->
SetOutputDim
(
"BBoxInsideWeight"
,
{
-
1
,
4
});
}
protected:
...
...
@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
const
float
rpn_positive_overlap
,
const
float
rpn_negative_overlap
,
std
::
vector
<
int
>*
fg_inds
,
std
::
vector
<
int
>*
bg_inds
,
std
::
vector
<
int
>*
tgt_lbl
,
std
::
vector
<
int
>*
fg_fake
,
std
::
vector
<
T
>*
bbox_inside_weight
,
std
::
minstd_rand
engine
,
bool
use_random
)
{
float
epsilon
=
0.00001
;
int
anchor_num
=
anchor_to_gt_max
.
dims
()[
0
];
...
...
@@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
// Reservoir Sampling
int
fg_num
=
static_cast
<
int
>
(
rpn_fg_fraction
*
rpn_batch_size_per_im
);
ReservoirSampling
(
fg_num
,
&
fg_inds_fake
,
engine
,
use_random
);
fg
_num
=
static_cast
<
int
>
(
fg_inds_fake
.
size
());
for
(
int64_t
i
=
0
;
i
<
fg_num
;
++
i
)
{
int
fg_fake
_num
=
static_cast
<
int
>
(
fg_inds_fake
.
size
());
for
(
int64_t
i
=
0
;
i
<
fg_
fake_
num
;
++
i
)
{
target_label
[
fg_inds_fake
[
i
]]
=
1
;
}
int
bg_num
=
rpn_batch_size_per_im
-
fg_num
;
int
bg_num
=
rpn_batch_size_per_im
-
fg_
fake_
num
;
for
(
int64_t
i
=
0
;
i
<
anchor_num
;
++
i
)
{
if
(
anchor_to_gt_max_data
[
i
]
<
rpn_negative_overlap
)
{
bg_inds_fake
.
push_back
(
i
);
...
...
@@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
}
ReservoirSampling
(
bg_num
,
&
bg_inds_fake
,
engine
,
use_random
);
bg_num
=
static_cast
<
int
>
(
bg_inds_fake
.
size
());
int
fake_num
=
0
;
for
(
int64_t
i
=
0
;
i
<
bg_num
;
++
i
)
{
// fg fake found
if
(
target_label
[
bg_inds_fake
[
i
]]
==
1
)
{
fake_num
++
;
fg_fake
->
emplace_back
(
fg_inds_fake
[
0
]);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
bbox_inside_weight
->
emplace_back
(
T
(
0.
));
}
}
target_label
[
bg_inds_fake
[
i
]]
=
0
;
}
for
(
int64_t
i
=
0
;
i
<
(
fg_fake_num
-
fake_num
)
*
4
;
++
i
)
{
bbox_inside_weight
->
emplace_back
(
T
(
1.
));
}
for
(
int64_t
i
=
0
;
i
<
anchor_num
;
++
i
)
{
if
(
target_label
[
i
]
==
1
)
fg_inds
->
emplace_back
(
i
);
if
(
target_label
[
i
]
==
1
)
{
fg_inds
->
emplace_back
(
i
);
fg_fake
->
emplace_back
(
i
);
}
if
(
target_label
[
i
]
==
0
)
bg_inds
->
emplace_back
(
i
);
}
fg_num
=
fg_inds
->
size
();
...
...
@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
std
::
vector
<
int
>
bg_inds
;
std
::
vector
<
int
>
gt_inds
;
std
::
vector
<
int
>
tgt_lbl
;
std
::
vector
<
int
>
fg_fake
;
std
::
vector
<
T
>
bbox_inside_weight
;
// Calculate the max IoU between anchors and gt boxes
// Map from anchor to gt box that has highest overlap
auto
place
=
ctx
.
GetPlace
();
...
...
@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
// Follow the Faster RCNN's implementation
ScoreAssign
(
anchor_by_gt_overlap_data
,
anchor_to_gt_max
,
gt_to_anchor_max
,
rpn_batch_size_per_im
,
rpn_fg_fraction
,
rpn_positive_overlap
,
rpn_negative_overlap
,
&
fg_inds
,
&
bg_inds
,
&
tgt_lbl
,
engin
e
,
use_random
);
rpn_negative_overlap
,
&
fg_inds
,
&
bg_inds
,
&
tgt_lbl
,
&
fg_fak
e
,
&
bbox_inside_weight
,
engine
,
use_random
);
int
fg_num
=
fg_inds
.
size
();
int
bg_num
=
bg_inds
.
size
();
gt_inds
.
reserve
(
fg_num
);
for
(
int
i
=
0
;
i
<
fg_num
;
++
i
)
{
gt_inds
.
emplace_back
(
argmax
[
fg_inds
[
i
]]);
int
fg_fake_num
=
fg_fake
.
size
();
gt_inds
.
reserve
(
fg_fake_num
);
for
(
int
i
=
0
;
i
<
fg_fake_num
;
++
i
)
{
gt_inds
.
emplace_back
(
argmax
[
fg_fake
[
i
]]);
}
Tensor
loc_index_t
,
score_index_t
,
tgt_lbl_t
,
gt_inds_t
;
int
*
loc_index_data
=
loc_index_t
.
mutable_data
<
int
>
({
fg_num
},
place
);
Tensor
loc_index_t
,
score_index_t
,
tgt_lbl_t
,
gt_inds_t
,
bbox_inside_weight_t
;
int
*
loc_index_data
=
loc_index_t
.
mutable_data
<
int
>
({
fg_fake_num
},
place
);
int
*
score_index_data
=
score_index_t
.
mutable_data
<
int
>
({
fg_num
+
bg_num
},
place
);
int
*
tgt_lbl_data
=
tgt_lbl_t
.
mutable_data
<
int
>
({
fg_num
+
bg_num
},
place
);
int
*
gt_inds_data
=
gt_inds_t
.
mutable_data
<
int
>
({
fg_num
},
place
);
std
::
copy
(
fg_inds
.
begin
(),
fg_inds
.
end
(),
loc_index_data
);
int
*
gt_inds_data
=
gt_inds_t
.
mutable_data
<
int
>
({
fg_fake_num
},
place
);
T
*
bbox_inside_weight_data
=
bbox_inside_weight_t
.
mutable_data
<
T
>
({
fg_fake_num
,
4
},
place
);
std
::
copy
(
fg_fake
.
begin
(),
fg_fake
.
end
(),
loc_index_data
);
std
::
copy
(
fg_inds
.
begin
(),
fg_inds
.
end
(),
score_index_data
);
std
::
copy
(
bg_inds
.
begin
(),
bg_inds
.
end
(),
score_index_data
+
fg_num
);
std
::
copy
(
tgt_lbl
.
begin
(),
tgt_lbl
.
end
(),
tgt_lbl_data
);
std
::
copy
(
gt_inds
.
begin
(),
gt_inds
.
end
(),
gt_inds_data
);
std
::
copy
(
bbox_inside_weight
.
begin
(),
bbox_inside_weight
.
end
(),
bbox_inside_weight_data
);
std
::
vector
<
Tensor
>
loc_score_tgtlbl_gt
;
loc_score_tgtlbl_gt
.
emplace_back
(
loc_index_t
);
loc_score_tgtlbl_gt
.
emplace_back
(
score_index_t
);
loc_score_tgtlbl_gt
.
emplace_back
(
tgt_lbl_t
);
loc_score_tgtlbl_gt
.
emplace_back
(
gt_inds_t
);
loc_score_tgtlbl_gt
.
emplace_back
(
bbox_inside_weight_t
);
return
loc_score_tgtlbl_gt
;
}
...
...
@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
auto
*
score_index
=
context
.
Output
<
LoDTensor
>
(
"ScoreIndex"
);
auto
*
tgt_bbox
=
context
.
Output
<
LoDTensor
>
(
"TargetBBox"
);
auto
*
tgt_lbl
=
context
.
Output
<
LoDTensor
>
(
"TargetLabel"
);
auto
*
bbox_inside_weight
=
context
.
Output
<
LoDTensor
>
(
"BBoxInsideWeight"
);
PADDLE_ENFORCE_EQ
(
gt_boxes
->
lod
().
size
(),
1UL
,
"RpnTargetAssignOp gt_boxes needs 1 level of LoD"
);
...
...
@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
score_index
->
mutable_data
<
int
>
({
max_num
},
place
);
tgt_bbox
->
mutable_data
<
T
>
({
max_num
,
4
},
place
);
tgt_lbl
->
mutable_data
<
int
>
({
max_num
,
1
},
place
);
bbox_inside_weight
->
mutable_data
<
T
>
({
max_num
,
4
},
place
);
auto
&
dev_ctx
=
context
.
device_context
<
platform
::
CPUDeviceContext
>
();
std
::
random_device
rnd
;
...
...
@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
Tensor
sampled_score_index
=
loc_score_tgtlbl_gt
[
1
];
Tensor
sampled_tgtlbl
=
loc_score_tgtlbl_gt
[
2
];
Tensor
sampled_gt_index
=
loc_score_tgtlbl_gt
[
3
];
Tensor
sampled_bbox_inside_weight
=
loc_score_tgtlbl_gt
[
4
];
int
loc_num
=
sampled_loc_index
.
dims
()[
0
];
int
score_num
=
sampled_score_index
.
dims
()[
0
];
...
...
@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
AppendRpns
<
int
>
(
score_index
,
total_score_num
,
&
sampled_score_index_unmap
);
AppendRpns
<
T
>
(
tgt_bbox
,
total_loc_num
*
4
,
&
sampled_tgt_bbox
);
AppendRpns
<
int
>
(
tgt_lbl
,
total_score_num
,
&
sampled_tgtlbl
);
AppendRpns
<
T
>
(
bbox_inside_weight
,
total_loc_num
*
4
,
&
sampled_bbox_inside_weight
);
total_loc_num
+=
loc_num
;
total_score_num
+=
score_num
;
...
...
@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
score_index
->
set_lod
(
loc_score
);
tgt_bbox
->
set_lod
(
lod_loc
);
tgt_lbl
->
set_lod
(
loc_score
);
bbox_inside_weight
->
set_lod
(
lod_loc
);
loc_index
->
Resize
({
total_loc_num
});
score_index
->
Resize
({
total_score_num
});
tgt_bbox
->
Resize
({
total_loc_num
,
4
});
tgt_lbl
->
Resize
({
total_score_num
,
1
});
bbox_inside_weight
->
Resize
({
total_loc_num
,
4
});
}
};
...
...
@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"TargetLabel"
,
"(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number."
);
AddOutput
(
"BBoxInsideWeight"
,
"(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number."
);
AddComment
(
R"DOC(
This operator can be, for a given set of ground truth bboxes and the
anchors, to assign classification and regression targets to each prediction.
...
...
paddle/fluid/operators/elementwise_op.h
浏览文件 @
201d4f2a
...
...
@@ -80,8 +80,6 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
final
{
AddInput
(
"X"
,
"(Tensor), The first input tensor of elementwise op."
);
AddInput
(
"Y"
,
"(Tensor), The second input tensor of elementwise op."
);
// AddOutput("SavedShape", "(Tensor), save X, Y shape for grad to save
// memory.").AsIntermediate();
AddOutput
(
"Out"
,
"The output of elementwise op."
);
AddAttr
<
int
>
(
"axis"
,
"(int, default -1). The start dimension index "
...
...
@@ -129,13 +127,11 @@ But the output only shares the LoD information with the input $X$.
)DOC"
,
GetName
(),
GetEquation
()));
SetReuse
();
}
protected:
virtual
std
::
string
GetName
()
const
=
0
;
virtual
std
::
string
GetEquation
()
const
=
0
;
virtual
void
SetReuse
()
{}
};
class
ElementwiseOpGrad
:
public
framework
::
OperatorWithKernel
{
...
...
@@ -269,7 +265,6 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
virtual void SetReuse() { Reuse(__VA_ARGS__); } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
...
...
paddle/fluid/operators/mean_op.cc
浏览文件 @
201d4f2a
...
...
@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input of mean op"
);
AddOutput
(
"Out"
,
"(Tensor) The output of mean op"
)
.
Reuse
(
"X"
)
;
AddOutput
(
"Out"
,
"(Tensor) The output of mean op"
);
AddComment
(
R"DOC(
Mean Operator calculates the mean of all elements in X.
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
201d4f2a
...
...
@@ -151,8 +151,7 @@ void Pool2dOpMaker::Make() {
"The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, "
"H is the height of the feature, "
"and W is the width of the feature."
)
.
Reuse
(
"X"
);
"and W is the width of the feature."
);
AddAttr
<
std
::
string
>
(
"pooling_type"
,
"(string), pooling type, can be
\"
max
\"
for max-pooling "
...
...
@@ -252,8 +251,7 @@ void Pool3dOpMaker::Make() {
"The format of output tensor is also NCDHW, "
"where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively."
)
.
Reuse
(
"X"
);
"width of the feature, respectively."
);
AddAttr
<
std
::
string
>
(
"pooling_type"
,
"(string) Pooling type, can be
\"
max
\"
for max-pooling "
...
...
paddle/fluid/operators/sgd_op.cc
浏览文件 @
201d4f2a
...
...
@@ -77,8 +77,7 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Grad"
,
"(Tensor or SelectedRows) Input gradient"
);
AddOutput
(
"ParamOut"
,
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param"
)
.
Reuse
(
"Param"
);
"Output parameter, should share the same memory with Param"
);
AddComment
(
R"DOC(
SGD operator
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
201d4f2a
...
...
@@ -80,8 +80,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"The input tensor of softmax, "
"whose last dimension is the input_feature_dimensions."
);
AddOutput
(
"Out"
,
"The normalized values with the same shape as X."
)
.
Reuse
(
"X"
);
AddOutput
(
"Out"
,
"The normalized values with the same shape as X."
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default false) Only used in cudnn kernel, need install cudnn"
)
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
201d4f2a
...
...
@@ -132,7 +132,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
AddInput
(
"X"
,
"(vector<Tensor>) The input tensors of sum operator."
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) The output tensor of sum operator."
)
.
Reuse
(
"X"
)
;
AddOutput
(
"Out"
,
"(Tensor) The output tensor of sum operator."
);
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
);
...
...
paddle/fluid/operators/top_k_op.cc
浏览文件 @
201d4f2a
...
...
@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input of Topk op"
);
AddOutput
(
"Out"
,
"(Tensor) The output tensor of Topk op"
)
.
Reuse
(
"X"
)
;
AddOutput
(
"Out"
,
"(Tensor) The output tensor of Topk op"
);
AddOutput
(
"Indices"
,
"(Tensor) The indices of Topk elements of input"
);
AddComment
(
R"DOC(
Top K operator
...
...
paddle/fluid/operators/top_k_op.cu
浏览文件 @
201d4f2a
...
...
@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const
T
*
src
,
int
lds
,
int
dim
,
int
k
,
int
grid_dim
,
int
num
)
{
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
int
maxid
[
BlockSize
/
2
];
const
int
tid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
32
;
const
int
bid
=
blockIdx
.
x
;
for
(
int
i
=
bid
;
i
<
num
;
i
+=
grid_dim
)
{
output
+=
i
*
output_stride
;
indices
+=
i
*
k
;
int
top_num
=
k
;
__shared__
int
maxid
[
BlockSize
/
2
];
T
*
out
=
output
+
i
*
output_stride
;
int64_t
*
inds
=
indices
+
i
*
k
;
Pair
<
T
>
topk
[
MaxLength
];
int
beam
=
MaxLength
;
Pair
<
T
>
max
;
bool
is_empty
=
false
;
bool
firststep
=
true
;
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
for
(
int
j
=
0
;
j
<
MaxLength
;
j
++
)
{
topk
[
j
].
set
(
-
INFINITY
,
-
1
);
}
while
(
k
)
{
while
(
top_num
)
{
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
i
*
lds
,
&
firststep
,
&
is_empty
,
&
max
,
dim
,
tid
);
sh_topk
[
tid
]
=
topk
[
0
];
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
out
put
,
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
out
,
&
inds
,
&
beam
,
&
top_num
,
tid
,
warp
);
}
}
}
...
...
@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t
k
=
static_cast
<
int
>
(
ctx
.
Attr
<
int
>
(
"k"
));
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// FIXME(typhoonzero): data is always converted to type T?
int64_t
*
indices_data
=
indices
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
size_t
input_height
=
input
->
dims
()[
0
];
size_t
input_width
=
input
->
dims
()[
1
];
framework
::
DDim
inputdims
=
input
->
dims
();
const
size_t
input_height
=
framework
::
product
(
framework
::
slice_ddim
(
inputdims
,
0
,
inputdims
.
size
()
-
1
));
const
size_t
input_width
=
inputdims
[
inputdims
.
size
()
-
1
];
if
(
k
>
input_width
)
k
=
input_width
;
// NOTE: pass lds and dim same to input width.
...
...
@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const
int
kMaxHeight
=
2048
;
int
gridx
=
input_height
<
kMaxHeight
?
input_height
:
kMaxHeight
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
switch
(
GetDesiredBlockDim
(
input_width
))
{
FIXED_BLOCK_DIM
(
KeMatrixTopK
<
T
,
5
,
kBlockDim
><<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
static_cast
<
int
>
(
k
),
gridx
,
input_height
));
output_data
,
k
,
indices_data
,
input_data
,
input_width
,
input_width
,
static_cast
<
int
>
(
k
),
gridx
,
input_height
));
default:
PADDLE_THROW
(
"Error"
);
}
...
...
paddle/fluid/operators/top_k_op.h
浏览文件 @
201d4f2a
...
...
@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
indices
=
ctx
.
Output
<
Tensor
>
(
"Indices"
);
...
...
@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
*
indices_data
=
indices
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
auto
eg_input
=
EigenMatrix
<
T
>::
From
(
*
input
);
// reshape input to a flattern matrix(like flat_inner_dims)
framework
::
DDim
inputdims
=
input
->
dims
();
const
size_t
row
=
framework
::
product
(
...
...
@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
const
size_t
col
=
inputdims
[
inputdims
.
size
()
-
1
];
Eigen
::
DSizes
<
int
,
2
>
flat2dims
(
row
,
col
);
// NOTE: eigen shape doesn't affect paddle tensor.
eg_input
.
reshape
(
flat2dims
);
auto
eg_input
=
EigenMatrix
<
T
>::
Reshape
(
*
input
,
inputdims
.
size
()
-
1
);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
201d4f2a
...
...
@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred,
Returns:
tuple:
A tuple(predicted_scores, predicted_location, target_label,
target_bbox
) is returned. The predicted_scores and
predicted_location is the predicted result of the RPN.
target_bbox
, bbox_inside_weight) is returned. The predicted_scores
and
predicted_location is the predicted result of the RPN.
The target_label and target_bbox is the ground truth,
respectively. The predicted_location is a 2D Tensor with shape
[F, 4], and the shape of target_bbox is same as the shape of
...
...
@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred,
[F + B, 1], and the shape of target_label is same as the shape
of the predicted_scores, B is the number of the background
anchors, the F and B is depends on the input of this operator.
Bbox_inside_weight represents whether the predicted loc is fake_fg
or not and the shape is [F, 4].
Examples:
.. code-block:: python
...
...
@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred,
append_batch_size=False, dtype='float32')
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target =
loc_pred, score_pred, loc_target, score_target
, bbox_inside_weight
=
fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
cls_logits=cls_logits,
anchor_box=anchor_box,
...
...
@@ -152,6 +154,8 @@ def rpn_target_assign(bbox_pred,
target_label
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
target_bbox
=
helper
.
create_variable_for_type_inference
(
dtype
=
anchor_box
.
dtype
)
bbox_inside_weight
=
helper
.
create_variable_for_type_inference
(
dtype
=
anchor_box
.
dtype
)
helper
.
append_op
(
type
=
"rpn_target_assign"
,
inputs
=
{
...
...
@@ -164,7 +168,8 @@ def rpn_target_assign(bbox_pred,
'LocationIndex'
:
loc_index
,
'ScoreIndex'
:
score_index
,
'TargetLabel'
:
target_label
,
'TargetBBox'
:
target_bbox
'TargetBBox'
:
target_bbox
,
'BBoxInsideWeight'
:
bbox_inside_weight
},
attrs
=
{
'rpn_batch_size_per_im'
:
rpn_batch_size_per_im
,
...
...
@@ -179,13 +184,14 @@ def rpn_target_assign(bbox_pred,
score_index
.
stop_gradient
=
True
target_label
.
stop_gradient
=
True
target_bbox
.
stop_gradient
=
True
bbox_inside_weight
.
stop_gradient
=
True
cls_logits
=
nn
.
reshape
(
x
=
cls_logits
,
shape
=
(
-
1
,
1
))
bbox_pred
=
nn
.
reshape
(
x
=
bbox_pred
,
shape
=
(
-
1
,
4
))
predicted_cls_logits
=
nn
.
gather
(
cls_logits
,
score_index
)
predicted_bbox_pred
=
nn
.
gather
(
bbox_pred
,
loc_index
)
return
predicted_cls_logits
,
predicted_bbox_pred
,
target_label
,
target_bbox
return
predicted_cls_logits
,
predicted_bbox_pred
,
target_label
,
target_bbox
,
bbox_inside_weight
def
detection_output
(
loc
,
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
201d4f2a
...
...
@@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase):
dtype
=
'float32'
,
lod_level
=
1
,
append_batch_size
=
False
)
pred_scores
,
pred_loc
,
tgt_lbl
,
tgt_bbox
=
layers
.
rpn_target_assign
(
pred_scores
,
pred_loc
,
tgt_lbl
,
tgt_bbox
,
bbox_inside_weight
=
layers
.
rpn_target_assign
(
bbox_pred
=
bbox_pred
,
cls_logits
=
cls_logits
,
anchor_box
=
anchor_box
,
...
...
@@ -313,15 +313,18 @@ class TestRpnTargetAssign(unittest.TestCase):
rpn_straddle_thresh
=
0.0
,
rpn_fg_fraction
=
0.5
,
rpn_positive_overlap
=
0.7
,
rpn_negative_overlap
=
0.3
)
rpn_negative_overlap
=
0.3
,
use_random
=
False
)
self
.
assertIsNotNone
(
pred_scores
)
self
.
assertIsNotNone
(
pred_loc
)
self
.
assertIsNotNone
(
tgt_lbl
)
self
.
assertIsNotNone
(
tgt_bbox
)
self
.
assertIsNotNone
(
bbox_inside_weight
)
assert
pred_scores
.
shape
[
1
]
==
1
assert
pred_loc
.
shape
[
1
]
==
4
assert
pred_loc
.
shape
[
1
]
==
tgt_bbox
.
shape
[
1
]
print
(
str
(
program
))
class
TestGenerateProposals
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py
浏览文件 @
201d4f2a
...
...
@@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap,
fg_inds
,
size
=
(
len
(
fg_inds
)
-
num_fg
),
replace
=
False
)
else
:
disable_inds
=
fg_inds
[
num_fg
:]
labels
[
disable_inds
]
=
-
1
fg_inds
=
np
.
where
(
labels
==
1
)[
0
]
bbox_inside_weight
=
np
.
zeros
((
len
(
fg_inds
),
4
),
dtype
=
np
.
float32
)
num_bg
=
rpn_batch_size_per_im
-
np
.
sum
(
labels
==
1
)
bg_inds
=
np
.
where
(
anchor_to_gt_max
<
rpn_negative_overlap
)[
0
]
...
...
@@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap,
enable_inds
=
bg_inds
[
np
.
random
.
randint
(
len
(
bg_inds
),
size
=
num_bg
)]
else
:
enable_inds
=
bg_inds
[:
num_bg
]
fg_fake_inds
=
np
.
array
([],
np
.
int32
)
fg_value
=
np
.
array
([
fg_inds
[
0
]],
np
.
int32
)
fake_num
=
0
for
bg_id
in
enable_inds
:
if
bg_id
in
fg_inds
:
fake_num
+=
1
fg_fake_inds
=
np
.
hstack
([
fg_fake_inds
,
fg_value
])
labels
[
enable_inds
]
=
0
bbox_inside_weight
[
fake_num
:,
:]
=
1
fg_inds
=
np
.
where
(
labels
==
1
)[
0
]
bg_inds
=
np
.
where
(
labels
==
0
)[
0
]
loc_index
=
fg_inds
score_index
=
np
.
hstack
((
fg_inds
,
bg_inds
))
loc_index
=
np
.
hstack
([
fg_fake_inds
,
fg_inds
])
score_index
=
np
.
hstack
([
fg_inds
,
bg_inds
])
labels
=
labels
[
score_index
]
assert
not
np
.
any
(
labels
==
-
1
),
"Wrong labels with -1"
gt_inds
=
anchor_to_gt_argmax
[
fg_inds
]
gt_inds
=
anchor_to_gt_argmax
[
loc_index
]
return
loc_index
,
score_index
,
labels
,
gt_inds
return
loc_index
,
score_index
,
labels
,
gt_inds
,
bbox_inside_weight
def
get_anchor
(
n
,
c
,
h
,
w
):
...
...
@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors,
gt_boxes_slice
=
gt_boxes_slice
[
not_crowd_inds
]
iou
=
_bbox_overlaps
(
inside_anchors
,
gt_boxes_slice
)
loc_inds
,
score_inds
,
labels
,
gt_inds
=
rpn_target_assign
(
iou
,
rpn_batch_size_per_im
,
rpn_positive_overlap
,
rpn_negative_overlap
,
rpn_fg_fraction
,
use_random
)
loc_inds
,
score_inds
,
labels
,
gt_inds
,
bbox_inside_weight
=
\
rpn_target_assign
(
iou
,
rpn_batch_size_per_im
,
rpn_positive_overlap
,
rpn_negative_overlap
,
rpn_fg_fraction
,
use_random
)
# unmap to all anchor
loc_inds
=
inds_inside
[
loc_inds
]
score_inds
=
inds_inside
[
score_inds
]
...
...
@@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors,
score_indexes
=
score_inds
tgt_labels
=
labels
tgt_bboxes
=
box_deltas
bbox_inside_weights
=
bbox_inside_weight
else
:
loc_indexes
=
np
.
concatenate
(
[
loc_indexes
,
loc_inds
+
i
*
anchor_num
])
...
...
@@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors,
[
score_indexes
,
score_inds
+
i
*
anchor_num
])
tgt_labels
=
np
.
concatenate
([
tgt_labels
,
labels
])
tgt_bboxes
=
np
.
vstack
([
tgt_bboxes
,
box_deltas
])
bbox_inside_weights
=
np
.
vstack
([
bbox_inside_weights
,
\
bbox_inside_weight
])
return
loc_indexes
,
score_indexes
,
tgt_bboxes
,
tgt_labels
return
loc_indexes
,
score_indexes
,
tgt_bboxes
,
tgt_labels
,
bbox_inside_weights
class
TestRpnTargetAssignOp
(
OpTest
):
...
...
@@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest):
rpn_fg_fraction
=
0.5
use_random
=
False
loc_index
,
score_index
,
tgt_bbox
,
labels
=
rpn_target_assign_in_python
(
all_anchors
,
gt_boxes
,
is_crowd
,
im_info
,
lod
,
rpn_straddle_thresh
,
rpn_batch_size_per_im
,
rpn_positive_overlap
,
rpn_negative_overlap
,
rpn_fg_fraction
,
use_random
)
loc_index
,
score_index
,
tgt_bbox
,
labels
,
bbox_inside_weights
=
\
rpn_target_assign_in_python
(
all_anchors
,
gt_boxes
,
is_crowd
,
im_info
,
lod
,
rpn_straddle_thresh
,
rpn_batch_size_per_im
,
rpn_positive_overlap
,
rpn_negative_overlap
,
rpn_fg_fraction
,
use_random
)
labels
=
labels
[:,
np
.
newaxis
]
self
.
op_type
=
"rpn_target_assign"
...
...
@@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest):
'LocationIndex'
:
loc_index
.
astype
(
'int32'
),
'ScoreIndex'
:
score_index
.
astype
(
'int32'
),
'TargetBBox'
:
tgt_bbox
.
astype
(
'float32'
),
'TargetLabel'
:
labels
.
astype
(
'int32'
)
'TargetLabel'
:
labels
.
astype
(
'int32'
),
'BBoxInsideWeight'
:
bbox_inside_weights
.
astype
(
'float32'
)
}
def
test_check_output
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_top_k_op.py
浏览文件 @
201d4f2a
...
...
@@ -21,22 +21,27 @@ from op_test import OpTest
class
TestTopkOp
(
OpTest
):
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"top_k"
k
=
1
input
=
np
.
random
.
random
((
32
,
84
)).
astype
(
"float32"
)
output
=
np
.
ndarray
((
32
,
k
))
indices
=
np
.
ndarray
((
32
,
k
)).
astype
(
"int64"
)
k
=
self
.
top_k
input
=
np
.
random
.
random
((
self
.
row
,
k
)).
astype
(
"float32"
)
output
=
np
.
ndarray
((
self
.
row
,
k
))
indices
=
np
.
ndarray
((
self
.
row
,
k
)).
astype
(
"int64"
)
self
.
inputs
=
{
'X'
:
input
}
self
.
attrs
=
{
'k'
:
k
}
for
rowid
in
range
(
32
):
for
rowid
in
range
(
self
.
row
):
row
=
input
[
rowid
]
output
[
rowid
]
=
np
.
sort
(
row
)[
-
k
:
]
indices
[
rowid
]
=
row
.
argsort
()[
-
k
:
]
output
[
rowid
]
=
np
.
sort
(
row
)[
::
-
1
][:
k
]
indices
[
rowid
]
=
row
.
argsort
()[
::
-
1
][:
k
]
self
.
outputs
=
{
'Out'
:
output
,
'Indices'
:
indices
}
def
set_args
(
self
):
self
.
row
=
32
self
.
top_k
=
1
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest):
output
=
np
.
ndarray
((
64
,
k
))
indices
=
np
.
ndarray
((
64
,
k
)).
astype
(
"int64"
)
# FIXME: should use 'X': input for a 3d input
self
.
inputs
=
{
'X'
:
input_flat_2d
}
self
.
inputs
=
{
'X'
:
input
}
self
.
attrs
=
{
'k'
:
k
}
for
rowid
in
range
(
64
):
row
=
input_flat_2d
[
rowid
]
output
[
rowid
]
=
np
.
sort
(
row
)[
-
k
:]
indices
[
rowid
]
=
row
.
argsort
()[
-
k
:]
output
[
rowid
]
=
np
.
sort
(
row
)[::
-
1
][:
k
]
indices
[
rowid
]
=
row
.
argsort
()[::
-
1
][:
k
]
self
.
outputs
=
{
'Out'
:
output
.
reshape
((
32
,
2
,
k
)),
'Indices'
:
indices
.
reshape
((
32
,
2
,
k
))
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestTopkOp2
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"top_k"
k
=
1
m
=
2056
input
=
np
.
random
.
random
((
m
,
84
)).
astype
(
"float32"
)
output
=
np
.
ndarray
((
m
,
k
))
indices
=
np
.
ndarray
((
m
,
k
)).
astype
(
"int64"
)
self
.
inputs
=
{
'X'
:
input
}
self
.
attrs
=
{
'k'
:
k
}
for
rowid
in
range
(
m
):
row
=
input
[
rowid
]
output
[
rowid
]
=
-
np
.
sort
(
-
row
)[:
k
]
indices
[
rowid
]
=
(
-
row
).
argsort
()[:
k
]
self
.
outputs
=
{
'Out'
:
output
,
'Indices'
:
indices
}
...
...
@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest):
self
.
check_output
()
class
TestTopkOp3
(
TestTopkOp
):
def
set_args
(
self
):
self
.
row
=
2056
self
.
top_k
=
3
class
TestTopkOp4
(
TestTopkOp
):
def
set_args
(
self
):
self
.
row
=
40000
self
.
top_k
=
1
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录