Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
cba5736f
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cba5736f
编写于
9月 12, 2019
作者:
W
Wilber
提交者:
GitHub
9月 12, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add transpose kernel for cuda test=develop (#1997)
add transpose kernel for cuda
上级
83d4b0e8
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
654 addition
and
0 deletion
+654
-0
lite/backends/cuda/math/CMakeLists.txt
lite/backends/cuda/math/CMakeLists.txt
+2
-0
lite/backends/cuda/math/transpose.cu
lite/backends/cuda/math/transpose.cu
+191
-0
lite/backends/cuda/math/transpose.h
lite/backends/cuda/math/transpose.h
+44
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+3
-0
lite/kernels/cuda/transpose_compute.cu
lite/kernels/cuda/transpose_compute.cu
+86
-0
lite/kernels/cuda/transpose_compute.h
lite/kernels/cuda/transpose_compute.h
+38
-0
lite/kernels/cuda/transpose_compute_test.cc
lite/kernels/cuda/transpose_compute_test.cc
+290
-0
未找到文件。
lite/backends/cuda/math/CMakeLists.txt
浏览文件 @
cba5736f
...
...
@@ -5,6 +5,7 @@ endif()
nv_library
(
cuda_activation SRCS activation.cu
)
nv_library
(
cuda_scale SRCS scale.cu
)
nv_library
(
cuda_type_trans SRCS type_trans.cu
)
nv_library
(
cuda_transpose SRCS transpose.cu
)
nv_library
(
cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans
)
...
...
@@ -14,6 +15,7 @@ set (
cuda_activation
cuda_scale
cuda_type_trans
cuda_transpose
)
set
(
math_cuda
"
${
math_cuda
}
"
CACHE GLOBAL
"math cuda"
)
lite/backends/cuda/math/transpose.cu
0 → 100644
浏览文件 @
cba5736f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/transpose.h"
#include "lite/backends/cuda/math/utils.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
constexpr
int
kTileDim
=
32
;
constexpr
int
kBlockRows
=
8
;
constexpr
int
CUDA_NUM_THREADS
=
128
;
// Splits the original matrix into submatrices with size 32 * 32.
// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/
template
<
typename
T
>
__global__
void
BatchTranspose2DCUDAKernel
(
const
int
N
,
const
int
H
,
const
int
W
,
const
int
dh
,
const
int
dw
,
const
T
*
input
,
T
*
out
)
{
__shared__
T
tile
[
kTileDim
][
kTileDim
+
1
];
// plus 1 to prevent bank confict.
const
int
n
=
blockIdx
.
x
/
(
dh
*
dw
);
const
int
k
=
blockIdx
.
x
%
(
dh
*
dw
);
const
int
r
=
k
/
dw
;
const
int
c
=
k
%
dw
;
const
int
offset
=
n
*
H
*
W
;
int
x
=
c
*
kTileDim
+
threadIdx
.
x
;
int
y
=
r
*
kTileDim
+
threadIdx
.
y
;
if
(
x
<
W
)
{
for
(
int
i
=
0
;
threadIdx
.
y
+
i
<
kTileDim
&&
y
+
i
<
H
;
i
+=
kBlockRows
)
{
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
tile
[
threadIdx
.
y
+
i
][
threadIdx
.
x
]
=
__ldg
(
input
+
offset
+
(
y
+
i
)
*
W
+
x
);
#else
tile
[
threadIdx
.
y
+
i
][
threadIdx
.
x
]
=
input
[
offset
+
(
y
+
i
)
*
W
+
x
];
#endif
}
}
__syncthreads
();
x
=
r
*
kTileDim
+
threadIdx
.
x
;
y
=
c
*
kTileDim
+
threadIdx
.
y
;
if
(
x
<
H
)
{
for
(
int
i
=
0
;
threadIdx
.
y
+
i
<
kTileDim
&&
y
+
i
<
W
;
i
+=
kBlockRows
)
{
out
[
offset
+
(
y
+
i
)
*
H
+
x
]
=
tile
[
threadIdx
.
x
][
threadIdx
.
y
+
i
];
}
}
}
template
<
typename
T
>
void
BatchTranspose2DCUDAImpl
(
const
int
N
,
const
int
H
,
const
int
W
,
const
T
*
input
,
T
*
out
,
CUDAContext
*
ctx
)
{
const
int
dh
=
(
H
+
kTileDim
-
1
)
/
kTileDim
;
const
int
dw
=
(
W
+
kTileDim
-
1
)
/
kTileDim
;
BatchTranspose2DCUDAKernel
<
T
><<<
N
*
dh
*
dw
,
dim3
(
kTileDim
,
kBlockRows
),
0
,
ctx
->
exec_stream
()
>>>
(
N
,
H
,
W
,
dh
,
dw
,
input
,
out
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
void NCHW2NHWC<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC
(
float
)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
template <> \
void NHWC2NCHW<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW
(
float
)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template
<
typename
T
>
__global__
void
TransposeCUDAKernel
(
const
int
size
,
const
int
ndim
,
const
int
*
X_strides
,
const
int
*
Y_dims
,
const
T
*
X
,
T
*
Y
)
{
const
int
Y_index
=
blockIdx
.
x
*
CUDA_NUM_THREADS
+
threadIdx
.
x
;
if
(
Y_index
<
size
)
{
int
X_index
=
0
;
int
v
=
Y_index
;
#pragma unroll
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
X_index
+=
v
%
Y_dims
[
i
]
*
X_strides
[
i
];
v
/=
Y_dims
[
i
];
}
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y
[
Y_index
]
=
__ldg
(
X
+
X_index
);
#else
Y
[
Y_index
]
=
X
[
X_index
];
#endif
}
}
template
<
typename
T
>
void
TransposeCUDAImpl
(
const
std
::
vector
<
int64_t
>&
X_dims
,
const
std
::
vector
<
int
>&
axes
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
ctx
)
{
CHECK_EQ
(
X_dims
.
size
(),
axes
.
size
())
<<
"dimension size should be equal"
;
int
ndim
=
X_dims
.
size
();
std
::
vector
<
int
>
strides
(
ndim
,
0
);
std
::
vector
<
int
>
Y_dims
(
ndim
,
0
);
std
::
vector
<
int
>
buf
(
ndim
,
0
);
int
cur_stride
=
1
;
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
buf
[
i
]
=
cur_stride
;
cur_stride
*=
X_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
strides
[
i
]
=
buf
[
axes
[
i
]];
}
int
size
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
Y_dims
[
i
]
=
static_cast
<
int
>
(
X_dims
[
axes
[
i
]]);
size
*=
X_dims
[
i
];
}
lite
::
Tensor
Y_dims_
,
strides_
;
Y_dims_
.
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_y_dims
=
Y_dims_
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
d_y_dims
,
Y_dims
.
data
(),
sizeof
(
int
)
*
Y_dims
.
size
(),
IoDirection
::
HtoD
);
strides_
.
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_strides
=
strides_
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
d_strides
,
strides
.
data
(),
sizeof
(
int
)
*
strides
.
size
(),
IoDirection
::
HtoD
);
const
int
M
=
(
size
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
TransposeCUDAKernel
<<<
M
,
CUDA_NUM_THREADS
,
0
,
ctx
->
exec_stream
()
>>>
(
size
,
ndim
,
d_strides
,
d_y_dims
,
X
,
Y
);
// cudaError_t error = cudaGetLastError();
// if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
template <> \
void Transpose<T>(const std::vector<int64_t>& X_dims, \
const std::vector<int>& axes, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
TransposeCUDAImpl<T>(X_dims, axes, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_TRANSPOSE
(
float
)
#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/backends/cuda/math/transpose.h
0 → 100644
浏览文件 @
cba5736f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include <vector>
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace
paddle
{
namespace
lite
{
namespace
cuda
{
namespace
math
{
template
<
typename
T
>
void
NCHW2NHWC
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
context
);
template
<
typename
T
>
void
NHWC2NCHW
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
context
);
template
<
typename
T
>
void
Transpose
(
const
std
::
vector
<
int64_t
>&
X_dims
,
const
std
::
vector
<
int
>&
axes
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
ctx
);
}
// namespace math
}
// namespace cuda
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
cba5736f
...
...
@@ -10,6 +10,7 @@ lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_
nv_library
(
leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS
${
lite_kernel_deps
}
)
nv_library
(
yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS
${
lite_kernel_deps
}
)
nv_library
(
transpose_compute_cuda SRCS transpose_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
nv_library
(
nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS
${
lite_kernel_deps
}
)
nv_library
(
conv2d_cuda SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
nv_library
(
concat_compute_cuda SRCS concat_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
@@ -19,6 +20,7 @@ nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test
(
nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda
)
nv_test
(
leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda
)
nv_test
(
yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda
)
nv_test
(
transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda
)
nv_test
(
concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda
)
nv_test
(
elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda
)
...
...
@@ -34,6 +36,7 @@ nearest_interp_compute_cuda
concat_compute_cuda
elementwise_add_compute_cuda
yolo_box_compute_cuda
transpose_compute_cuda
)
set
(
cuda_kernels
"
${
cuda_kernels
}
"
CACHE GLOBAL
"cuda kernels"
)
lite/kernels/cuda/transpose_compute.cu
0 → 100644
浏览文件 @
cba5736f
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/transpose_compute.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
void
TransposeCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
const
lite
::
Tensor
*
X
=
param
.
x
;
lite
::
Tensor
*
Out
=
param
.
output
;
std
::
vector
<
int
>
axes
=
param
.
axis
;
const
float
*
in
=
X
->
data
<
float
>
();
float
*
out
=
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
int
ndim
=
X
->
dims
().
size
();
std
::
vector
<
int64_t
>
dims
=
X
->
dims
().
data
();
// NCHW -> NHWC
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
2
&&
axes
[
2
]
==
3
&&
axes
[
3
]
==
1
)
{
lite
::
cuda
::
math
::
NCHW2NHWC
(
dims
[
0
],
dims
[
1
],
dims
[
2
]
*
dims
[
3
],
in
,
out
,
&
ctx
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
return
;
}
// NHWC -> NCHW
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
3
&&
axes
[
2
]
==
1
&&
axes
[
3
]
==
2
)
{
lite
::
cuda
::
math
::
NHWC2NCHW
(
dims
[
0
],
dims
[
3
],
dims
[
1
]
*
dims
[
2
],
in
,
out
,
&
ctx
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
return
;
}
lite
::
cuda
::
math
::
Transpose
(
dims
,
axes
,
in
,
out
,
&
ctx
);
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
transpose
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
TransposeCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
// REGISTER_LITE_KERNEL(transpose2,
// kCUDA,
// kFloat,
// kNCHW,
// paddle::lite::kernels::cuda::TransposeCompute,
// def)
// .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .Finalize();
lite/kernels/cuda/transpose_compute.h
0 → 100644
浏览文件 @
cba5736f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
TransposeCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
TransposeParam
;
void
Run
()
override
;
virtual
~
TransposeCompute
()
=
default
;
private:
lite
::
Tensor
axes_
,
dims_
;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/transpose_compute_test.cc
0 → 100644
浏览文件 @
cba5736f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
namespace
{
#define IN(n, c, h, w) \
input_data[w + h * input_w + c * input_h * input_w + \
n * input_c * input_h * input_w]
#define OUT(n, c, h, w) \
output_data[w + h * output_w + c * output_h * output_w + \
n * output_c * output_h * output_w]
void
nchw2nhwc_ref
(
lite
::
Tensor
*
input
,
lite
::
Tensor
*
output
,
const
std
::
vector
<
int
>
axies
)
{
auto
*
input_data
=
input
->
data
<
float
>
();
auto
*
output_data
=
output
->
mutable_data
<
float
>
();
int
input_n
=
input
->
dims
()[
0
];
int
input_c
=
input
->
dims
()[
1
];
int
input_h
=
input
->
dims
()[
2
];
int
input_w
=
input
->
dims
()[
3
];
int
output_n
=
output
->
dims
()[
0
];
int
output_c
=
output
->
dims
()[
1
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
for
(
int
n
=
0
;
n
<
input_n
;
++
n
)
{
for
(
int
c
=
0
;
c
<
input_c
;
++
c
)
{
for
(
int
h
=
0
;
h
<
input_h
;
++
h
)
{
for
(
int
w
=
0
;
w
<
input_w
;
++
w
)
{
OUT
(
n
,
h
,
w
,
c
)
=
IN
(
n
,
c
,
h
,
w
);
}
}
}
}
}
#undef IN
#undef OUT
#define IN(n, h, w, c) \
input_data[c + w * input_c + h * input_w * input_c + \
n * input_h * input_w * input_c]
#define OUT(n, h, w, c) \
output_data[c + w * output_c + h * output_w * output_c + \
n * output_h * output_w * output_c]
void
nhwc2nchw_ref
(
lite
::
Tensor
*
input
,
lite
::
Tensor
*
output
,
const
std
::
vector
<
int
>
axies
)
{
auto
*
input_data
=
input
->
data
<
float
>
();
auto
*
output_data
=
output
->
mutable_data
<
float
>
();
int
input_n
=
input
->
dims
()[
0
];
int
input_h
=
input
->
dims
()[
1
];
int
input_w
=
input
->
dims
()[
2
];
int
input_c
=
input
->
dims
()[
3
];
int
output_n
=
output
->
dims
()[
0
];
int
output_h
=
output
->
dims
()[
1
];
int
output_w
=
output
->
dims
()[
2
];
int
output_c
=
output
->
dims
()[
3
];
for
(
int
n
=
0
;
n
<
input_n
;
++
n
)
{
for
(
int
c
=
0
;
c
<
input_c
;
++
c
)
{
for
(
int
h
=
0
;
h
<
input_h
;
++
h
)
{
for
(
int
w
=
0
;
w
<
input_w
;
++
w
)
{
OUT
(
n
,
c
,
h
,
w
)
=
IN
(
n
,
h
,
w
,
c
);
}
}
}
}
}
void
transpose_ref
(
lite
::
Tensor
*
input
,
lite
::
Tensor
*
output
,
const
std
::
vector
<
int
>
axes
)
{
auto
*
input_data
=
input
->
data
<
float
>
();
auto
*
output_data
=
output
->
mutable_data
<
float
>
();
int
ndim
=
input
->
dims
().
size
();
auto
dims
=
input
->
dims
();
std
::
vector
<
int
>
strides
(
ndim
,
0
);
std
::
vector
<
int
>
buf
(
ndim
,
0
);
int
cur_stride
=
1
;
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
buf
[
i
]
=
cur_stride
;
cur_stride
*=
dims
[
i
];
}
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
strides
[
i
]
=
buf
[
axes
[
i
]];
}
auto
y_dims
=
output
->
dims
();
int
size
=
input
->
dims
().
production
();
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
int
idx
=
0
;
int
v
=
i
;
for
(
int
j
=
ndim
-
1
;
j
>=
0
;
--
j
)
{
idx
+=
v
%
y_dims
[
j
]
*
strides
[
j
];
v
/=
y_dims
[
j
];
}
output_data
[
i
]
=
input_data
[
idx
];
}
}
}
// namespace
TEST
(
transpose_nchw
,
normal
)
{
TransposeCompute
transpose_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
operators
::
TransposeParam
param
;
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
int
N
=
5
,
C
=
6
,
H
=
7
,
W
=
8
;
std
::
vector
<
int
>
axes
({
0
,
2
,
3
,
1
});
x
.
Resize
({
N
,
C
,
H
,
W
});
out
.
Resize
({
N
,
H
,
W
,
C
});
x_cpu
.
Resize
({
N
,
C
,
H
,
W
});
out_cpu
.
Resize
({
N
,
H
,
W
,
C
});
x_ref
.
Resize
({
N
,
C
,
H
,
W
});
out_ref
.
Resize
({
N
,
H
,
W
,
C
});
auto
*
x_data
=
x
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
auto
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
auto
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
++
i
)
{
x_cpu_data
[
i
]
=
i
+
1
;
x_ref_data
[
i
]
=
i
+
1
;
}
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
param
.
x
=
&
x
;
param
.
output
=
&
out
;
param
.
axis
=
axes
;
transpose_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
transpose_kernel
.
SetContext
(
std
::
move
(
ctx
));
transpose_kernel
.
Launch
();
cudaDeviceSynchronize
();
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
nchw2nhwc_ref
(
&
x_ref
,
&
out_ref
,
axes
);
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
// transpose_ref(&x_ref, &out_ref, axes);
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
EXPECT_NEAR
(
out_cpu_data
[
i
],
out_ref_data
[
i
],
1e-5
);
}
}
TEST
(
transpose_nhwc
,
normal
)
{
TransposeCompute
transpose_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
operators
::
TransposeParam
param
;
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
int
N
=
5
,
C
=
6
,
H
=
7
,
W
=
8
;
std
::
vector
<
int
>
axes
({
0
,
3
,
1
,
2
});
x
.
Resize
({
N
,
H
,
W
,
C
});
out
.
Resize
({
N
,
C
,
H
,
W
});
x_cpu
.
Resize
({
N
,
H
,
W
,
C
});
out_cpu
.
Resize
({
N
,
C
,
H
,
W
});
x_ref
.
Resize
({
N
,
H
,
W
,
C
});
out_ref
.
Resize
({
N
,
C
,
H
,
W
});
auto
*
x_data
=
x
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
auto
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
auto
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
++
i
)
{
x_cpu_data
[
i
]
=
i
+
1
;
x_ref_data
[
i
]
=
i
+
1
;
}
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
param
.
x
=
&
x
;
param
.
output
=
&
out
;
param
.
axis
=
axes
;
transpose_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
transpose_kernel
.
SetContext
(
std
::
move
(
ctx
));
transpose_kernel
.
Launch
();
cudaDeviceSynchronize
();
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
nhwc2nchw_ref
(
&
x_ref
,
&
out_ref
,
axes
);
// transpose_ref(&x_ref, &out_ref, axes);
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
EXPECT_NEAR
(
out_cpu_data
[
i
],
out_ref_data
[
i
],
1e-5
);
}
}
TEST
(
transpose
,
normal
)
{
TransposeCompute
transpose_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
operators
::
TransposeParam
param
;
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
int
C
=
6
,
H
=
7
,
W
=
8
;
std
::
vector
<
int
>
axes
({
2
,
0
,
1
});
x
.
Resize
({
C
,
H
,
W
});
out
.
Resize
({
W
,
C
,
H
});
x_cpu
.
Resize
({
C
,
H
,
W
});
out_cpu
.
Resize
({
W
,
C
,
H
});
x_ref
.
Resize
({
C
,
H
,
W
});
out_ref
.
Resize
({
W
,
C
,
H
});
auto
*
x_data
=
x
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
auto
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
auto
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
auto
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
++
i
)
{
x_cpu_data
[
i
]
=
i
+
1
;
x_ref_data
[
i
]
=
i
+
1
;
}
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
param
.
x
=
&
x
;
param
.
output
=
&
out
;
param
.
axis
=
axes
;
transpose_kernel
.
SetParam
(
param
);
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
transpose_kernel
.
SetContext
(
std
::
move
(
ctx
));
transpose_kernel
.
Launch
();
cudaDeviceSynchronize
();
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
transpose_ref
(
&
x_ref
,
&
out_ref
,
axes
);
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
EXPECT_NEAR
(
out_cpu_data
[
i
],
out_ref_data
[
i
],
1e-5
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录