Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
b65a6dc9
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b65a6dc9
编写于
12月 15, 2019
作者:
W
Wilber
提交者:
GitHub
12月 15, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize search_grnn test=develop (#2608)
optimize search_grnn
上级
1dbcd51d
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
593 addition
and
324 deletion
+593
-324
lite/backends/cuda/math/transpose.cu
lite/backends/cuda/math/transpose.cu
+59
-54
lite/backends/cuda/math/transpose.h
lite/backends/cuda/math/transpose.h
+19
-9
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+1
-1
lite/kernels/cuda/layout_compute.cc
lite/kernels/cuda/layout_compute.cc
+5
-3
lite/kernels/cuda/layout_compute.h
lite/kernels/cuda/layout_compute.h
+13
-0
lite/kernels/cuda/search_grnn_compute.cu
lite/kernels/cuda/search_grnn_compute.cu
+423
-247
lite/kernels/cuda/search_grnn_compute.h
lite/kernels/cuda/search_grnn_compute.h
+67
-3
lite/kernels/cuda/transpose_compute.cu
lite/kernels/cuda/transpose_compute.cu
+4
-5
lite/kernels/cuda/transpose_compute.h
lite/kernels/cuda/transpose_compute.h
+1
-1
lite/kernels/cuda/transpose_compute_test.cc
lite/kernels/cuda/transpose_compute_test.cc
+1
-1
未找到文件。
lite/backends/cuda/math/transpose.cu
浏览文件 @
b65a6dc9
...
...
@@ -69,44 +69,16 @@ void BatchTranspose2DCUDAImpl(const int N,
const
int
W
,
const
T
*
input
,
T
*
out
,
CUDAContext
*
ctx
)
{
cudaStream_t
*
stream
)
{
const
int
dh
=
(
H
+
kTileDim
-
1
)
/
kTileDim
;
const
int
dw
=
(
W
+
kTileDim
-
1
)
/
kTileDim
;
BatchTranspose2DCUDAKernel
<
T
><<<
N
*
dh
*
dw
,
dim3
(
kTileDim
,
kBlockRows
),
0
,
ctx
->
exec_stream
()
>>>
(
T
><<<
N
*
dh
*
dw
,
dim3
(
kTileDim
,
kBlockRows
),
0
,
*
stream
>>>
(
N
,
H
,
W
,
dh
,
dw
,
input
,
out
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
void NCHW2NHWC<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC
(
float
)
TYPE_SPECIALIZED_CUDA_NCHW2NHWC
(
int8_t
)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
template <> \
void NHWC2NCHW<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW
(
float
)
TYPE_SPECIALIZED_CUDA_NHWC2NCHW
(
int8_t
)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template
<
typename
T
>
__global__
void
TransposeCUDAKernel
(
const
int
size
,
const
int
ndim
,
...
...
@@ -136,7 +108,9 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const
std
::
vector
<
int
>&
axes
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
ctx
)
{
lite
::
Tensor
*
Y_dims_
,
lite
::
Tensor
*
strides_
,
cudaStream_t
*
stream
)
{
CHECK_EQ
(
X_dims
.
size
(),
axes
.
size
())
<<
"dimension size should be equal"
;
int
ndim
=
X_dims
.
size
();
std
::
vector
<
int
>
strides
(
ndim
,
0
);
...
...
@@ -156,37 +130,68 @@ void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
size
*=
X_dims
[
i
];
}
lite
::
Tensor
Y_dims_
,
strides_
;
Y_dims_
.
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_y_dims
=
Y_dims_
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
d_y_dims
,
Y_dims
.
data
(),
sizeof
(
int
)
*
Y_dims
.
size
(),
IoDirection
::
HtoD
);
Y_dims_
->
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_y_dims
=
Y_dims_
->
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
d_y_dims
,
Y_dims
.
data
(),
sizeof
(
int
)
*
Y_dims
.
size
(),
IoDirection
::
HtoD
,
*
stream
);
strides_
.
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_strides
=
strides_
.
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
d_strides
,
strides
.
data
(),
sizeof
(
int
)
*
strides
.
size
(),
IoDirection
::
HtoD
);
strides_
->
Resize
(
std
::
vector
<
int64_t
>
({
ndim
}));
int
*
d_strides
=
strides_
->
mutable_data
<
int
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemcpyAsync
(
d_strides
,
strides
.
data
(),
sizeof
(
int
)
*
strides
.
size
(),
IoDirection
::
HtoD
,
*
stream
);
const
int
M
=
(
size
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
TransposeCUDAKernel
<<<
M
,
CUDA_NUM_THREADS
,
0
,
ctx
->
exec_stream
()
>>>
(
TransposeCUDAKernel
<<<
M
,
CUDA_NUM_THREADS
,
0
,
*
stream
>>>
(
size
,
ndim
,
d_strides
,
d_y_dims
,
X
,
Y
);
auto
e
=
cudaGetLastError
();
CHECK_EQ
(
e
,
cudaSuccess
)
<<
" CUDA: "
<<
cudaGetErrorString
(
e
);
}
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
template <> \
void Transpose<T>(const std::vector<int64_t>& X_dims, \
const std::vector<int>& axes, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
TransposeCUDAImpl<T>(X_dims, axes, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_TRANSPOSE
(
float
)
#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF
template
<
typename
T
>
void
Transpose
<
T
>::
NCHW2NHWC
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
)
{
BatchTranspose2DCUDAImpl
<
T
>
(
N
,
C
,
HxW
,
X
,
Y
,
stream
);
}
template
<
typename
T
>
void
Transpose
<
T
>::
NHWC2NCHW
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
)
{
BatchTranspose2DCUDAImpl
<
T
>
(
N
,
HxW
,
C
,
X
,
Y
,
stream
);
}
template
<
typename
T
>
void
Transpose
<
T
>::
transpose
(
T
*
dst
,
const
T
*
src
,
const
std
::
vector
<
int64_t
>&
src_dims
,
const
std
::
vector
<
int
>&
axes
,
cudaStream_t
*
stream
)
{
TransposeCUDAImpl
<
T
>
(
src_dims
,
axes
,
src
,
dst
,
&
Y_dims_
,
&
strides_
,
stream
);
}
// template <typename T>
// void Transpose<T>::transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream) {
// std::vector<int64_t> _src_dims(src_dims.size(), 0);
// std::transform(
// src_dims.begin(),
// src_dims.end(),
// _src_dims.begin(),
// [](int data) -> int64_t { return static_cast<int64_t>(data); });
// TransposeCUDAImpl<T>(_src_dims, axes, src, dst, &Y_dims_, &strides_,
// stream);
//}
template
class
Transpose
<
int8_t
>;
template
class
Transpose
<
float
>;
}
// namespace math
}
// namespace cuda
...
...
lite/backends/cuda/math/transpose.h
浏览文件 @
b65a6dc9
...
...
@@ -26,17 +26,27 @@ namespace cuda {
namespace
math
{
template
<
typename
T
>
void
NCHW2NHWC
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
context
);
class
Transpose
{
public:
void
NCHW2NHWC
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
);
template
<
typename
T
>
void
NHWC2NCHW
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
context
);
void
NHWC2NCHW
(
int
N
,
int
C
,
int
HxW
,
const
T
*
X
,
T
*
Y
,
cudaStream_t
*
stream
);
template
<
typename
T
>
void
Transpose
(
const
std
::
vector
<
int64_t
>&
X_dims
,
const
std
::
vector
<
int
>&
axes
,
const
T
*
X
,
T
*
Y
,
CUDAContext
*
ctx
);
void
transpose
(
T
*
dst
,
const
T
*
src
,
const
std
::
vector
<
int64_t
>&
src_dims
,
const
std
::
vector
<
int
>&
axes
,
cudaStream_t
*
stream
);
// void transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream);
private:
lite
::
Tensor
Y_dims_
,
strides_
;
// for transpose.
};
}
// namespace math
}
// namespace cuda
...
...
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
b65a6dc9
...
...
@@ -26,7 +26,7 @@ add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS
${
lite_kernel_deps
}
cudnn_pool
)
add_kernel
(
bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS
${
lite_kernel_deps
}
cuda_gemm
)
add_kernel
(
search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS
${
lite_kernel_deps
}
cuda_gemm
${
math_cuda
}
)
add_kernel
(
sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS
${
lite_kernel_deps
}
)
...
...
lite/kernels/cuda/layout_compute.cc
浏览文件 @
b65a6dc9
...
...
@@ -29,7 +29,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
}
std
::
vector
<
int64_t
>
trim_dims
;
trim_dims
.
resize
(
actual_dims_size
);
for
(
in
t
i
=
0
;
i
<
actual_dims_size
;
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
actual_dims_size
;
++
i
)
{
trim_dims
[
i
]
=
dims
[
i
];
}
if
(
trim_dims
.
size
()
==
0
)
{
...
...
@@ -41,6 +41,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
...
...
@@ -56,11 +57,12 @@ inline DDim trim_singular_dims(const DDim& dims) {
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NCHW2NHWC<type>(n, c, h * w, input, output, &ctx
);
trans.NCHW2NHWC(n, c, h* w, input, output, &stream
);
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto& ctx = this->ctx_->template As<CUDAContext>(); \
auto stream = ctx.exec_stream(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
DDim input_trim_dim = trim_singular_dims(input_dim); \
...
...
@@ -76,7 +78,7 @@ inline DDim trim_singular_dims(const DDim& dims) {
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kCUDA)); \
lite::cuda::math::NHWC2NCHW<type>(n, c, h * w, input, output, &ctx
);
trans.NHWC2NCHW(n, c, h* w, input, output, &stream
);
void
NCHWToNHWCCompute
::
Run
()
{
NCHWTONHWC
(
float
)
}
...
...
lite/kernels/cuda/layout_compute.h
浏览文件 @
b65a6dc9
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/kernel.h"
namespace
paddle
{
...
...
@@ -25,6 +26,9 @@ class NCHWToNHWCCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
virtual
~
NCHWToNHWCCompute
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
};
class
NCHWToNHWCComputeInt8
...
...
@@ -33,6 +37,9 @@ class NCHWToNHWCComputeInt8
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
virtual
~
NCHWToNHWCComputeInt8
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
int8_t
>
trans
;
};
class
NHWCToNCHWCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
...
...
@@ -40,6 +47,9 @@ class NHWCToNCHWCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
virtual
~
NHWCToNCHWCompute
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
};
class
NHWCToNCHWComputeInt8
...
...
@@ -48,6 +58,9 @@ class NHWCToNCHWComputeInt8
using
param_t
=
operators
::
LayoutParam
;
void
Run
()
override
;
virtual
~
NHWCToNCHWComputeInt8
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
int8_t
>
trans
;
};
}
// namespace cuda
...
...
lite/kernels/cuda/search_grnn_compute.cu
浏览文件 @
b65a6dc9
此差异已折叠。
点击以展开。
lite/kernels/cuda/search_grnn_compute.h
浏览文件 @
b65a6dc9
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <vector>
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/kernel.h"
...
...
@@ -23,6 +24,53 @@ namespace lite {
namespace
kernels
{
namespace
cuda
{
class
SeqSortedseqTranseUtil
{
public:
explicit
SeqSortedseqTranseUtil
(
bool
is_reverse
=
false
,
bool
is_bi
=
false
)
:
_is_reverse
(
is_reverse
),
_is_bi
(
is_bi
),
_dev_map_vec
(
nullptr
),
_dev_map_vec_length
(
0
)
{}
~
SeqSortedseqTranseUtil
()
{
if
(
_dev_map_vec
!=
nullptr
)
{
TargetWrapperCuda
::
Free
(
static_cast
<
void
*>
(
_dev_map_vec
));
}
}
std
::
vector
<
int
>&
get_length_index
()
{
return
_length_index
;
}
std
::
vector
<
int
>&
get_emit_offset_vec
()
{
return
_emit_offset_vec
;
}
std
::
vector
<
int
>&
get_map_vec
()
{
return
_map_vec
;
}
int
*
get_dev_map_vec
()
{
return
_dev_map_vec
;
}
int
get_emit_length
()
{
return
_emit_length
;
}
template
<
typename
Dtype
>
void
seq_2_sorted_seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
word_size
,
cudaStream_t
stream
);
template
<
typename
Dtype
>
void
sorted_seq_2_seq
(
const
Dtype
*
input
,
Dtype
*
output
,
int
hidden_size
,
cudaStream_t
stream
);
bool
get_sorted_map
(
const
std
::
vector
<
int
>&
offset_vec
,
cudaStream_t
stream_id
);
private:
std
::
vector
<
int
>
_length_index
;
std
::
vector
<
int
>
_emit_offset_vec
;
std
::
vector
<
int
>
_map_vec
;
int
_emit_length
;
bool
_is_reverse
;
bool
_is_bi
;
int
*
_dev_map_vec
;
int
_dev_map_vec_length
;
};
class
SearchGrnnCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)
>
{
public:
...
...
@@ -34,10 +82,26 @@ class SearchGrnnCompute
virtual
~
SearchGrnnCompute
()
=
default
;
private:
std
::
shared_ptr
<
Tensor
>
idx_sorted_by_width_cpu
;
// Weights preprocess:
// wi need to be transpose, the axes should be (2, 0, 1)
// wh0 should transpose, {wh1 wh2} need be transpose, the axes should be {2,
// 0, 1}
void
WeightsPreprocess
();
private:
std
::
unique_ptr
<
lite
::
cuda
::
math
::
Gemm
<
float
,
float
>>
gemm_impl_
;
void
PrepareLayout
(
const
Tensor
*
input
);
void
CopyBack
(
float
*
from
,
float
*
to
,
int
step
);
lite
::
Tensor
_temp_tensor_in
;
lite
::
Tensor
_temp_tensor_out
;
lite
::
Tensor
_temp_wx
;
lite
::
Tensor
_temp_wh
;
lite
::
Tensor
_temp_zero
;
lite
::
Tensor
_temp_weights_h2h
;
lite
::
Tensor
_wi
;
lite
::
Tensor
_wh
;
SeqSortedseqTranseUtil
_seq_util
;
};
}
// namespace cuda
...
...
lite/kernels/cuda/transpose_compute.cu
浏览文件 @
b65a6dc9
...
...
@@ -25,6 +25,7 @@ namespace cuda {
void
TransposeCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
const
lite
::
Tensor
*
X
=
param
.
x
;
lite
::
Tensor
*
Out
=
param
.
output
;
...
...
@@ -39,8 +40,7 @@ void TransposeCompute::Run() {
// NCHW -> NHWC
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
2
&&
axes
[
2
]
==
3
&&
axes
[
3
]
==
1
)
{
lite
::
cuda
::
math
::
NCHW2NHWC
(
dims
[
0
],
dims
[
1
],
dims
[
2
]
*
dims
[
3
],
in
,
out
,
&
ctx
);
trans
.
NCHW2NHWC
(
dims
[
0
],
dims
[
1
],
dims
[
2
]
*
dims
[
3
],
in
,
out
,
&
stream
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
return
;
...
...
@@ -49,14 +49,13 @@ void TransposeCompute::Run() {
// NHWC -> NCHW
if
(
axes
.
size
()
==
4
&&
axes
[
0
]
==
0
&&
axes
[
1
]
==
3
&&
axes
[
2
]
==
1
&&
axes
[
3
]
==
2
)
{
lite
::
cuda
::
math
::
NHWC2NCHW
(
dims
[
0
],
dims
[
3
],
dims
[
1
]
*
dims
[
2
],
in
,
out
,
&
ctx
);
trans
.
NHWC2NCHW
(
dims
[
0
],
dims
[
3
],
dims
[
1
]
*
dims
[
2
],
in
,
out
,
&
stream
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
return
;
}
lite
::
cuda
::
math
::
Transpose
(
dims
,
axes
,
in
,
out
,
&
ctx
);
trans
.
transpose
(
out
,
in
,
dims
,
axes
,
&
stream
);
cudaError_t
error
=
cudaGetLastError
();
if
(
error
!=
cudaSuccess
)
LOG
(
INFO
)
<<
cudaGetErrorString
(
error
);
}
...
...
lite/kernels/cuda/transpose_compute.h
浏览文件 @
b65a6dc9
...
...
@@ -29,7 +29,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual
~
TransposeCompute
()
=
default
;
private:
lite
::
Tensor
axes_
,
dims_
;
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
};
}
// namespace cuda
...
...
lite/kernels/cuda/transpose_compute_test.cc
浏览文件 @
b65a6dc9
...
...
@@ -238,7 +238,7 @@ TEST(transpose, normal) {
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
int
C
=
6
,
H
=
7
,
W
=
8
;
int
C
=
3
,
H
=
128
,
W
=
12
8
;
std
::
vector
<
int
>
axes
({
2
,
0
,
1
});
x
.
Resize
({
C
,
H
,
W
});
out
.
Resize
({
W
,
C
,
H
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录