Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
9b84dc91
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b84dc91
编写于
1月 15, 2020
作者:
W
Wilber
提交者:
GitHub
1月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix var_conv_2d to support cascading use. test=develop (#2766)
- 修复var_conv_2d级联使用中计算错误的bug - x86的var_conv_2d中显示指定lod level为3
上级
974c50db
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
69 addition
and
0 deletion
+69
-0
lite/kernels/cuda/var_conv_2d_compute.cu
lite/kernels/cuda/var_conv_2d_compute.cu
+67
-0
lite/kernels/cuda/var_conv_2d_compute.h
lite/kernels/cuda/var_conv_2d_compute.h
+1
-0
lite/kernels/x86/var_conv_2d_compute.h
lite/kernels/x86/var_conv_2d_compute.h
+1
-0
未找到文件。
lite/kernels/cuda/var_conv_2d_compute.cu
浏览文件 @
9b84dc91
...
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <functional>
#include <memory>
#include <vector>
#include "lite/backends/cuda/math/gemm.h"
...
...
@@ -38,6 +39,32 @@ inline int ConvOutputSize(int input_size,
return
output_size
;
}
// Eliminate the effects of pad, support batch > 1.
template
<
typename
dtype
>
__global__
void
eliminate_pad_effect
(
dtype
*
src
,
const
int64_t
*
offset
,
const
int
num_batch
,
const
int
batch_stride
,
const
int
num_channel
,
const
int
channel_stride
,
const
int
num_height
,
const
int
height_stride
,
const
int
num_width
,
const
int
width_stride
,
const
int
count
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
thread_num
=
blockDim
.
x
*
gridDim
.
x
;
for
(
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
tid
<
count
;
tid
+=
thread_num
)
{
int
batch_id
=
tid
/
batch_stride
;
int
width_id
=
tid
%
num_width
;
int
cur_len
=
offset
[
batch_id
+
1
]
-
offset
[
batch_id
];
if
(
width_id
>=
cur_len
)
{
src
[
tid
]
=
0.
;
}
}
}
void
VarConv2DCompute
::
PrepareForRun
()
{
auto
&
context
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
context
.
exec_stream
();
...
...
@@ -102,6 +129,46 @@ void VarConv2DCompute::Run() {
conv_param_
.
output
->
Resize
({
output_shape
});
conv_impl_
->
create
(
conv_param_
,
&
context
);
conv_impl_
->
run
(
conv_param_
);
// Avoid situations where cascading conv does not support multiple batch
// calculations
float
*
out_data
=
param
.
Out
->
mutable_data
<
float
>
();
const
int
batch_num
=
output_shape
[
1
]
*
output_shape
[
2
]
*
output_shape
[
3
];
std
::
vector
<
int64_t
>
lod
(
param
.
X
->
lod
()[
0
].
size
(),
0
);
for
(
size_t
i
=
0
;
i
<
param
.
X
->
lod
()[
0
].
size
();
++
i
)
{
lod
[
i
]
=
param
.
X
->
lod
()[
0
][
i
];
}
int
count
=
std
::
accumulate
(
output_shape
.
begin
(),
output_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
int
width_stride
=
1
;
int
height_stride
=
output_shape
[
3
];
int
channel_stride
=
output_shape
[
2
]
*
output_shape
[
3
];
int
batch_stride
=
output_shape
[
1
]
*
output_shape
[
2
]
*
output_shape
[
3
];
int
threads
=
512
;
int
blocks
=
(
count
+
threads
-
1
)
/
threads
;
offset_
.
Resize
({
static_cast
<
int64_t
>
(
lod
.
size
())});
int64_t
*
d_offset
=
offset_
.
mutable_data
<
int64_t
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
d_offset
,
lod
.
data
(),
sizeof
(
int64_t
)
*
lod
.
size
(),
IoDirection
::
HtoD
,
stream
);
eliminate_pad_effect
<
float
><<<
blocks
,
threads
,
0
,
stream
>>>
(
out_data
,
d_offset
,
output_shape
[
0
],
batch_stride
,
output_shape
[
1
],
channel_stride
,
output_shape
[
2
],
height_stride
,
output_shape
[
3
],
width_stride
,
count
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
ERROR
)
<<
cudaGetErrorString
(
error
);
}
}
// namespace cuda
...
...
lite/kernels/cuda/var_conv_2d_compute.h
浏览文件 @
9b84dc91
...
...
@@ -33,6 +33,7 @@ class VarConv2DCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
private:
mutable
operators
::
ConvParam
conv_param_
;
std
::
unique_ptr
<
lite
::
cuda
::
math
::
CudnnConv2D
<
PRECISION
(
kFloat
)
>>
conv_impl_
;
lite
::
Tensor
offset_
;
};
}
// namespace cuda
...
...
lite/kernels/x86/var_conv_2d_compute.h
浏览文件 @
9b84dc91
...
...
@@ -44,6 +44,7 @@ class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
// 2-D lod info.
// const auto& offset_x = in_col->lod()[0];
// const auto& offset_y = in_row->lod()[0];
CHECK_EQ
(
param
.
X
->
lod
().
size
(),
3
)
<<
"input lod size should be 3!"
;
const
auto
&
offset_y
=
param
.
X
->
lod
()[
1
];
const
auto
&
offset_x
=
param
.
X
->
lod
()[
2
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录