Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
1f075a8b
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
1f075a8b
编写于
11月 14, 2019
作者:
L
liu zhengxi
提交者:
GitHub
11月 14, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the compile error for cuda and fix unit test (#2424)
* fix the compile error for cuda and fix unit test
上级
bfd2a950
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
51 addition
and
44 deletion
+51
-44
lite/kernels/cuda/bilinear_interp_compute.cu
lite/kernels/cuda/bilinear_interp_compute.cu
+2
-5
lite/kernels/cuda/bilinear_interp_compute_test.cc
lite/kernels/cuda/bilinear_interp_compute_test.cc
+24
-18
lite/kernels/cuda/nearest_interp_compute.cu
lite/kernels/cuda/nearest_interp_compute.cu
+2
-2
lite/kernels/cuda/nearest_interp_compute_test.cc
lite/kernels/cuda/nearest_interp_compute_test.cc
+22
-18
lite/operators/op_params.h
lite/operators/op_params.h
+1
-1
未找到文件。
lite/kernels/cuda/bilinear_interp_compute.cu
浏览文件 @
1f075a8b
...
...
@@ -29,7 +29,7 @@ inline std::vector<int> get_new_shape(
auto
tensor
=
list_new_shape_tensor
[
i
];
lite
::
Tensor
temp
;
auto
temp_data
=
temp
.
mutable_data
<
int32_t
>
();
auto
tensor_data
=
tensor
->
data
<
int32_t
>
(
TARGET
(
kCUDA
)
);
auto
tensor_data
=
tensor
->
data
<
int32_t
>
();
cudaMemcpy
(
temp_data
,
tensor_data
,
tensor
->
dims
().
production
()
*
sizeof
(
float
),
...
...
@@ -44,7 +44,7 @@ inline std::vector<int> get_new_shape(
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
(
kCUDA
);
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
();
lite
::
Tensor
cpu_starts_tensor
;
auto
cpu_starts_tensor_data
=
cpu_starts_tensor
.
mutable_data
<
T
>
();
cudaMemcpy
(
cpu_starts_tensor_data
,
...
...
@@ -141,7 +141,6 @@ void BilinearInterpCompute::Run() {
int
out_w
=
param
.
out_w
;
float
scale
=
param
.
scale
;
bool
align_corners
=
param
.
align_corners
;
auto
align_mode
=
param
.
align_mode
;
auto
list_new_shape_tensor
=
param
.
SizeTensor
;
if
(
list_new_shape_tensor
.
size
()
>
0
)
{
...
...
@@ -159,7 +158,6 @@ void BilinearInterpCompute::Run() {
out_h
=
static_cast
<
int
>
(
in_h
*
scale
);
out_w
=
static_cast
<
int
>
(
in_w
*
scale
);
}
if
(
out_size
!=
nullptr
)
{
lite
::
Tensor
sizes
;
float
*
size_data
=
sizes
.
mutable_data
<
float
>
();
...
...
@@ -172,7 +170,6 @@ void BilinearInterpCompute::Run() {
}
auto
output_data
=
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
cudaMemcpy
(
output_data
,
input_data
,
...
...
lite/kernels/cuda/bilinear_interp_compute_test.cc
浏览文件 @
1f075a8b
...
...
@@ -106,10 +106,11 @@ TEST(bilinear_interp, update) {
operators
::
InterpolateParam
param
;
std
::
vector
<
Tensor
*>
size_tensor
(
2
),
size_tensor_cpu
(
2
),
size_tensor_ref
(
2
);
std
::
vector
<
Tensor
>
size_tensor
(
2
);
std
::
vector
<
Tensor
>
size_tensor_cpu
(
2
),
size_tensor_ref
(
2
);
Tensor
x
,
input_scale
,
osz
,
out
;
Tensor
x_cpu
,
input_scale_cpu
,
osz_cpu
,
out_cpu
;
Tensor
x_ref
,
size_tensor_ref
,
input_scale_ref
,
osz_ref
,
out_ref
;
Tensor
x_ref
,
input_scale_ref
,
osz_ref
,
out_ref
;
int
n
=
1
,
c
=
1
,
in_h
=
3
,
in_w
=
3
;
int
out_h
=
6
,
out_w
=
6
;
...
...
@@ -122,22 +123,22 @@ TEST(bilinear_interp, update) {
param
.
align_mode
=
0
;
x
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor
[
0
]
->
Resize
({
1
});
size_tensor
[
1
]
->
Resize
({
1
});
size_tensor
[
0
]
.
Resize
({
1
});
size_tensor
[
1
]
.
Resize
({
1
});
input_scale
.
Resize
({
1
});
osz
.
Resize
({
2
});
out
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_cpu
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_cpu
[
0
]
->
Resize
({
1
});
size_tensor_cpu
[
1
]
->
Resize
({
1
});
size_tensor_cpu
[
0
]
.
Resize
({
1
});
size_tensor_cpu
[
1
]
.
Resize
({
1
});
input_scale_cpu
.
Resize
({
1
});
osz_cpu
.
Resize
({
2
});
out_cpu
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_ref
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_ref
[
0
]
->
Resize
({
1
});
size_tensor_ref
[
1
]
->
Resize
({
1
});
size_tensor_ref
[
0
]
.
Resize
({
1
});
size_tensor_ref
[
1
]
.
Resize
({
1
});
input_scale_ref
.
Resize
({
1
});
osz_ref
.
Resize
({
2
});
out_ref
.
Resize
({
n
,
c
,
out_h
,
out_w
});
...
...
@@ -145,15 +146,15 @@ TEST(bilinear_interp, update) {
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
float
*
size_tensor0_cpu_data
=
size_tensor_cpu
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_cpu_data
=
size_tensor_cpu
[
1
]
->
mutable_data
<
float
>
();
float
*
size_tensor0_cpu_data
=
size_tensor_cpu
[
0
]
.
mutable_data
<
float
>
();
float
*
size_tensor1_cpu_data
=
size_tensor_cpu
[
1
]
.
mutable_data
<
float
>
();
float
*
input_scale_cpu_data
=
input_scale_cpu
.
mutable_data
<
float
>
();
float
*
osz_cpu_data
=
osz_cpu
.
mutable_data
<
float
>
();
float
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
float
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
float
*
size_tensor0_ref_data
=
size_tensor_ref
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_ref_data
=
size_tensor_ref
[
1
]
->
mutable_data
<
float
>
();
float
*
size_tensor0_ref_data
=
size_tensor_ref
[
0
]
.
mutable_data
<
float
>
();
float
*
size_tensor1_ref_data
=
size_tensor_ref
[
1
]
.
mutable_data
<
float
>
();
float
*
input_scale_ref_data
=
input_scale_ref
.
mutable_data
<
float
>
();
float
*
osz_ref_data
=
osz_ref
.
mutable_data
<
float
>
();
...
...
@@ -161,6 +162,7 @@ TEST(bilinear_interp, update) {
x_cpu_data
[
i
]
=
i
+
5.0
;
x_ref_data
[
i
]
=
i
+
5.0
;
}
osz_cpu_data
[
0
]
=
out_h
;
osz_cpu_data
[
1
]
=
out_w
;
size_tensor0_cpu_data
[
0
]
=
out_h
;
...
...
@@ -173,19 +175,23 @@ TEST(bilinear_interp, update) {
input_scale_ref_data
[
0
]
=
scale
;
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
size_tensor
[
0
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor0_cpu_data
,
{
1
}
);
size_tensor
[
1
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor1_cpu_data
,
{
1
}
);
size_tensor
[
0
]
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor0_cpu_data
,
size_tensor
[
0
].
dims
()
);
size_tensor
[
1
]
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor1_cpu_data
,
size_tensor
[
1
].
dims
()
);
input_scale
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
input_scale_cpu_data
,
{
1
}
);
input_scale
.
dims
()
);
osz
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
osz_cpu_data
,
osz_cpu
.
dims
());
param
.
X
=
&
x
;
param
.
SizeTensor
=
size_tensor
;
param
.
SizeTensor
.
emplace_back
(
reinterpret_cast
<
const
Tensor
*>
(
&
size_tensor
[
0
]));
param
.
SizeTensor
.
emplace_back
(
reinterpret_cast
<
const
Tensor
*>
(
&
size_tensor
[
1
]));
param
.
Scale
=
&
input_scale
;
param
.
OutSize
=
&
osz
;
param
.
Out
=
&
out
;
bilinear_interp_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
...
...
lite/kernels/cuda/nearest_interp_compute.cu
浏览文件 @
1f075a8b
...
...
@@ -29,7 +29,7 @@ inline std::vector<int> get_new_shape(
auto
tensor
=
list_new_shape_tensor
[
i
];
lite
::
Tensor
temp
;
auto
temp_data
=
temp
.
mutable_data
<
int32_t
>
();
auto
tensor_data
=
tensor
->
data
<
int32_t
>
(
TARGET
(
kCUDA
)
);
auto
tensor_data
=
tensor
->
data
<
int32_t
>
();
cudaMemcpy
(
temp_data
,
tensor_data
,
tensor
->
dims
().
production
()
*
sizeof
(
float
),
...
...
@@ -44,7 +44,7 @@ inline std::vector<int> get_new_shape(
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
(
kCUDA
);
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
();
lite
::
Tensor
cpu_starts_tensor
;
auto
cpu_starts_tensor_data
=
cpu_starts_tensor
.
mutable_data
<
T
>
();
cudaMemcpy
(
cpu_starts_tensor_data
,
...
...
lite/kernels/cuda/nearest_interp_compute_test.cc
浏览文件 @
1f075a8b
...
...
@@ -151,10 +151,11 @@ TEST(nearest_interp, update) {
operators
::
InterpolateParam
param
;
std
::
vector
<
Tensor
*>
size_tensor
(
2
),
size_tensor_cpu
(
2
),
size_tensor_ref
(
2
);
std
::
vector
<
Tensor
>
size_tensor
(
2
);
std
::
vector
<
Tensor
>
size_tensor_cpu
(
2
),
size_tensor_ref
(
2
);
Tensor
x
,
input_scale
,
osz
,
out
;
Tensor
x_cpu
,
input_scale_cpu
,
osz_cpu
,
out_cpu
;
Tensor
x_ref
,
size_tensor_ref
,
input_scale_ref
,
osz_ref
,
out_ref
;
Tensor
x_ref
,
input_scale_ref
,
osz_ref
,
out_ref
;
int
n
=
1
,
c
=
3
,
in_h
=
40
,
in_w
=
40
;
int
out_h
=
80
,
out_w
=
80
;
...
...
@@ -167,22 +168,22 @@ TEST(nearest_interp, update) {
param
.
align_mode
=
0
;
x
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor
[
0
]
->
Resize
({
1
});
size_tensor
[
1
]
->
Resize
({
1
});
size_tensor
[
0
]
.
Resize
({
1
});
size_tensor
[
1
]
.
Resize
({
1
});
input_scale
.
Resize
({
1
});
osz
.
Resize
({
2
});
out
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_cpu
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_cpu
[
0
]
->
Resize
({
1
});
size_tensor_cpu
[
1
]
->
Resize
({
1
});
size_tensor_cpu
[
0
]
.
Resize
({
1
});
size_tensor_cpu
[
1
]
.
Resize
({
1
});
input_scale_cpu
.
Resize
({
1
});
osz_cpu
.
Resize
({
2
});
out_cpu
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_ref
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_ref
[
0
]
->
Resize
({
1
});
size_tensor_ref
[
1
]
->
Resize
({
1
});
size_tensor_ref
[
0
]
.
Resize
({
1
});
size_tensor_ref
[
1
]
.
Resize
({
1
});
input_scale_ref
.
Resize
({
1
});
osz_ref
.
Resize
({
2
});
out_ref
.
Resize
({
n
,
c
,
out_h
,
out_w
});
...
...
@@ -190,15 +191,15 @@ TEST(nearest_interp, update) {
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
float
*
size_tensor0_cpu_data
=
size_tensor_cpu
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_cpu_data
=
size_tensor_cpu
[
1
]
->
mutable_data
<
float
>
();
float
*
size_tensor0_cpu_data
=
size_tensor_cpu
[
0
]
.
mutable_data
<
float
>
();
float
*
size_tensor1_cpu_data
=
size_tensor_cpu
[
1
]
.
mutable_data
<
float
>
();
float
*
input_scale_cpu_data
=
input_scale_cpu
.
mutable_data
<
float
>
();
float
*
osz_cpu_data
=
osz_cpu
.
mutable_data
<
float
>
();
float
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
float
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
float
*
size_tensor0_ref_data
=
size_tensor_ref
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_ref_data
=
size_tensor_ref
[
1
]
->
mutable_data
<
float
>
();
float
*
size_tensor0_ref_data
=
size_tensor_ref
[
0
]
.
mutable_data
<
float
>
();
float
*
size_tensor1_ref_data
=
size_tensor_ref
[
1
]
.
mutable_data
<
float
>
();
float
*
input_scale_ref_data
=
input_scale_ref
.
mutable_data
<
float
>
();
float
*
osz_ref_data
=
osz_ref
.
mutable_data
<
float
>
();
...
...
@@ -218,16 +219,19 @@ TEST(nearest_interp, update) {
input_scale_ref_data
[
0
]
=
scale
;
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
size_tensor
[
0
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor0_cpu_data
,
{
1
}
);
size_tensor
[
1
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor1_cpu_data
,
{
1
}
);
size_tensor
[
0
]
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor0_cpu_data
,
size_tensor
[
0
].
dims
()
);
size_tensor
[
1
]
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor1_cpu_data
,
size_tensor
[
1
].
dims
()
);
input_scale
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
input_scale_cpu_data
,
{
1
}
);
input_scale
.
dims
()
);
osz
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
osz_cpu_data
,
osz_cpu
.
dims
());
param
.
X
=
&
x
;
param
.
SizeTensor
=
size_tensor
;
param
.
SizeTensor
.
emplace_back
(
reinterpret_cast
<
const
Tensor
*>
(
&
size_tensor
[
0
]));
param
.
SizeTensor
.
emplace_back
(
reinterpret_cast
<
const
Tensor
*>
(
&
size_tensor
[
1
]));
param
.
Scale
=
&
input_scale
;
param
.
OutSize
=
&
osz
;
param
.
Out
=
&
out
;
...
...
lite/operators/op_params.h
浏览文件 @
1f075a8b
...
...
@@ -95,7 +95,7 @@ struct InterpolateParam {
lite
::
Tensor
*
OutSize
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
const
lite
::
Tensor
*>
SizeTensor
;
lite
::
Tensor
*
Scale
;
lite
::
Tensor
*
Scale
{}
;
float
scale
{
0.
f
};
int
out_h
{
-
1
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录