Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
479ad4bb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
479ad4bb
编写于
9月 27, 2018
作者:
Q
qingqing01
提交者:
GitHub
9月 27, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into quantize_transpiler_update
上级
bd0a9fb7
4e81e228
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
595 addition
and
10 deletion
+595
-10
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+16
-0
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+3
-2
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+7
-1
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+4
-3
paddle/fluid/operators/detection/generate_proposals_op.cu
paddle/fluid/operators/detection/generate_proposals_op.cu
+449
-0
python/paddle/fluid/contrib/__init__.py
python/paddle/fluid/contrib/__init__.py
+7
-2
python/paddle/fluid/contrib/op_frequence.py
python/paddle/fluid/contrib/op_frequence.py
+104
-0
python/paddle/fluid/tests/unittests/test_generate_proposals_op.py
...addle/fluid/tests/unittests/test_generate_proposals_op.py
+1
-2
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+3
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
479ad4bb
...
...
@@ -302,6 +302,7 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b
paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.contrib.op_freq_statistic ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
479ad4bb
...
...
@@ -257,6 +257,22 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PDPattern
external_pattern
,
subblock_pattern
;
// Use the following variables to tell whether this model is RNN1.
// This fuse can only works on the RNN1 model.
std
::
unordered_set
<
std
::
string
>
specified_vars
({
"data_lod_attention"
,
"cell_init"
,
"hidden_init"
,
"data"
,
"week"
,
"minute"
});
int
count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
specified_vars
.
count
(
node
->
Name
()))
{
++
count
;
}
}
if
(
count
<
specified_vars
.
size
())
{
return
graph
;
}
// Continue to fuse.
FindWhileOp
(
graph
.
get
());
return
graph
;
}
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
479ad4bb
...
...
@@ -212,10 +212,11 @@ struct AnalysisConfig : public NativeConfig {
kExclude
// Specify the disabled passes in `ir_passes`.
};
// Determine whether to perform graph optimization.
bool
enable_ir_optim
=
true
;
// Manually determine the IR passes to run.
IrPassMode
ir_mode
{
IrPassMode
::
kExclude
};
// attention lstm fuse works only on some specific models, disable as default.
std
::
vector
<
std
::
string
>
ir_passes
{
"attention_lstm_fuse_pass"
};
std
::
vector
<
std
::
string
>
ir_passes
;
// NOTE this is just for internal development, please not use it.
bool
_use_mkldnn
{
false
};
...
...
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
479ad4bb
...
...
@@ -30,7 +30,13 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu
)
detection_library
(
rpn_target_assign_op SRCS rpn_target_assign_op.cc
)
detection_library
(
generate_proposal_labels_op SRCS generate_proposal_labels_op.cc
)
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc
)
if
(
WITH_GPU
)
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub
)
else
()
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc
)
endif
()
detection_library
(
roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu
)
#Export local libraries to parent
set
(
DETECTION_LIBRARY
${
LOCAL_DETECTION_LIBS
}
PARENT_SCOPE
)
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
479ad4bb
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
...
...
@@ -69,7 +70,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Anchors"
)
->
type
()),
platform
::
CPUPlace
());
ctx
.
device_context
());
}
};
...
...
@@ -162,7 +163,7 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
im_scale
=
im_info_data
[
2
];
keep
->
Resize
({
boxes
->
dims
()[
0
]
,
1
});
keep
->
Resize
({
boxes
->
dims
()[
0
]});
min_size
=
std
::
max
(
min_size
,
1.0
f
);
int
*
keep_data
=
keep
->
mutable_data
<
int
>
(
ctx
.
GetPlace
());
...
...
@@ -463,7 +464,7 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
"post_nms_topN"
,
"post_nms_topN"
);
AddAttr
<
float
>
(
"nms_thresh"
,
"nms_thres"
);
AddAttr
<
float
>
(
"min_size"
,
"min size"
);
AddAttr
<
float
>
(
"eta"
,
"
eta
"
);
AddAttr
<
float
>
(
"eta"
,
"
The parameter for adaptive NMS.
"
);
AddComment
(
R"DOC(
Generate Proposals OP
...
...
paddle/fluid/operators/detection/generate_proposals_op.cu
0 → 100644
浏览文件 @
479ad4bb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
namespace
{
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
int
const
kThreadsPerBlock
=
sizeof
(
uint64_t
)
*
8
;
template
<
typename
T
>
__global__
void
RangeInitKernel
(
const
T
start
,
const
T
delta
,
const
int
size
,
T
*
out
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
size
)
{
out
[
i
]
=
start
+
i
*
delta
;
}
}
template
<
typename
T
>
void
SortDescending
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
value
,
Tensor
*
value_out
,
Tensor
*
index_out
)
{
int
num
=
value
.
numel
();
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
int
block
=
512
;
auto
stream
=
ctx
.
stream
();
RangeInitKernel
<<<
DIVUP
(
num
,
block
),
block
,
0
,
stream
>>>
(
0
,
1
,
num
,
idx_in
);
int
*
idx_out
=
index_out
->
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
const
T
*
keys_in
=
value
.
data
<
T
>
();
T
*
keys_out
=
value_out
->
mutable_data
<
T
>
({
num
},
ctx
.
GetPlace
());
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
// Allocate temporary storage
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
memory
::
Free
(
place
,
d_temp_storage
);
}
template
<
typename
T
>
__device__
__forceinline__
T
Min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
typename
T
>
__device__
__forceinline__
T
Max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
typename
T
>
__global__
void
BoxDecodeAndClipKernel
(
const
T
*
anchor
,
const
T
*
deltas
,
const
T
*
var
,
const
int
*
index
,
const
T
*
im_info
,
const
int
num
,
T
*
proposals
)
{
T
kBBoxClipDefault
=
log
(
1000.0
/
16.0
);
CUDA_1D_KERNEL_LOOP
(
i
,
num
)
{
int
k
=
index
[
i
]
*
4
;
T
axmin
=
anchor
[
k
];
T
aymin
=
anchor
[
k
+
1
];
T
axmax
=
anchor
[
k
+
2
];
T
aymax
=
anchor
[
k
+
3
];
T
w
=
axmax
-
axmin
+
1.0
;
T
h
=
aymax
-
aymin
+
1.0
;
T
cx
=
axmin
+
0.5
*
w
;
T
cy
=
aymin
+
0.5
*
h
;
T
dxmin
=
deltas
[
k
];
T
dymin
=
deltas
[
k
+
1
];
T
dxmax
=
deltas
[
k
+
2
];
T
dymax
=
deltas
[
k
+
3
];
T
d_cx
=
0.
,
d_cy
=
0.
,
d_w
=
0.
,
d_h
=
0.
;
if
(
var
)
{
d_cx
=
cx
+
dxmin
*
w
*
var
[
k
];
d_cy
=
cy
+
dymin
*
h
*
var
[
k
+
1
];
d_w
=
exp
(
Min
<
T
>
(
dxmax
*
var
[
k
+
2
],
kBBoxClipDefault
))
*
w
;
d_h
=
exp
(
Min
<
T
>
(
dymax
*
var
[
k
+
3
],
kBBoxClipDefault
))
*
h
;
}
else
{
d_cx
=
cx
+
dxmin
*
w
;
d_cy
=
cy
+
dymin
*
h
;
d_w
=
exp
(
Min
<
T
>
(
dxmax
,
kBBoxClipDefault
))
*
w
;
d_h
=
exp
(
Min
<
T
>
(
dymax
,
kBBoxClipDefault
))
*
h
;
}
T
oxmin
=
d_cx
-
d_w
*
0.5
;
T
oymin
=
d_cy
-
d_h
*
0.5
;
T
oxmax
=
d_cx
+
d_w
*
0.5
-
1.
;
T
oymax
=
d_cy
+
d_h
*
0.5
-
1.
;
proposals
[
i
*
4
]
=
Max
<
T
>
(
Min
<
T
>
(
oxmin
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
1
]
=
Max
<
T
>
(
Min
<
T
>
(
oymin
,
im_info
[
0
]
-
1.
),
0.
);
proposals
[
i
*
4
+
2
]
=
Max
<
T
>
(
Min
<
T
>
(
oxmax
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
3
]
=
Max
<
T
>
(
Min
<
T
>
(
oymax
,
im_info
[
0
]
-
1.
),
0.
);
}
}
template
<
typename
T
,
int
BlockSize
>
__global__
void
FilterBBoxes
(
const
T
*
bboxes
,
const
T
*
im_info
,
const
T
min_size
,
const
int
num
,
int
*
keep_num
,
int
*
keep
)
{
T
im_h
=
im_info
[
0
];
T
im_w
=
im_info
[
1
];
T
im_scale
=
im_info
[
2
];
int
cnt
=
0
;
__shared__
int
keep_index
[
BlockSize
];
CUDA_1D_KERNEL_LOOP
(
i
,
num
)
{
keep_index
[
threadIdx
.
x
]
=
-
1
;
__syncthreads
();
int
k
=
i
*
4
;
T
xmin
=
bboxes
[
k
];
T
ymin
=
bboxes
[
k
+
1
];
T
xmax
=
bboxes
[
k
+
2
];
T
ymax
=
bboxes
[
k
+
3
];
T
w
=
xmax
-
xmin
+
1.0
;
T
h
=
ymax
-
ymin
+
1.0
;
T
cx
=
xmin
+
w
/
2.
;
T
cy
=
ymin
+
h
/
2.
;
T
w_s
=
(
xmax
-
xmin
)
/
im_scale
+
1.
;
T
h_s
=
(
ymax
-
ymin
)
/
im_scale
+
1.
;
if
(
w_s
>=
min_size
&&
h_s
>=
min_size
&&
cx
<=
im_w
&&
cy
<=
im_h
)
{
keep_index
[
threadIdx
.
x
]
=
i
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
size
=
(
num
-
i
)
<
BlockSize
?
num
-
i
:
BlockSize
;
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
if
(
keep_index
[
j
]
>
-
1
)
{
keep
[
cnt
++
]
=
keep_index
[
j
];
}
}
}
__syncthreads
();
}
if
(
threadIdx
.
x
==
0
)
{
keep_num
[
0
]
=
cnt
;
}
}
__device__
inline
float
IoU
(
const
float
*
a
,
const
float
*
b
)
{
float
left
=
max
(
a
[
0
],
b
[
0
]),
right
=
min
(
a
[
2
],
b
[
2
]);
float
top
=
max
(
a
[
1
],
b
[
1
]),
bottom
=
min
(
a
[
3
],
b
[
3
]);
float
width
=
max
(
right
-
left
+
1
,
0.
f
),
height
=
max
(
bottom
-
top
+
1
,
0.
f
);
float
inter_s
=
width
*
height
;
float
s_a
=
(
a
[
2
]
-
a
[
0
]
+
1
)
*
(
a
[
3
]
-
a
[
1
]
+
1
);
float
s_b
=
(
b
[
2
]
-
b
[
0
]
+
1
)
*
(
b
[
3
]
-
b
[
1
]
+
1
);
return
inter_s
/
(
s_a
+
s_b
-
inter_s
);
}
__global__
void
NMSKernel
(
const
int
n_boxes
,
const
float
nms_overlap_thresh
,
const
float
*
dev_boxes
,
uint64_t
*
dev_mask
)
{
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
const
int
row_size
=
min
(
n_boxes
-
row_start
*
kThreadsPerBlock
,
kThreadsPerBlock
);
const
int
col_size
=
min
(
n_boxes
-
col_start
*
kThreadsPerBlock
,
kThreadsPerBlock
);
__shared__
float
block_boxes
[
kThreadsPerBlock
*
4
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
4
+
0
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
0
];
block_boxes
[
threadIdx
.
x
*
4
+
1
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
1
];
block_boxes
[
threadIdx
.
x
*
4
+
2
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
2
];
block_boxes
[
threadIdx
.
x
*
4
+
3
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
3
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
kThreadsPerBlock
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
dev_boxes
+
cur_box_idx
*
4
;
int
i
=
0
;
uint64_t
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
IoU
(
cur_box
,
block_boxes
+
i
*
4
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
n_boxes
,
kThreadsPerBlock
);
dev_mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
template
<
typename
T
>
void
NMS
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
proposals
,
const
Tensor
&
sorted_indices
,
const
T
nms_threshold
,
Tensor
*
keep_out
)
{
int
boxes_num
=
proposals
.
dims
()[
0
];
PADDLE_ENFORCE_EQ
(
boxes_num
,
sorted_indices
.
dims
()[
0
]);
const
int
col_blocks
=
DIVUP
(
boxes_num
,
kThreadsPerBlock
);
dim3
blocks
(
DIVUP
(
boxes_num
,
kThreadsPerBlock
),
DIVUP
(
boxes_num
,
kThreadsPerBlock
));
dim3
threads
(
kThreadsPerBlock
);
const
T
*
boxes
=
proposals
.
data
<
T
>
();
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
int
size_bytes
=
boxes_num
*
col_blocks
*
sizeof
(
uint64_t
);
uint64_t
*
d_mask
=
reinterpret_cast
<
uint64_t
*>
(
memory
::
Alloc
(
place
,
size_bytes
));
NMSKernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_threshold
,
boxes
,
d_mask
);
uint64_t
*
h_mask
=
reinterpret_cast
<
uint64_t
*>
(
memory
::
Alloc
(
platform
::
CPUPlace
(),
size_bytes
));
memory
::
Copy
(
platform
::
CPUPlace
(),
h_mask
,
place
,
d_mask
,
size_bytes
,
0
);
std
::
vector
<
uint64_t
>
remv
(
col_blocks
);
memset
(
&
remv
[
0
],
0
,
sizeof
(
uint64_t
)
*
col_blocks
);
std
::
vector
<
int
>
keep_vec
;
int
num_to_keep
=
0
;
for
(
int
i
=
0
;
i
<
boxes_num
;
i
++
)
{
int
nblock
=
i
/
kThreadsPerBlock
;
int
inblock
=
i
%
kThreadsPerBlock
;
if
(
!
(
remv
[
nblock
]
&
(
1ULL
<<
inblock
)))
{
++
num_to_keep
;
keep_vec
.
push_back
(
i
);
uint64_t
*
p
=
&
h_mask
[
0
]
+
i
*
col_blocks
;
for
(
int
j
=
nblock
;
j
<
col_blocks
;
j
++
)
{
remv
[
j
]
|=
p
[
j
];
}
}
}
int
*
keep
=
keep_out
->
mutable_data
<
int
>
({
num_to_keep
},
ctx
.
GetPlace
());
memory
::
Copy
(
place
,
keep
,
platform
::
CPUPlace
(),
keep_vec
.
data
(),
sizeof
(
int
)
*
num_to_keep
,
0
);
memory
::
Free
(
place
,
d_mask
);
memory
::
Free
(
platform
::
CPUPlace
(),
h_mask
);
}
template
<
typename
T
>
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
im_info
,
const
Tensor
&
anchors
,
const
Tensor
&
variances
,
const
Tensor
&
bbox_deltas
,
// [M, 4]
const
Tensor
&
scores
,
// [N, 1]
int
pre_nms_top_n
,
int
post_nms_top_n
,
float
nms_thresh
,
float
min_size
,
float
eta
)
{
// 1. pre nms
Tensor
scores_sort
,
index_sort
;
SortDescending
<
T
>
(
ctx
,
scores
,
&
scores_sort
,
&
index_sort
);
int
num
=
scores
.
numel
();
int
pre_nms_num
=
(
pre_nms_top_n
<=
0
||
pre_nms_top_n
>
num
)
?
scores
.
numel
()
:
pre_nms_top_n
;
scores_sort
.
Resize
({
pre_nms_num
,
1
});
index_sort
.
Resize
({
pre_nms_num
,
1
});
// 2. box decode and clipping
Tensor
proposals
;
proposals
.
mutable_data
<
T
>
({
pre_nms_num
,
4
},
ctx
.
GetPlace
());
int
block
=
512
;
auto
stream
=
ctx
.
stream
();
BoxDecodeAndClipKernel
<
T
><<<
DIVUP
(
pre_nms_num
,
block
),
block
,
0
,
stream
>>>
(
anchors
.
data
<
T
>
(),
bbox_deltas
.
data
<
T
>
(),
variances
.
data
<
T
>
(),
index_sort
.
data
<
int
>
(),
im_info
.
data
<
T
>
(),
pre_nms_num
,
proposals
.
data
<
T
>
());
// 3. filter
Tensor
keep_index
,
keep_num_t
;
keep_index
.
mutable_data
<
int
>
({
pre_nms_num
},
ctx
.
GetPlace
());
keep_num_t
.
mutable_data
<
int
>
({
1
},
ctx
.
GetPlace
());
min_size
=
std
::
max
(
min_size
,
1.0
f
);
FilterBBoxes
<
T
,
512
><<<
1
,
512
,
0
,
stream
>>>
(
proposals
.
data
<
T
>
(),
im_info
.
data
<
T
>
(),
min_size
,
pre_nms_num
,
keep_num_t
.
data
<
int
>
(),
keep_index
.
data
<
int
>
());
int
keep_num
;
const
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
memory
::
Copy
(
platform
::
CPUPlace
(),
&
keep_num
,
gpu_place
,
keep_num_t
.
data
<
int
>
(),
sizeof
(
int
),
0
);
keep_index
.
Resize
({
keep_num
});
Tensor
scores_filter
,
proposals_filter
;
proposals_filter
.
mutable_data
<
T
>
({
keep_num
,
4
},
ctx
.
GetPlace
());
scores_filter
.
mutable_data
<
T
>
({
keep_num
,
1
},
ctx
.
GetPlace
());
GPUGather
<
T
>
(
ctx
,
proposals
,
keep_index
,
&
proposals_filter
);
GPUGather
<
T
>
(
ctx
,
scores_sort
,
keep_index
,
&
scores_filter
);
if
(
nms_thresh
<=
0
)
{
return
std
::
make_pair
(
proposals_filter
,
scores_filter
);
}
// 4. nms
Tensor
keep_nms
;
NMS
<
T
>
(
ctx
,
proposals_filter
,
keep_index
,
nms_thresh
,
&
keep_nms
);
if
(
post_nms_top_n
>
0
&&
post_nms_top_n
<
keep_nms
.
numel
())
{
keep_nms
.
Resize
({
post_nms_top_n
});
}
Tensor
scores_nms
,
proposals_nms
;
proposals_nms
.
mutable_data
<
T
>
({
keep_nms
.
numel
(),
4
},
ctx
.
GetPlace
());
scores_nms
.
mutable_data
<
T
>
({
keep_nms
.
numel
(),
1
},
ctx
.
GetPlace
());
GPUGather
<
T
>
(
ctx
,
proposals_filter
,
keep_nms
,
&
proposals_nms
);
GPUGather
<
T
>
(
ctx
,
scores_filter
,
keep_nms
,
&
scores_nms
);
return
std
::
make_pair
(
proposals_nms
,
scores_nms
);
}
}
// namespace
template
<
typename
DeviceContext
,
typename
T
>
class
CUDAGenerateProposalsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
scores
=
context
.
Input
<
Tensor
>
(
"Scores"
);
auto
*
bbox_deltas
=
context
.
Input
<
Tensor
>
(
"BboxDeltas"
);
auto
*
im_info
=
context
.
Input
<
Tensor
>
(
"ImInfo"
);
auto
*
anchors
=
context
.
Input
<
Tensor
>
(
"Anchors"
);
auto
*
variances
=
context
.
Input
<
Tensor
>
(
"Variances"
);
auto
*
rpn_rois
=
context
.
Output
<
LoDTensor
>
(
"RpnRois"
);
auto
*
rpn_roi_probs
=
context
.
Output
<
LoDTensor
>
(
"RpnRoiProbs"
);
int
pre_nms_top_n
=
context
.
Attr
<
int
>
(
"pre_nms_topN"
);
int
post_nms_top_n
=
context
.
Attr
<
int
>
(
"post_nms_topN"
);
float
nms_thresh
=
context
.
Attr
<
float
>
(
"nms_thresh"
);
float
min_size
=
context
.
Attr
<
float
>
(
"min_size"
);
float
eta
=
context
.
Attr
<
float
>
(
"eta"
);
PADDLE_ENFORCE_GE
(
eta
,
1.
,
"Not support adaptive NMS."
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
scores_dim
=
scores
->
dims
();
int64_t
num
=
scores_dim
[
0
];
int64_t
c_score
=
scores_dim
[
1
];
int64_t
h_score
=
scores_dim
[
2
];
int64_t
w_score
=
scores_dim
[
3
];
auto
bbox_dim
=
bbox_deltas
->
dims
();
int64_t
c_bbox
=
bbox_dim
[
1
];
int64_t
h_bbox
=
bbox_dim
[
2
];
int64_t
w_bbox
=
bbox_dim
[
3
];
Tensor
bbox_deltas_swap
,
scores_swap
;
bbox_deltas_swap
.
mutable_data
<
T
>
({
num
,
h_bbox
,
w_bbox
,
c_bbox
},
dev_ctx
.
GetPlace
());
scores_swap
.
mutable_data
<
T
>
({
num
,
h_score
,
w_score
,
c_score
},
dev_ctx
.
GetPlace
());
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans
;
std
::
vector
<
int
>
axis
=
{
0
,
2
,
3
,
1
};
trans
(
dev_ctx
,
*
bbox_deltas
,
&
bbox_deltas_swap
,
axis
);
trans
(
dev_ctx
,
*
scores
,
&
scores_swap
,
axis
);
Tensor
*
anchor
=
const_cast
<
framework
::
Tensor
*>
(
anchors
);
anchor
->
Resize
({
anchors
->
numel
()
/
4
,
4
});
Tensor
*
var
=
const_cast
<
framework
::
Tensor
*>
(
variances
);
var
->
Resize
({
var
->
numel
()
/
4
,
4
});
rpn_rois
->
mutable_data
<
T
>
({
bbox_deltas
->
numel
()
/
4
,
4
},
context
.
GetPlace
());
rpn_roi_probs
->
mutable_data
<
T
>
({
scores
->
numel
(),
1
},
context
.
GetPlace
());
T
*
rpn_rois_data
=
rpn_rois
->
data
<
T
>
();
T
*
rpn_roi_probs_data
=
rpn_roi_probs
->
data
<
T
>
();
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx
.
GetPlace
());
int64_t
num_proposals
=
0
;
std
::
vector
<
size_t
>
offset
(
1
,
0
);
for
(
int64_t
i
=
0
;
i
<
num
;
++
i
)
{
Tensor
im_info_slice
=
im_info
->
Slice
(
i
,
i
+
1
);
Tensor
bbox_deltas_slice
=
bbox_deltas_swap
.
Slice
(
i
,
i
+
1
);
Tensor
scores_slice
=
scores_swap
.
Slice
(
i
,
i
+
1
);
bbox_deltas_slice
.
Resize
({
h_bbox
*
w_bbox
*
c_bbox
/
4
,
4
});
scores_slice
.
Resize
({
h_score
*
w_score
*
c_score
,
1
});
std
::
pair
<
Tensor
,
Tensor
>
box_score_pair
=
ProposalForOneImage
<
T
>
(
dev_ctx
,
im_info_slice
,
*
anchor
,
*
var
,
bbox_deltas_slice
,
scores_slice
,
pre_nms_top_n
,
post_nms_top_n
,
nms_thresh
,
min_size
,
eta
);
Tensor
proposals
=
box_score_pair
.
first
;
Tensor
scores
=
box_score_pair
.
second
;
memory
::
Copy
(
place
,
rpn_rois_data
+
num_proposals
*
4
,
place
,
proposals
.
data
<
T
>
(),
sizeof
(
T
)
*
proposals
.
numel
(),
0
);
memory
::
Copy
(
place
,
rpn_roi_probs_data
+
num_proposals
,
place
,
scores
.
data
<
T
>
(),
sizeof
(
T
)
*
scores
.
numel
(),
0
);
num_proposals
+=
proposals
.
dims
()[
0
];
offset
.
emplace_back
(
num_proposals
);
}
framework
::
LoD
lod
;
lod
.
emplace_back
(
offset
);
rpn_rois
->
set_lod
(
lod
);
rpn_roi_probs
->
set_lod
(
lod
);
rpn_rois
->
Resize
({
num_proposals
,
4
});
rpn_roi_probs
->
Resize
({
num_proposals
,
1
});
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
generate_proposals
,
ops
::
CUDAGenerateProposalsKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
python/paddle/fluid/contrib/__init__.py
浏览文件 @
479ad4bb
...
...
@@ -18,8 +18,13 @@ from . import decoder
from
.decoder
import
*
from
.
import
memory_usage_calc
from
.memory_usage_calc
import
*
from
.
import
op_frequence
from
.op_frequence
import
*
from
.
import
quantize
from
.quantize
import
*
__all__
=
decoder
.
__all__
+
memory_usage_calc
.
__all__
__all__
=
[]
__all__
+=
decoder
.
__all__
__all__
+=
memory_usage_calc
.
__all__
__all__
+=
op_frequence
.
__all__
__all__
+=
quantize
.
__all__
\ No newline at end of file
python/paddle/fluid/contrib/op_frequence.py
0 → 100644
浏览文件 @
479ad4bb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
collections
import
OrderedDict
from
..framework
import
Program
__all__
=
[
'op_freq_statistic'
]
def
op_freq_statistic
(
program
):
"""
Statistics of Op frequency.
Args:
program(Program): The current Program.
Returns:
uni_op_freq(dict): the single op frequency.
adj_2_op_freq(dict): the two adjacent ops frequency.
Examples:
>>> import paddle.fluid as fluid
>>> uni_op_freq, adj_2_op_freq = fluid.contrib.op_freq_statistic(
>>> fluid.default_main_program())
>>> for op_type, op_num in uni_op_freq:
>>> print("%s
\t
%d" % (op_type, op_num))
>>> for op_type, op_num in adj_2_op_freq:
>>> print("%s
\t
%d" % (op_type, op_num))
"""
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"The input type should be Porgram."
"But you passed in %s"
%
(
type
(
program
)))
uni_op_freq
=
OrderedDict
()
adj_2_op_freq
=
OrderedDict
()
op_in_ops
=
OrderedDict
()
parameters
=
[
p
.
name
for
p
in
program
.
blocks
[
0
].
all_parameters
()]
# get uni_op_freq
for
op
in
program
.
global_block
().
ops
:
had_recorded
=
False
for
var_name
in
op
.
output_arg_names
:
if
var_name
in
parameters
:
continue
if
not
had_recorded
and
uni_op_freq
.
has_key
(
op
.
type
):
uni_op_freq
[
op
.
type
]
+=
1
had_recorded
=
True
elif
not
had_recorded
:
uni_op_freq
[
op
.
type
]
=
1
had_recorded
=
True
# get adj_2_op_freq
var_gen_op
=
{}
for
op
in
program
.
global_block
().
ops
:
for
var_name
in
op
.
input_arg_names
:
if
var_name
in
parameters
:
continue
if
var_gen_op
.
has_key
(
var_name
):
assert
len
(
var_gen_op
[
var_name
])
>
0
if
op_in_ops
.
has_key
(
op
.
type
):
op_in_ops
[
op
.
type
].
append
(
var_gen_op
[
var_name
][
-
1
])
else
:
op_in_ops
[
op
.
type
]
=
[
var_gen_op
[
var_name
][
-
1
]]
else
:
print
(
"Var's generate op is not found,%s, %s"
%
(
var_name
,
op
.
type
))
for
var_name
in
op
.
output_arg_names
:
if
var_gen_op
.
has_key
(
var_name
):
var_gen_op
[
var_name
].
append
(
op
.
type
)
else
:
var_gen_op
[
var_name
]
=
[
op
.
type
]
for
op
,
in_ops
in
op_in_ops
.
iteritems
():
for
in_op
in
in_ops
:
op_op
=
in_op
+
"->"
+
op
if
adj_2_op_freq
.
has_key
(
op_op
):
adj_2_op_freq
[
op_op
]
+=
1
else
:
adj_2_op_freq
[
op_op
]
=
1
uni_op_freq
=
sorted
(
uni_op_freq
.
items
(),
key
=
lambda
item
:
item
[
1
],
reverse
=
True
)
adj_2_op_freq
=
sorted
(
adj_2_op_freq
.
items
(),
key
=
lambda
item
:
item
[
1
],
reverse
=
True
)
return
uni_op_freq
,
adj_2_op_freq
python/paddle/fluid/tests/unittests/test_generate_proposals_op.py
浏览文件 @
479ad4bb
...
...
@@ -277,7 +277,6 @@ class TestGenerateProposalsOp(OpTest):
'eta'
:
self
.
eta
}
print
(
"lod = "
,
self
.
lod
)
self
.
outputs
=
{
'RpnRois'
:
(
self
.
rpn_rois
[
0
],
[
self
.
lod
]),
'RpnRoiProbs'
:
(
self
.
rpn_roi_probs
[
0
],
[
self
.
lod
])
...
...
@@ -295,7 +294,7 @@ class TestGenerateProposalsOp(OpTest):
self
.
post_nms_topN
=
5000
# train 6000, test 1000
self
.
nms_thresh
=
0.7
self
.
min_size
=
3.0
self
.
eta
=
0.8
self
.
eta
=
1.
def
init_test_input
(
self
):
batch_size
=
1
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
479ad4bb
...
...
@@ -470,7 +470,10 @@ class DistributeTranspiler(object):
"""
# remove optimize ops and add a send op to main_program
# FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay?
lr_ops
=
self
.
_get_lr_ops
()
delete_ops
(
self
.
origin_program
.
global_block
(),
self
.
optimize_ops
)
delete_ops
(
self
.
origin_program
.
global_block
(),
lr_ops
)
self
.
origin_program
.
__str__
()
if
wait_port
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录