Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
518a87ef
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
518a87ef
编写于
11月 13, 2019
作者:
L
liu zhengxi
提交者:
GitHub
11月 13, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update the ops to fluid (#2406)
align the lite nearest, bilinear op to fluid on arm and cuda
上级
f4e06650
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
635 addition
and
88 deletion
+635
-88
lite/backends/arm/math/interpolate.cc
lite/backends/arm/math/interpolate.cc
+53
-12
lite/backends/arm/math/interpolate.h
lite/backends/arm/math/interpolate.h
+3
-2
lite/kernels/arm/interpolate_compute.cc
lite/kernels/arm/interpolate_compute.cc
+12
-2
lite/kernels/cuda/bilinear_interp_compute.cu
lite/kernels/cuda/bilinear_interp_compute.cu
+73
-12
lite/kernels/cuda/bilinear_interp_compute_test.cc
lite/kernels/cuda/bilinear_interp_compute_test.cc
+105
-0
lite/kernels/cuda/nearest_interp_compute.cu
lite/kernels/cuda/nearest_interp_compute.cu
+74
-13
lite/kernels/cuda/nearest_interp_compute_test.cc
lite/kernels/cuda/nearest_interp_compute_test.cc
+105
-0
lite/operators/interpolate_op.cc
lite/operators/interpolate_op.cc
+46
-9
lite/operators/op_params.h
lite/operators/op_params.h
+3
-0
lite/tests/kernels/bilinear_interp_compute_test.cc
lite/tests/kernels/bilinear_interp_compute_test.cc
+79
-21
lite/tests/kernels/nearest_interp_compute_test.cc
lite/tests/kernels/nearest_interp_compute_test.cc
+76
-8
lite/tests/kernels/shuffle_channel_compute_test.cc
lite/tests/kernels/shuffle_channel_compute_test.cc
+6
-9
未找到文件。
lite/backends/arm/math/interpolate.cc
浏览文件 @
518a87ef
...
...
@@ -22,6 +22,28 @@ namespace lite {
namespace
arm
{
namespace
math
{
inline
std
::
vector
<
int
>
get_new_shape
(
std
::
vector
<
const
lite
::
Tensor
*>
list_new_shape_tensor
)
{
// get tensor from
std
::
vector
<
int
>
vec_new_shape
;
for
(
size_t
i
=
0
;
i
<
list_new_shape_tensor
.
size
();
++
i
)
{
auto
tensor
=
list_new_shape_tensor
[
i
];
vec_new_shape
.
push_back
(
static_cast
<
int32_t
>
(
*
tensor
->
data
<
int32_t
>
()));
}
return
vec_new_shape
;
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
();
lite
::
Tensor
cpu_starts_tensor
;
vec_new_data
=
std
::
vector
<
T
>
(
new_data
,
new_data
+
new_data_tensor
->
dims
().
production
());
return
vec_new_data
;
}
// The following function bilinear_interp is partially base on
// https://github.com/Tencent/ncnn/blob/master/src/layer/arm/interp_arm.cpp
// Tencent is pleased to support the open source community by making ncnn
...
...
@@ -472,33 +494,52 @@ void nearest_interp(const float* src,
void
interpolate
(
lite
::
Tensor
*
X
,
lite
::
Tensor
*
OutSize
,
std
::
vector
<
const
lite
::
Tensor
*>
SizeTensor
,
lite
::
Tensor
*
Scale
,
lite
::
Tensor
*
Out
,
int
out_height
,
int
out_width
,
float
height_scale
,
float
width_scale
,
float
scale
,
bool
with_align
,
std
::
string
interpolate_type
)
{
int
in_h
=
X
->
dims
()[
2
];
int
in_w
=
X
->
dims
()[
3
];
if
(
SizeTensor
.
size
()
>
0
)
{
auto
new_size
=
get_new_shape
(
SizeTensor
);
out_height
=
new_size
[
0
];
out_width
=
new_size
[
1
];
}
else
{
auto
scale_tensor
=
Scale
;
if
(
scale_tensor
!=
nullptr
)
{
auto
scale_data
=
get_new_data_from_tensor
<
float
>
(
scale_tensor
);
scale
=
scale_data
[
0
];
}
if
(
scale
>
0
)
{
out_height
=
static_cast
<
int
>
(
in_h
*
scale
);
out_width
=
static_cast
<
int
>
(
in_w
*
scale
);
}
auto
out_size
=
OutSize
;
if
(
out_size
!=
nullptr
)
{
auto
out_size_data
=
get_new_data_from_tensor
<
float
>
(
out_size
);
out_height
=
static_cast
<
int
>
(
out_size_data
[
0
]);
out_width
=
static_cast
<
int
>
(
out_size_data
[
1
]);
}
}
float
height_scale
=
scale
;
float
width_scale
=
scale
;
if
(
out_width
>
0
&&
out_height
>
0
)
{
height_scale
=
static_cast
<
float
>
(
out_height
/
X
->
dims
()[
2
]);
width_scale
=
static_cast
<
float
>
(
out_width
/
X
->
dims
()[
3
]);
}
if
(
OutSize
!=
nullptr
)
{
auto
OutSize_data
=
OutSize
->
data
<
int
>
();
int
h_out
=
OutSize_data
[
0
];
// HW
int
w_out
=
OutSize_data
[
1
];
// HW
int
num_cout
=
Out
->
dims
()[
0
];
int
c_cout
=
Out
->
dims
()[
1
];
Out
->
Resize
({
num_cout
,
c_cout
,
h_out
,
w_out
});
}
int
num_cout
=
X
->
dims
()[
0
];
int
c_cout
=
X
->
dims
()[
1
];
Out
->
Resize
({
num_cout
,
c_cout
,
out_height
,
out_width
});
float
*
dout
=
Out
->
mutable_data
<
float
>
();
const
float
*
din
=
X
->
data
<
float
>
();
int
out_num
=
Out
->
dims
()[
0
];
int
out_c
=
Out
->
dims
()[
1
];
int
count
=
out_num
*
out_c
;
int
in_h
=
X
->
dims
()[
2
];
int
in_w
=
X
->
dims
()[
3
];
int
out_h
=
Out
->
dims
()[
2
];
int
out_w
=
Out
->
dims
()[
3
];
int
spatial_in
=
in_h
*
in_w
;
...
...
lite/backends/arm/math/interpolate.h
浏览文件 @
518a87ef
...
...
@@ -44,11 +44,12 @@ void nearest_interp(const float* src,
void
interpolate
(
lite
::
Tensor
*
X
,
lite
::
Tensor
*
OutSize
,
std
::
vector
<
const
lite
::
Tensor
*>
SizeTensor
,
lite
::
Tensor
*
Scale
,
lite
::
Tensor
*
Out
,
int
out_height
,
int
out_width
,
float
height_scale
,
float
width_scale
,
float
scale
,
bool
with_align
,
std
::
string
interpolate_type
);
...
...
lite/kernels/arm/interpolate_compute.cc
浏览文件 @
518a87ef
...
...
@@ -28,6 +28,8 @@ void BilinearInterpCompute::Run() {
auto
&
param
=
Param
<
operators
::
InterpolateParam
>
();
lite
::
Tensor
*
X
=
param
.
X
;
lite
::
Tensor
*
OutSize
=
param
.
OutSize
;
auto
SizeTensor
=
param
.
SizeTensor
;
auto
Scale
=
param
.
Scale
;
lite
::
Tensor
*
Out
=
param
.
Out
;
float
scale
=
param
.
scale
;
int
out_w
=
param
.
out_w
;
...
...
@@ -36,11 +38,12 @@ void BilinearInterpCompute::Run() {
std
::
string
interp_method
=
"Bilinear"
;
lite
::
arm
::
math
::
interpolate
(
X
,
OutSize
,
SizeTensor
,
Scale
,
Out
,
out_h
,
out_w
,
scale
,
scale
,
align_corners
,
interp_method
);
}
...
...
@@ -49,6 +52,8 @@ void NearestInterpCompute::Run() {
auto
&
param
=
Param
<
operators
::
InterpolateParam
>
();
lite
::
Tensor
*
X
=
param
.
X
;
lite
::
Tensor
*
OutSize
=
param
.
OutSize
;
auto
SizeTensor
=
param
.
SizeTensor
;
auto
Scale
=
param
.
Scale
;
lite
::
Tensor
*
Out
=
param
.
Out
;
float
scale
=
param
.
scale
;
int
out_w
=
param
.
out_w
;
...
...
@@ -57,11 +62,12 @@ void NearestInterpCompute::Run() {
std
::
string
interp_method
=
"Nearest"
;
lite
::
arm
::
math
::
interpolate
(
X
,
OutSize
,
SizeTensor
,
Scale
,
Out
,
out_h
,
out_w
,
scale
,
scale
,
align_corners
,
interp_method
);
}
...
...
@@ -79,6 +85,8 @@ REGISTER_LITE_KERNEL(bilinear_interp,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"OutSize"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"SizeTensor"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Scale"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
...
...
@@ -90,5 +98,7 @@ REGISTER_LITE_KERNEL(nearest_interp,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"OutSize"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"SizeTensor"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Scale"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
lite/kernels/cuda/bilinear_interp_compute.cu
浏览文件 @
518a87ef
...
...
@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/bilinear_interp_compute.h"
...
...
@@ -20,6 +21,43 @@ namespace kernels {
namespace
cuda
{
using
Tensor
=
lite
::
Tensor
;
inline
std
::
vector
<
int
>
get_new_shape
(
std
::
vector
<
const
lite
::
Tensor
*>
list_new_shape_tensor
)
{
// get tensor from
std
::
vector
<
int
>
vec_new_shape
;
for
(
size_t
i
=
0
;
i
<
list_new_shape_tensor
.
size
();
++
i
)
{
auto
tensor
=
list_new_shape_tensor
[
i
];
lite
::
Tensor
temp
;
auto
temp_data
=
temp
.
mutable_data
<
int32_t
>
();
auto
tensor_data
=
tensor
->
data
<
int32_t
>
(
TARGET
(
kCUDA
));
cudaMemcpy
(
temp_data
,
tensor_data
,
tensor
->
dims
().
production
()
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
vec_new_shape
.
push_back
(
static_cast
<
int32_t
>
(
*
temp_data
));
}
return
vec_new_shape
;
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
(
kCUDA
);
lite
::
Tensor
cpu_starts_tensor
;
auto
cpu_starts_tensor_data
=
cpu_starts_tensor
.
mutable_data
<
T
>
();
cudaMemcpy
(
cpu_starts_tensor_data
,
new_data
,
new_data_tensor
->
dims
().
production
()
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
auto
new_data_
=
cpu_starts_tensor
.
data
<
T
>
();
vec_new_data
=
std
::
vector
<
T
>
(
new_data_
,
new_data_
+
new_data_tensor
->
dims
().
production
());
return
vec_new_data
;
}
template
<
typename
T
>
__global__
void
BilinearInterp
(
const
T
*
in
,
const
size_t
in_img_h
,
...
...
@@ -103,19 +141,34 @@ void BilinearInterpCompute::Run() {
int
out_w
=
param
.
out_w
;
float
scale
=
param
.
scale
;
bool
align_corners
=
param
.
align_corners
;
if
(
scale
>
0
)
{
out_h
=
static_cast
<
int
>
(
in_h
*
scale
);
out_w
=
static_cast
<
int
>
(
in_w
*
scale
);
}
auto
align_mode
=
param
.
align_mode
;
auto
list_new_shape_tensor
=
param
.
SizeTensor
;
if
(
list_new_shape_tensor
.
size
()
>
0
)
{
// have size tensor
auto
new_size
=
get_new_shape
(
list_new_shape_tensor
);
out_h
=
new_size
[
0
];
out_w
=
new_size
[
1
];
}
else
{
auto
scale_tensor
=
param
.
Scale
;
if
(
scale_tensor
!=
nullptr
)
{
auto
scale_data
=
get_new_data_from_tensor
<
float
>
(
scale_tensor
);
scale
=
scale_data
[
0
];
}
if
(
scale
>
0
)
{
out_h
=
static_cast
<
int
>
(
in_h
*
scale
);
out_w
=
static_cast
<
int
>
(
in_w
*
scale
);
}
if
(
out_size
!=
nullptr
)
{
Tensor
sizes
;
float
*
size_data
=
sizes
.
mutable_data
<
float
>
();
float
*
outsize_data
=
out_size
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
cudaMemcpy
(
size_data
,
outsize_data
,
sizeof
(
float
)
*
2
,
cudaMemcpyDeviceToHost
);
out_h
=
static_cast
<
int
>
(
size_data
[
0
]);
out_w
=
static_cast
<
int
>
(
size_data
[
1
]);
if
(
out_size
!=
nullptr
)
{
lite
::
Tensor
sizes
;
float
*
size_data
=
sizes
.
mutable_data
<
float
>
();
float
*
outsize_data
=
out_size
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
cudaMemcpy
(
size_data
,
outsize_data
,
sizeof
(
float
)
*
2
,
cudaMemcpyDeviceToHost
);
out_h
=
static_cast
<
int
>
(
size_data
[
0
]);
out_w
=
static_cast
<
int
>
(
size_data
[
1
]);
}
}
auto
output_data
=
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
...
...
@@ -188,6 +241,14 @@ REGISTER_LITE_KERNEL(bilinear_interp,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindInput
(
"SizeTensor"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindInput
(
"Scale"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
...
...
lite/kernels/cuda/bilinear_interp_compute_test.cc
浏览文件 @
518a87ef
...
...
@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace
paddle
{
namespace
lite
{
...
...
@@ -98,6 +99,110 @@ TEST(bilinear_interp, normal) {
}
}
TEST
(
bilinear_interp
,
update
)
{
BilinearInterpCompute
bilinear_interp_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
operators
::
InterpolateParam
param
;
std
::
vector
<
Tensor
*>
size_tensor
(
2
),
size_tensor_cpu
(
2
),
size_tensor_ref
(
2
);
Tensor
x
,
input_scale
,
osz
,
out
;
Tensor
x_cpu
,
input_scale_cpu
,
osz_cpu
,
out_cpu
;
Tensor
x_ref
,
size_tensor_ref
,
input_scale_ref
,
osz_ref
,
out_ref
;
int
n
=
1
,
c
=
1
,
in_h
=
3
,
in_w
=
3
;
int
out_h
=
6
,
out_w
=
6
;
float
scale
=
2.0
;
param
.
out_h
=
out_h
;
param
.
out_w
=
out_w
;
param
.
scale
=
scale
;
param
.
align_corners
=
false
;
param
.
align_mode
=
0
;
x
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor
[
0
]
->
Resize
({
1
});
size_tensor
[
1
]
->
Resize
({
1
});
input_scale
.
Resize
({
1
});
osz
.
Resize
({
2
});
out
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_cpu
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_cpu
[
0
]
->
Resize
({
1
});
size_tensor_cpu
[
1
]
->
Resize
({
1
});
input_scale_cpu
.
Resize
({
1
});
osz_cpu
.
Resize
({
2
});
out_cpu
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_ref
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_ref
[
0
]
->
Resize
({
1
});
size_tensor_ref
[
1
]
->
Resize
({
1
});
input_scale_ref
.
Resize
({
1
});
osz_ref
.
Resize
({
2
});
out_ref
.
Resize
({
n
,
c
,
out_h
,
out_w
});
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
float
*
size_tensor0_cpu_data
=
size_tensor_cpu
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_cpu_data
=
size_tensor_cpu
[
1
]
->
mutable_data
<
float
>
();
float
*
input_scale_cpu_data
=
input_scale_cpu
.
mutable_data
<
float
>
();
float
*
osz_cpu_data
=
osz_cpu
.
mutable_data
<
float
>
();
float
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
float
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
float
*
size_tensor0_ref_data
=
size_tensor_ref
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_ref_data
=
size_tensor_ref
[
1
]
->
mutable_data
<
float
>
();
float
*
input_scale_ref_data
=
input_scale_ref
.
mutable_data
<
float
>
();
float
*
osz_ref_data
=
osz_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
++
i
)
{
x_cpu_data
[
i
]
=
i
+
5.0
;
x_ref_data
[
i
]
=
i
+
5.0
;
}
osz_cpu_data
[
0
]
=
out_h
;
osz_cpu_data
[
1
]
=
out_w
;
size_tensor0_cpu_data
[
0
]
=
out_h
;
size_tensor1_cpu_data
[
0
]
=
out_w
;
input_scale_cpu_data
[
0
]
=
scale
;
osz_ref_data
[
0
]
=
out_h
;
osz_ref_data
[
1
]
=
out_w
;
size_tensor0_ref_data
[
0
]
=
out_h
;
size_tensor1_ref_data
[
0
]
=
out_w
;
input_scale_ref_data
[
0
]
=
scale
;
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
size_tensor
[
0
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor0_cpu_data
,
{
1
});
size_tensor
[
1
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor1_cpu_data
,
{
1
});
input_scale
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
input_scale_cpu_data
,
{
1
});
osz
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
osz_cpu_data
,
osz_cpu
.
dims
());
param
.
X
=
&
x
;
param
.
SizeTensor
=
size_tensor
;
param
.
Scale
=
&
input_scale
;
param
.
OutSize
=
&
osz
;
param
.
Out
=
&
out
;
bilinear_interp_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
bilinear_interp_kernel
.
SetContext
(
std
::
move
(
ctx
));
bilinear_interp_kernel
.
Launch
();
cudaDeviceSynchronize
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
LOG
(
INFO
)
<<
out_cpu_data
[
i
];
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
...
...
lite/kernels/cuda/nearest_interp_compute.cu
浏览文件 @
518a87ef
...
...
@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/nearest_interp_compute.h"
...
...
@@ -20,6 +21,43 @@ namespace kernels {
namespace
cuda
{
using
Tensor
=
lite
::
Tensor
;
inline
std
::
vector
<
int
>
get_new_shape
(
std
::
vector
<
const
lite
::
Tensor
*>
list_new_shape_tensor
)
{
// get tensor from
std
::
vector
<
int
>
vec_new_shape
;
for
(
size_t
i
=
0
;
i
<
list_new_shape_tensor
.
size
();
++
i
)
{
auto
tensor
=
list_new_shape_tensor
[
i
];
lite
::
Tensor
temp
;
auto
temp_data
=
temp
.
mutable_data
<
int32_t
>
();
auto
tensor_data
=
tensor
->
data
<
int32_t
>
(
TARGET
(
kCUDA
));
cudaMemcpy
(
temp_data
,
tensor_data
,
tensor
->
dims
().
production
()
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
vec_new_shape
.
push_back
(
static_cast
<
int32_t
>
(
*
temp_data
));
}
return
vec_new_shape
;
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
(
kCUDA
);
lite
::
Tensor
cpu_starts_tensor
;
auto
cpu_starts_tensor_data
=
cpu_starts_tensor
.
mutable_data
<
T
>
();
cudaMemcpy
(
cpu_starts_tensor_data
,
new_data
,
new_data_tensor
->
dims
().
production
()
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
auto
new_data_
=
cpu_starts_tensor
.
data
<
T
>
();
vec_new_data
=
std
::
vector
<
T
>
(
new_data_
,
new_data_
+
new_data_tensor
->
dims
().
production
());
return
vec_new_data
;
}
__global__
void
KeNearestNeighborInterp
(
const
float
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
...
...
@@ -79,19 +117,34 @@ void NearestInterpCompute::Run() {
int
out_w
=
param
.
out_w
;
float
scale
=
param
.
scale
;
bool
align_corners
=
param
.
align_corners
;
if
(
scale
>
0
)
{
out_h
=
static_cast
<
int
>
(
in_h
*
scale
);
out_w
=
static_cast
<
int
>
(
in_w
*
scale
);
}
if
(
out_size
!=
nullptr
)
{
Tensor
sizes
;
float
*
size_data
=
sizes
.
mutable_data
<
float
>
();
float
*
outsize_data
=
out_size
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
cudaMemcpy
(
size_data
,
outsize_data
,
sizeof
(
float
)
*
2
,
cudaMemcpyDeviceToHost
);
out_h
=
static_cast
<
int
>
(
size_data
[
0
]);
out_w
=
static_cast
<
int
>
(
size_data
[
1
]);
auto
align_mode
=
param
.
align_mode
;
auto
list_new_shape_tensor
=
param
.
SizeTensor
;
if
(
list_new_shape_tensor
.
size
()
>
0
)
{
// have size tensor
auto
new_size
=
get_new_shape
(
list_new_shape_tensor
);
out_h
=
new_size
[
0
];
out_w
=
new_size
[
1
];
}
else
{
auto
scale_tensor
=
param
.
Scale
;
if
(
scale_tensor
!=
nullptr
)
{
auto
scale_data
=
get_new_data_from_tensor
<
float
>
(
scale_tensor
);
scale
=
scale_data
[
0
];
}
if
(
scale
>
0
)
{
out_h
=
static_cast
<
int
>
(
in_h
*
scale
);
out_w
=
static_cast
<
int
>
(
in_w
*
scale
);
}
if
(
out_size
!=
nullptr
)
{
lite
::
Tensor
sizes
;
float
*
size_data
=
sizes
.
mutable_data
<
float
>
();
float
*
outsize_data
=
out_size
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
cudaMemcpy
(
size_data
,
outsize_data
,
sizeof
(
float
)
*
2
,
cudaMemcpyDeviceToHost
);
out_h
=
static_cast
<
int
>
(
size_data
[
0
]);
out_w
=
static_cast
<
int
>
(
size_data
[
1
]);
}
}
auto
output_data
=
output
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
...
...
@@ -162,6 +215,14 @@ REGISTER_LITE_KERNEL(nearest_interp,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindInput
(
"SizeTensor"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindInput
(
"Scale"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
...
...
lite/kernels/cuda/nearest_interp_compute_test.cc
浏览文件 @
518a87ef
...
...
@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace
paddle
{
namespace
lite
{
...
...
@@ -143,6 +144,110 @@ TEST(nearest_interp, normal) {
}
}
TEST
(
nearest_interp
,
update
)
{
NearestInterpCompute
nearest_interp_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
operators
::
InterpolateParam
param
;
std
::
vector
<
Tensor
*>
size_tensor
(
2
),
size_tensor_cpu
(
2
),
size_tensor_ref
(
2
);
Tensor
x
,
input_scale
,
osz
,
out
;
Tensor
x_cpu
,
input_scale_cpu
,
osz_cpu
,
out_cpu
;
Tensor
x_ref
,
size_tensor_ref
,
input_scale_ref
,
osz_ref
,
out_ref
;
int
n
=
1
,
c
=
3
,
in_h
=
40
,
in_w
=
40
;
int
out_h
=
80
,
out_w
=
80
;
float
scale
=
2.0
;
param
.
out_h
=
out_h
;
param
.
out_w
=
out_w
;
param
.
scale
=
scale
;
param
.
align_corners
=
false
;
param
.
align_mode
=
0
;
x
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor
[
0
]
->
Resize
({
1
});
size_tensor
[
1
]
->
Resize
({
1
});
input_scale
.
Resize
({
1
});
osz
.
Resize
({
2
});
out
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_cpu
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_cpu
[
0
]
->
Resize
({
1
});
size_tensor_cpu
[
1
]
->
Resize
({
1
});
input_scale_cpu
.
Resize
({
1
});
osz_cpu
.
Resize
({
2
});
out_cpu
.
Resize
({
n
,
c
,
out_h
,
out_w
});
x_ref
.
Resize
({
n
,
c
,
in_h
,
in_w
});
size_tensor_ref
[
0
]
->
Resize
({
1
});
size_tensor_ref
[
1
]
->
Resize
({
1
});
input_scale_ref
.
Resize
({
1
});
osz_ref
.
Resize
({
2
});
out_ref
.
Resize
({
n
,
c
,
out_h
,
out_w
});
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
float
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
float
*
size_tensor0_cpu_data
=
size_tensor_cpu
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_cpu_data
=
size_tensor_cpu
[
1
]
->
mutable_data
<
float
>
();
float
*
input_scale_cpu_data
=
input_scale_cpu
.
mutable_data
<
float
>
();
float
*
osz_cpu_data
=
osz_cpu
.
mutable_data
<
float
>
();
float
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
float
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
float
*
size_tensor0_ref_data
=
size_tensor_ref
[
0
]
->
mutable_data
<
float
>
();
float
*
size_tensor1_ref_data
=
size_tensor_ref
[
1
]
->
mutable_data
<
float
>
();
float
*
input_scale_ref_data
=
input_scale_ref
.
mutable_data
<
float
>
();
float
*
osz_ref_data
=
osz_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
++
i
)
{
x_cpu_data
[
i
]
=
i
+
5.0
;
x_ref_data
[
i
]
=
i
+
5.0
;
}
osz_cpu_data
[
0
]
=
out_h
;
osz_cpu_data
[
1
]
=
out_w
;
size_tensor0_cpu_data
[
0
]
=
out_h
;
size_tensor1_cpu_data
[
0
]
=
out_w
;
input_scale_cpu_data
[
0
]
=
scale
;
osz_ref_data
[
0
]
=
out_h
;
osz_ref_data
[
1
]
=
out_w
;
size_tensor0_ref_data
[
0
]
=
out_h
;
size_tensor1_ref_data
[
0
]
=
out_w
;
input_scale_ref_data
[
0
]
=
scale
;
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
size_tensor
[
0
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor0_cpu_data
,
{
1
});
size_tensor
[
1
]
->
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
size_tensor1_cpu_data
,
{
1
});
input_scale
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
input_scale_cpu_data
,
{
1
});
osz
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
osz_cpu_data
,
osz_cpu
.
dims
());
param
.
X
=
&
x
;
param
.
SizeTensor
=
size_tensor
;
param
.
Scale
=
&
input_scale
;
param
.
OutSize
=
&
osz
;
param
.
Out
=
&
out
;
nearest_interp_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
nearest_interp_kernel
.
SetContext
(
std
::
move
(
ctx
));
nearest_interp_kernel
.
Launch
();
cudaDeviceSynchronize
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
LOG
(
INFO
)
<<
out_cpu_data
[
i
];
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
...
...
lite/operators/interpolate_op.cc
浏览文件 @
518a87ef
...
...
@@ -45,23 +45,42 @@ bool InterpolateOp::InferShape() const {
int
out_h
;
int
out_w
;
if
(
OutSize
!=
nullptr
)
{
auto
outsize_data
=
OutSize
->
data
<
int
>
();
int
h_out
=
outsize_data
[
0
];
// HW
int
w_out
=
outsize_data
[
1
];
// HW
param_
.
Out
->
Resize
({
n
,
c
,
h_out
,
w_out
});
auto
SizeTensor
=
param_
.
SizeTensor
;
if
(
!
SizeTensor
.
empty
())
{
CHECK
(
SizeTensor
.
size
()
==
2
)
<<
"Input(SizeTensor)'size of Op(interpolate) must be 2. "
"Attr(out_shape)'s length must be 2 for 4-D input tensor."
;
out_h
=
param_
.
out_h
;
out_w
=
param_
.
out_w
;
param_
.
Out
->
Resize
({
n
,
c
,
out_h
,
out_w
});
return
true
;
}
auto
Scale
=
param_
.
Scale
;
if
(
Scale
)
{
auto
scale_dims
=
Scale
->
dims
();
CHECK
(
scale_dims
.
size
()
==
1
)
<<
"Scale's dimension size must be 1."
;
out_h
=
-
1
;
out_w
=
-
1
;
}
else
{
if
(
0
>=
param_
.
out_h
&&
0
>=
param_
.
out_w
)
{
out_h
=
h
*
param_
.
scale
;
out_w
=
w
*
param_
.
scale
;
auto
scale
=
param_
.
scale
;
if
(
scale
>
0
)
{
out_h
=
static_cast
<
int
>
(
h
*
scale
);
out_w
=
static_cast
<
int
>
(
w
*
scale
);
out_h
=
out_h
>
0
?
out_h
:
-
1
;
out_w
=
out_w
>
0
?
out_w
:
-
1
;
}
else
{
out_h
=
param_
.
out_h
;
out_w
=
param_
.
out_w
;
}
param_
.
Out
->
Resize
({
n
,
c
,
out_h
,
out_w
});
}
if
(
OutSize
!=
nullptr
)
{
auto
out_lod
=
param_
.
Out
->
mutable_lod
();
*
out_lod
=
param_
.
X
->
lod
();
}
param_
.
Out
->
Resize
({
n
,
c
,
out_h
,
out_w
});
return
true
;
}
...
...
@@ -76,6 +95,24 @@ bool InterpolateOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
}
else
{
param_
.
OutSize
=
nullptr
;
}
if
(
op_desc
.
HasInput
(
"SizeTensor"
))
{
auto
size_tensor
=
op_desc
.
Input
(
"SizeTensor"
);
for
(
auto
var
:
size_tensor
)
{
param_
.
SizeTensor
.
push_back
(
scope
->
FindVar
(
var
)
->
GetMutable
<
lite
::
Tensor
>
());
}
}
if
(
op_desc
.
HasInput
(
"Scale"
))
{
auto
scale_var_names
=
op_desc
.
Input
(
"Scale"
);
if
(
scale_var_names
.
size
()
>
0
)
{
param_
.
Scale
=
scope
->
FindVar
(
scale_var_names
.
front
())
->
GetMutable
<
lite
::
Tensor
>
();
}
}
else
{
param_
.
Scale
=
nullptr
;
}
auto
Out
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
X
=
scope
->
FindVar
(
X
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
Out
=
scope
->
FindVar
(
Out
)
->
GetMutable
<
lite
::
Tensor
>
();
...
...
lite/operators/op_params.h
浏览文件 @
518a87ef
...
...
@@ -94,6 +94,8 @@ struct InterpolateParam {
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
OutSize
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
const
lite
::
Tensor
*>
SizeTensor
;
lite
::
Tensor
*
Scale
;
float
scale
{
0.
f
};
int
out_h
{
-
1
};
...
...
@@ -101,6 +103,7 @@ struct InterpolateParam {
bool
align_corners
{
true
};
int
align_mode
{
1
};
std
::
string
interp_method
{
"Nearest"
};
DataLayoutType
data_layout
{
DATALAYOUT
(
kNCHW
)};
};
// For Mul Op
...
...
lite/tests/kernels/bilinear_interp_compute_test.cc
浏览文件 @
518a87ef
...
...
@@ -22,6 +22,27 @@
namespace
paddle
{
namespace
lite
{
inline
std
::
vector
<
int
>
get_new_shape
(
std
::
vector
<
const
lite
::
Tensor
*>
list_new_shape_tensor
)
{
// get tensor from
std
::
vector
<
int
>
vec_new_shape
;
for
(
size_t
i
=
0
;
i
<
list_new_shape_tensor
.
size
();
++
i
)
{
auto
tensor
=
list_new_shape_tensor
[
i
];
vec_new_shape
.
push_back
(
static_cast
<
int32_t
>
(
*
(
tensor
->
data
<
int32_t
>
())));
}
return
vec_new_shape
;
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
();
lite
::
Tensor
cpu_starts_tensor
;
vec_new_data
=
std
::
vector
<
T
>
(
new_data
,
new_data
+
new_data_tensor
->
dims
().
production
());
return
vec_new_data
;
}
template
<
typename
dtype
>
void
resize_bilinear_align
(
std
::
vector
<
const
lite
::
Tensor
*>
inputs
,
lite
::
Tensor
*
output
)
{
...
...
@@ -149,6 +170,9 @@ class BilinearInterpComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std
::
string
input0_
=
"X"
;
std
::
string
sizetensor0_
=
"SizeTensor0"
;
std
::
string
sizetensor1_
=
"SizeTensor1"
;
std
::
string
input_scale_
=
"Scale"
;
std
::
string
input1_
=
"OutSize"
;
std
::
string
output_
=
"Out"
;
...
...
@@ -162,6 +186,8 @@ class BilinearInterpComputeTester : public arena::TestCase {
std
::
string
interp_method_
=
"Bilinear"
;
DDim
_dims0_
{{
1
,
1
,
16
,
16
}};
DDim
_dims1_
{{
2
}};
DDim
sizetensor_dims_
{{
1
}};
DDim
scale_dims_
{{
1
}};
public:
BilinearInterpComputeTester
(
const
Place
&
place
,
...
...
@@ -190,33 +216,48 @@ class BilinearInterpComputeTester : public arena::TestCase {
if
(
outsize_height_
>
0
&&
outsize_width_
>
0
)
{
inputs
.
emplace_back
(
scope
->
FindTensor
(
input1_
));
}
std
::
vector
<
const
lite
::
Tensor
*>
SizeTensor
;
if
(
outsize_height_
>
0
&&
outsize_width_
>
0
)
{
SizeTensor
.
emplace_back
(
scope
->
FindTensor
(
sizetensor0_
));
SizeTensor
.
emplace_back
(
scope
->
FindTensor
(
sizetensor1_
));
}
const
lite
::
Tensor
*
input_scale
=
scope
->
FindTensor
(
input_scale_
);
float
scale
=
height_scale_
;
int
in_h
=
inputs
[
0
]
->
dims
()[
2
];
int
in_w
=
inputs
[
0
]
->
dims
()[
3
];
if
(
SizeTensor
.
size
()
>
0
)
{
auto
new_size
=
get_new_shape
(
SizeTensor
);
out_height_
=
new_size
[
0
];
out_width_
=
new_size
[
1
];
}
else
{
auto
scale_tensor
=
input_scale
;
if
(
scale_tensor
!=
nullptr
)
{
auto
scale_data
=
get_new_data_from_tensor
<
float
>
(
scale_tensor
);
scale
=
scale_data
[
0
];
}
if
(
scale
>
0
)
{
out_height_
=
static_cast
<
int
>
(
in_h
*
scale
);
out_width_
=
static_cast
<
int
>
(
in_w
*
scale
);
}
if
(
inputs
.
size
()
>
1
)
{
auto
out_size
=
inputs
[
1
];
auto
out_size_data
=
get_new_data_from_tensor
<
int
>
(
out_size
);
out_height_
=
out_size_data
[
0
];
out_width_
=
out_size_data
[
1
];
}
}
height_scale_
=
scale
;
width_scale_
=
scale
;
if
(
out_width_
!=
-
1
&&
out_height_
!=
-
1
)
{
height_scale_
=
static_cast
<
float
>
(
out_height_
/
inputs
[
0
]
->
dims
()[
2
]);
width_scale_
=
static_cast
<
float
>
(
out_width_
/
inputs
[
0
]
->
dims
()[
3
]);
}
auto
*
outputs
=
scope
->
NewTensor
(
output_
);
CHECK
(
outputs
);
if
(
inputs
.
size
()
>
1
)
{
auto
outsize_data
=
inputs
[
1
]
->
data
<
int
>
();
int
h_out
=
outsize_data
[
0
];
// HW
int
w_out
=
outsize_data
[
1
];
// HW
int
num_cout
=
inputs
[
0
]
->
dims
()[
0
];
int
c_cout
=
inputs
[
0
]
->
dims
()[
1
];
outputs
->
Resize
({
num_cout
,
c_cout
,
h_out
,
w_out
});
}
else
{
int
out_h
;
int
out_w
;
if
(
-
1
==
out_height_
&&
-
1
==
out_width_
)
{
out_h
=
inputs
[
0
]
->
dims
()[
2
]
*
height_scale_
;
out_w
=
inputs
[
0
]
->
dims
()[
3
]
*
width_scale_
;
}
else
{
out_h
=
out_height_
;
out_w
=
out_width_
;
}
outputs
->
Resize
(
{
inputs
[
0
]
->
dims
()[
0
],
inputs
[
0
]
->
dims
()[
1
],
out_h
,
out_w
});
}
int
num_cout
=
inputs
[
0
]
->
dims
()[
0
];
int
c_cout
=
inputs
[
0
]
->
dims
()[
1
];
outputs
->
Resize
({
num_cout
,
c_cout
,
out_height_
,
out_width_
});
if
(
align_corners_
)
{
resize_bilinear_align
<
float
>
(
inputs
,
outputs
);
}
else
{
...
...
@@ -229,6 +270,10 @@ class BilinearInterpComputeTester : public arena::TestCase {
op_desc
->
SetInput
(
"X"
,
{
input0_
});
if
(
outsize_height_
>
0
&&
outsize_width_
>
0
)
{
op_desc
->
SetInput
(
"OutSize"
,
{
input1_
});
op_desc
->
SetInput
(
"SizeTensor"
,
{
sizetensor0_
,
sizetensor1_
});
}
if
(
height_scale_
>
0
)
{
op_desc
->
SetInput
(
"Scale"
,
{
input_scale_
});
}
op_desc
->
SetOutput
(
"Out"
,
{
output_
});
op_desc
->
SetAttr
(
"scale"
,
height_scale_
);
...
...
@@ -250,6 +295,19 @@ class BilinearInterpComputeTester : public arena::TestCase {
data1
[
0
]
=
outsize_height_
;
data1
[
1
]
=
outsize_width_
;
SetCommonTensor
(
input1_
,
_dims1_
,
data1
.
data
());
std
::
vector
<
int
>
sizetensor_data
(
1
);
sizetensor_data
[
0
]
=
outsize_height_
;
SetCommonTensor
(
sizetensor0_
,
sizetensor_dims_
,
sizetensor_data
.
data
());
sizetensor_data
[
0
]
=
outsize_width_
;
SetCommonTensor
(
sizetensor1_
,
sizetensor_dims_
,
sizetensor_data
.
data
());
}
if
(
height_scale_
>
0
)
{
std
::
vector
<
float
>
scale_data
(
1
);
scale_data
[
0
]
=
height_scale_
;
SetCommonTensor
(
input_scale_
,
scale_dims_
,
scale_data
.
data
());
}
}
};
...
...
lite/tests/kernels/nearest_interp_compute_test.cc
浏览文件 @
518a87ef
...
...
@@ -22,6 +22,28 @@
namespace
paddle
{
namespace
lite
{
inline
std
::
vector
<
int
>
get_new_shape
(
const
std
::
vector
<
const
lite
::
Tensor
*>&
list_new_shape_tensor
)
{
// get tensor from
std
::
vector
<
int
>
vec_new_shape
;
for
(
size_t
i
=
0
;
i
<
list_new_shape_tensor
.
size
();
++
i
)
{
auto
tensor
=
list_new_shape_tensor
[
i
];
vec_new_shape
.
push_back
(
static_cast
<
int32_t
>
(
*
tensor
->
data
<
int32_t
>
()));
}
return
vec_new_shape
;
}
template
<
typename
T
>
inline
std
::
vector
<
T
>
get_new_data_from_tensor
(
const
Tensor
*
new_data_tensor
)
{
std
::
vector
<
T
>
vec_new_data
;
auto
*
new_data
=
new_data_tensor
->
data
<
T
>
();
lite
::
Tensor
cpu_starts_tensor
;
vec_new_data
=
std
::
vector
<
T
>
(
new_data
,
new_data
+
new_data_tensor
->
dims
().
production
());
return
vec_new_data
;
}
template
<
typename
dtype
>
void
resize_nearest_align
(
std
::
vector
<
const
lite
::
Tensor
*>
inputs
,
lite
::
Tensor
*
output
,
...
...
@@ -73,6 +95,9 @@ class NearestInterpComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std
::
string
input0_
=
"X"
;
std
::
string
sizetensor0_
=
"SizeTensor0"
;
std
::
string
sizetensor1_
=
"SizeTensor1"
;
std
::
string
input_scale_
=
"Scale"
;
std
::
string
input1_
=
"OutSize"
;
std
::
string
output_
=
"Out"
;
...
...
@@ -85,6 +110,8 @@ class NearestInterpComputeTester : public arena::TestCase {
DDim
dims_
{{
2
,
3
}};
DDim
_dims0_
{{
2
,
3
,
3
,
2
}};
DDim
_dims1_
{{
2
}};
DDim
sizetensor_dims_
{{
1
}};
DDim
scale_dims_
{{
1
}};
public:
NearestInterpComputeTester
(
const
Place
&
place
,
...
...
@@ -112,24 +139,54 @@ class NearestInterpComputeTester : public arena::TestCase {
inputs
.
emplace_back
(
scope
->
FindTensor
(
input0_
));
inputs
.
emplace_back
(
scope
->
FindTensor
(
input1_
));
auto
outsize_data
=
inputs
[
1
]
->
data
<
int
>
();
std
::
vector
<
const
lite
::
Tensor
*>
SizeTensor
(
2
);
SizeTensor
[
0
]
=
scope
->
FindTensor
(
sizetensor0_
);
SizeTensor
[
1
]
=
scope
->
FindTensor
(
sizetensor1_
);
const
lite
::
Tensor
*
input_scale
=
scope
->
FindTensor
(
input_scale_
);
float
scale
=
height_scale_
;
int
in_h
=
inputs
[
0
]
->
dims
()[
2
];
int
in_w
=
inputs
[
0
]
->
dims
()[
3
];
if
(
SizeTensor
.
size
()
>
0
)
{
auto
new_size
=
get_new_shape
(
SizeTensor
);
out_height_
=
new_size
[
0
];
out_width_
=
new_size
[
1
];
}
else
{
auto
scale_tensor
=
input_scale
;
if
(
scale_tensor
!=
nullptr
)
{
auto
scale_data
=
get_new_data_from_tensor
<
float
>
(
scale_tensor
);
scale
=
scale_data
[
0
];
}
if
(
scale
>
0
)
{
out_height_
=
static_cast
<
int
>
(
in_h
*
scale
);
out_width_
=
static_cast
<
int
>
(
in_w
*
scale
);
}
auto
out_size
=
inputs
[
1
];
if
(
out_size
!=
nullptr
)
{
auto
out_size_data
=
get_new_data_from_tensor
<
int
>
(
out_size
);
out_height_
=
out_size_data
[
0
];
out_width_
=
out_size_data
[
1
];
}
}
height_scale_
=
scale
;
width_scale_
=
scale
;
if
(
out_width_
!=
-
1
&&
out_height_
!=
-
1
)
{
height_scale_
=
static_cast
<
float
>
(
out_height_
/
inputs
[
0
]
->
dims
()[
2
]);
width_scale_
=
static_cast
<
float
>
(
out_width_
/
inputs
[
0
]
->
dims
()[
3
]);
}
if
(
inputs
.
size
()
>
1
)
{
int
h_out
=
outsize_data
[
0
];
// HW
int
w_out
=
outsize_data
[
1
];
// HW
int
num_cout
=
outputs
->
dims
()[
0
];
int
c_cout
=
outputs
->
dims
()[
1
];
outputs
->
Resize
({
num_cout
,
c_cout
,
h_out
,
w_out
});
}
int
num_cout
=
inputs
[
0
]
->
dims
()[
0
];
int
c_cout
=
inputs
[
0
]
->
dims
()[
1
];
outputs
->
Resize
({
num_cout
,
c_cout
,
out_height_
,
out_width_
});
resize_nearest_align
<
float
>
(
inputs
,
outputs
,
align_corners_
);
}
void
PrepareOpDesc
(
cpp
::
OpDesc
*
op_desc
)
{
op_desc
->
SetType
(
"nearest_interp"
);
op_desc
->
SetInput
(
"X"
,
{
input0_
});
op_desc
->
SetInput
(
"SizeTensor"
,
{
sizetensor0_
,
sizetensor1_
});
op_desc
->
SetInput
(
"Scale"
,
{
input_scale_
});
op_desc
->
SetInput
(
"OutSize"
,
{
input1_
});
op_desc
->
SetOutput
(
"Out"
,
{
output_
});
op_desc
->
SetAttr
(
"scale"
,
height_scale_
);
...
...
@@ -152,6 +209,17 @@ class NearestInterpComputeTester : public arena::TestCase {
SetCommonTensor
(
input0_
,
_dims0_
,
data0
.
data
());
SetCommonTensor
(
input1_
,
_dims1_
,
data1
.
data
());
std
::
vector
<
int
>
sizetensor_data
(
1
);
sizetensor_data
[
0
]
=
out_height_
;
SetCommonTensor
(
sizetensor0_
,
sizetensor_dims_
,
sizetensor_data
.
data
());
sizetensor_data
[
0
]
=
out_width_
;
SetCommonTensor
(
sizetensor1_
,
sizetensor_dims_
,
sizetensor_data
.
data
());
std
::
vector
<
float
>
scale_data
(
1
);
scale_data
[
0
]
=
height_scale_
;
SetCommonTensor
(
input_scale_
,
scale_dims_
,
scale_data
.
data
());
}
};
...
...
lite/tests/kernels/shuffle_channel_compute_test.cc
浏览文件 @
518a87ef
...
...
@@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO(zhengxi)
// shuffle_channel_test can pass on local compilation
// while on ci compilation, the test will be killed immediately.
/*
#include <gtest/gtest.h>
// TODO(FrostML): shaffle_channel cannot pass on CI, but ok in local machine.
// Open this.
/*#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
...
...
@@ -30,8 +27,8 @@ class ShuffleChannelComputeTester : public arena::TestCase {
// common attributes for this op.
std::string input_ = "X";
std::string output_ = "Out";
int group_ =
1
;
DDim dims_{{1
, 2
}};
int group_ =
4
;
DDim dims_{{1
0, 16, 4, 4
}};
public:
ShuffleChannelComputeTester(const Place& place,
...
...
@@ -87,7 +84,7 @@ class ShuffleChannelComputeTester : public arena::TestCase {
};
void test_shuffle_channel(Place place) {
for (int group : {
1, 2, 3
}) {
for (int group : {
4
}) {
std::unique_ptr<arena::TestCase> tester(
new ShuffleChannelComputeTester(place, "def", group));
arena::Arena arena(std::move(tester), place, 2e-5);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录