Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
2402c742
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2402c742
编写于
6月 23, 2020
作者:
W
Wilber
提交者:
GitHub
6月 23, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add topk_pooling kernel. test=develop (#3813)
上级
e43081ff
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
505 addition
and
0 deletion
+505
-0
lite/backends/cuda/cuda_utils.h
lite/backends/cuda/cuda_utils.h
+2
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+2
-0
lite/kernels/cuda/topk_pooling_compute.cu
lite/kernels/cuda/topk_pooling_compute.cu
+200
-0
lite/kernels/cuda/topk_pooling_compute.h
lite/kernels/cuda/topk_pooling_compute.h
+45
-0
lite/kernels/cuda/topk_pooling_compute_test.cc
lite/kernels/cuda/topk_pooling_compute_test.cc
+145
-0
lite/operators/CMakeLists.txt
lite/operators/CMakeLists.txt
+1
-0
lite/operators/op_params.h
lite/operators/op_params.h
+9
-0
lite/operators/topk_pooling_op.cc
lite/operators/topk_pooling_op.cc
+55
-0
lite/operators/topk_pooling_op.h
lite/operators/topk_pooling_op.h
+46
-0
未找到文件。
lite/backends/cuda/cuda_utils.h
浏览文件 @
2402c742
...
...
@@ -41,6 +41,8 @@
<< "CUDA: " << cudaGetErrorString(e); \
}
#define CUDA_POST_KERNEL_CHECK CUDA_CALL(cudaPeekAtLastError())
#define CUBLAS_CALL(func) \
{ \
auto e = (func); \
...
...
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
2402c742
...
...
@@ -44,6 +44,7 @@ add_kernel(match_matrix_tensor_compute_cuda CUDA extra SRCS match_matrix_tensor_
add_kernel
(
search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_mat_mul_compute.cc DEPS
${
lite_kernel_deps
}
cuda_batched_gemm
)
add_kernel
(
search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS
${
lite_kernel_deps
}
cuda_gemm
)
add_kernel
(
var_conv_2d_compute_cuda CUDA extra SRCS var_conv_2d_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
topk_pooling_compute_cuda CUDA extra SRCS topk_pooling_compute.cu DEPS
${
lite_kernel_deps
}
)
# unit test
lite_cc_test
(
calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda
)
...
...
@@ -79,4 +80,5 @@ if(LITE_BUILD_EXTRA)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
nv_test
(
sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda
)
#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda)
nv_test
(
topk_pooling_compute_cuda_test SRCS topk_pooling_compute_test.cc DEPS topk_pooling_compute_cuda
)
endif
()
lite/kernels/cuda/topk_pooling_compute.cu
0 → 100644
浏览文件 @
2402c742
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/topk_pooling_compute.h"
#include <limits>
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
Dtype
>
__global__
void
top_k_pooling_batch_kernel_reduction
(
Dtype
*
output_data
,
const
Dtype
*
input
,
const
int
*
height_offset
,
const
int
*
width_offset
,
const
int
batch_size
,
const
int
channel_num
,
const
int
height_stride
,
const
int
width_stride
,
const
int
k
)
{
const
Dtype
*
input_start
=
input
+
(
blockIdx
.
x
*
channel_num
+
blockIdx
.
y
)
*
height_stride
*
width_stride
;
Dtype
*
output_start
=
output_data
+
(
blockIdx
.
x
*
channel_num
+
blockIdx
.
y
)
*
k
;
int
width
=
width_offset
[
blockIdx
.
x
+
1
]
-
width_offset
[
blockIdx
.
x
];
int
height
=
height_offset
[
blockIdx
.
x
+
1
]
-
height_offset
[
blockIdx
.
x
];
int
real_k
=
k
<
height
*
width
?
k
:
height
*
width
;
extern
__shared__
Dtype
smem
[];
Dtype
min_val
=
-
100000.0
f
;
for
(
int
j
=
threadIdx
.
x
;
j
<
height
*
width
;
j
+=
blockDim
.
x
)
{
int
index_tmp
=
(
j
/
width
)
*
width_stride
+
j
%
width
;
smem
[
j
]
=
input_start
[
index_tmp
];
}
__syncthreads
();
// get max val
int
t
=
0
;
for
(;
t
<
real_k
;
++
t
)
{
// reduction
for
(
int
gap
=
height
*
width
;
gap
>
1
;)
{
if
(
threadIdx
.
x
==
0
)
{
// edge cond
if
(
gap
%
2
!=
0
)
{
Dtype
value_first
=
smem
[
0
];
Dtype
value_gap
=
smem
[
gap
-
1
];
if
(
value_first
<
value_gap
)
{
smem
[
0
]
=
value_gap
;
smem
[
gap
-
1
]
=
value_first
;
}
}
}
gap
>>=
1
;
for
(
int
j
=
threadIdx
.
x
;
j
<
gap
;
j
+=
blockDim
.
x
)
{
Dtype
value_first
=
smem
[
j
];
Dtype
value_gap
=
smem
[
j
+
gap
];
if
(
value_first
<
value_gap
)
{
smem
[
j
]
=
value_gap
;
smem
[
j
+
gap
]
=
value_first
;
}
}
__syncthreads
();
}
if
(
threadIdx
.
x
==
0
)
{
output_start
[
t
]
=
smem
[
0
];
smem
[
0
]
=
min_val
;
}
__syncthreads
();
}
for
(
int
i
=
threadIdx
.
x
;
i
<
(
k
-
t
);
i
+=
blockDim
.
x
)
{
// output_start[t + i] = 0.0f;
}
}
template
<
typename
T
>
void
TopkPoolingCompute
<
T
>::
PrepareForRun
()
{
int
device_id
=
lite
::
TargetWrapperCuda
::
GetCurDevice
();
cudaDeviceProp
deviceProp
;
CUDA_CALL
(
cudaGetDeviceProperties
(
&
deviceProp
,
device_id
));
_shared_mem_size
=
deviceProp
.
sharedMemPerBlock
;
}
template
<
typename
T
>
void
TopkPoolingCompute
<
T
>::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
cuda_stream
=
ctx
.
exec_stream
();
CHECK
(
param
.
X
->
lod
().
size
()
>
0
&&
param
.
X
->
lod
()[
0
].
size
()
>
0
)
<<
"X sequence offset is not valid"
;
CHECK
(
param
.
Y
->
lod
().
size
()
>
0
&&
param
.
Y
->
lod
()[
0
].
size
()
>
0
)
<<
"Y sequence offset is not valid"
;
int
width_offset_len
=
param
.
X
->
lod
()[
0
].
size
();
lite
::
DDim
width_offset_shape
(
std
::
vector
<
int64_t
>
{
width_offset_len
});
_width_offset
.
Resize
(
width_offset_shape
);
std
::
vector
<
int
>
width_lod_0
(
width_offset_len
,
0
);
for
(
size_t
i
=
0
;
i
<
param
.
X
->
lod
()[
0
].
size
();
++
i
)
{
width_lod_0
[
i
]
=
static_cast
<
int
>
(
param
.
X
->
lod
()[
0
][
i
]);
}
lite
::
TargetWrapperCuda
::
MemcpyAsync
(
_width_offset
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
)),
width_lod_0
.
data
(),
sizeof
(
int
)
*
width_offset_len
,
lite
::
IoDirection
::
HtoD
,
cuda_stream
);
int
height_offset_len
=
param
.
Y
->
lod
()[
0
].
size
();
lite
::
DDim
height_offset_shape
(
std
::
vector
<
int64_t
>
{
height_offset_len
});
_height_offset
.
Resize
(
height_offset_shape
);
std
::
vector
<
int
>
height_lod_0
(
height_offset_len
,
0
);
for
(
size_t
i
=
0
;
i
<
param
.
Y
->
lod
()[
0
].
size
();
++
i
)
{
height_lod_0
[
i
]
=
static_cast
<
int
>
(
param
.
Y
->
lod
()[
0
][
i
]);
}
lite
::
TargetWrapperCuda
::
MemcpyAsync
(
_height_offset
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
)),
height_lod_0
.
data
(),
sizeof
(
int
)
*
height_offset_len
,
lite
::
IoDirection
::
HtoD
,
cuda_stream
);
const
Tensor
*
x_tensor
=
param
.
X
;
Tensor
*
out_tensor
=
param
.
Out
;
const
T
*
in_data
=
x_tensor
->
data
<
T
>
();
T
*
out_data
=
out_tensor
->
mutable_data
<
T
>
(
TARGET
(
kCUDA
));
int
num
=
x_tensor
->
dims
()[
0
];
int
channel
=
x_tensor
->
dims
()[
1
];
int
height
=
x_tensor
->
dims
()[
2
];
int
width
=
x_tensor
->
dims
()[
3
];
const
int
*
height_offset
=
_height_offset
.
data
<
int
>
();
const
int
*
width_offset
=
_width_offset
.
data
<
int
>
();
int
feat_map_size
=
height
*
width
;
if
(
feat_map_size
*
sizeof
(
T
)
<=
_shared_mem_size
)
{
dim3
blocks
(
num
,
channel
);
dim3
threads
(
32
,
1
);
top_k_pooling_batch_kernel_reduction
<
T
><<<
blocks
,
threads
,
feat_map_size
*
sizeof
(
T
),
cuda_stream
>>>
(
out_data
,
in_data
,
height_offset
,
width_offset
,
num
,
channel
,
height
,
width
,
param
.
top_k
);
}
else
{
LOG
(
FATAL
)
<<
"Not implemented. Exceeded the shared memory limit."
;
}
CUDA_POST_KERNEL_CHECK
;
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
topk_pooling
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
TopkPoolingCompute
<
float
>
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
Finalize
();
lite/kernels/cuda/topk_pooling_compute.h
0 → 100644
浏览文件 @
2402c742
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cudnn.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
>
class
TopkPoolingCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
{
public:
using
param_t
=
operators
::
TopkPoolingParam
;
void
Run
()
override
;
void
PrepareForRun
()
override
;
virtual
~
TopkPoolingCompute
()
=
default
;
protected:
lite
::
Tensor
_height_offset
;
lite
::
Tensor
_width_offset
;
int
_shared_mem_size
;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/topk_pooling_compute_test.cc
0 → 100644
浏览文件 @
2402c742
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/topk_pooling_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
TopkPooingTest
:
public
::
testing
::
Test
{
protected:
TopkPooingTest
()
:
num
(
2
),
channels
(
4
),
height
(
4
),
width
(
4
),
top_k
(
2
),
feat_map_num
(
height
*
width
),
x_lod
({{
0
,
4
,
7
}}),
y_lod
({{
0
,
4
,
7
}}),
x_shape
({
num
,
channels
,
height
,
width
}),
out_shape
({
num
,
channels
*
top_k
})
{
CHECK_EQ
(
x_lod
[
0
].
size
(),
num
+
1
)
<<
"invalid input."
;
for
(
size_t
i
=
1
;
i
<
x_lod
[
0
].
size
();
++
i
)
{
CHECK_LE
(
x_lod
[
0
][
i
]
-
x_lod
[
0
][
i
-
1
],
height
)
<<
"invalid input."
;
}
X_gpu
.
Resize
(
lite
::
DDim
(
x_shape
));
X_ref
.
Resize
(
lite
::
DDim
(
x_shape
));
X_ref
.
set_lod
(
x_lod
);
Y_gpu
.
Resize
(
lite
::
DDim
(
x_shape
));
Y_ref
.
Resize
(
lite
::
DDim
(
x_shape
));
Y_ref
.
set_lod
(
y_lod
);
auto
x_ref_data
=
X_ref
.
mutable_data
<
float
>
();
auto
y_ref_data
=
Y_ref
.
mutable_data
<
float
>
();
// prepare input
for
(
int64_t
i
=
0
;
i
<
X_ref
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
16
);
}
for
(
int64_t
i
=
0
;
i
<
Y_ref
.
numel
();
i
++
)
{
y_ref_data
[
i
]
=
static_cast
<
float
>
(
i
%
16
);
}
Out_ref
.
Resize
(
lite
::
DDim
(
out_shape
));
Out_gpu
.
Resize
(
lite
::
DDim
(
out_shape
));
Out_cpu
.
Resize
(
lite
::
DDim
(
out_shape
));
device_init
();
}
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
param
.
X
=
&
X_gpu
;
param
.
Y
=
&
Y_gpu
;
param
.
Out
=
&
Out_gpu
;
param
.
top_k
=
top_k
;
param
.
feat_map_num
=
feat_map_num
;
}
void
float_data_init
()
{
X_gpu
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
X_ref
.
data
<
float
>
(),
X_gpu
.
dims
());
X_gpu
.
set_lod
(
X_ref
.
lod
());
Y_gpu
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
Y_ref
.
data
<
float
>
(),
Y_gpu
.
dims
());
Y_gpu
.
set_lod
(
Y_ref
.
lod
());
}
void
half_data_init
()
{}
void
cpu_base
(
const
lite
::
Tensor
*
X
,
const
lite
::
Tensor
*
Y
,
lite
::
Tensor
*
Out
)
{}
int
num
,
channels
,
height
,
width
;
int
top_k
,
feat_map_num
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
x_lod
,
y_lod
;
std
::
vector
<
int64_t
>
x_shape
,
out_shape
;
lite
::
Tensor
X_ref
,
Y_ref
,
Out_ref
;
lite
::
Tensor
X_gpu
,
Y_gpu
;
lite
::
Tensor
Out_cpu
,
Out_gpu
;
operators
::
TopkPoolingParam
param
;
std
::
unique_ptr
<
KernelContext
>
ctx
;
cudaStream_t
stream
;
};
TEST_F
(
TopkPooingTest
,
fp32
)
{
float_data_init
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
TopkPoolingCompute
<
float
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
CopySync
<
TARGET
(
kCUDA
)
>
(
Out_cpu
.
mutable_data
<
float
>
(),
Out_gpu
.
data
<
float
>
(),
sizeof
(
float
)
*
Out_gpu
.
numel
(),
IoDirection
::
DtoH
);
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/operators/CMakeLists.txt
浏览文件 @
2402c742
...
...
@@ -147,6 +147,7 @@ add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
add_operator
(
sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
search_fc_op basic SRCS search_fc_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
lstm_op extra SRCS lstm_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
topk_pooling_op extra SRCS topk_pooling_op.cc DEPS
${
op_DEPS
}
)
# for deformable-convNet
add_operator
(
deformable_conv_op extra SRCS deformable_conv_op.cc DEPS
${
op_DEPS
}
)
...
...
lite/operators/op_params.h
浏览文件 @
2402c742
...
...
@@ -1344,6 +1344,15 @@ struct SequenceTopkAvgPoolingParam : ParamBase {
std
::
vector
<
int
>
topks
{};
};
/// --------------- topk_pooling operators ------------------
struct
TopkPoolingParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
int
top_k
{
1
};
int
feat_map_num
{
1
};
};
/// --------------- search_fc operators ------------------
struct
SearchFcParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
...
...
lite/operators/topk_pooling_op.cc
0 → 100644
浏览文件 @
2402c742
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/topk_pooling_op.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
TopkPoolingOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
bool
TopkPoolingOp
::
InferShapeImpl
()
const
{
auto
out_dims
=
param_
.
X
->
dims
();
out_dims
[
1
]
*=
param_
.
top_k
;
auto
out
=
param_
.
Out
;
out
->
Resize
(
out_dims
);
out
->
set_lod
(
param_
.
X
->
lod
());
return
true
;
}
bool
TopkPoolingOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
y
=
op_desc
.
Input
(
"Y"
).
front
();
param_
.
X
=
scope
->
FindTensor
(
x
);
param_
.
Y
=
scope
->
FindTensor
(
y
);
auto
output
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
Out
=
scope
->
FindMutableTensor
(
output
);
param_
.
top_k
=
op_desc
.
GetAttr
<
int
>
(
"top_k"
);
param_
.
feat_map_num
=
op_desc
.
GetAttr
<
int
>
(
"feat_map_num"
);
return
true
;
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
topk_pooling
,
paddle
::
lite
::
operators
::
TopkPoolingOp
);
lite/operators/topk_pooling_op.h
0 → 100644
浏览文件 @
2402c742
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
TopkPoolingOp
:
public
OpLite
{
public:
TopkPoolingOp
()
{}
explicit
TopkPoolingOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShapeImpl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"topk_pooling"
;
}
private:
mutable
TopkPoolingParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录