Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
11cbd50e
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
11cbd50e
编写于
7月 03, 2020
作者:
W
Wilber
提交者:
GitHub
7月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update transpose cuda kernel. test=develop (#3879)
上级
f35f8dac
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
170 addition
and
124 deletion
+170
-124
lite/backends/cuda/math/transpose.cu
lite/backends/cuda/math/transpose.cu
+1
-16
lite/kernels/cuda/assign_value_compute_test.cc
lite/kernels/cuda/assign_value_compute_test.cc
+2
-2
lite/kernels/cuda/sequence_mask_compute_test.cc
lite/kernels/cuda/sequence_mask_compute_test.cc
+2
-2
lite/kernels/cuda/sequence_pad_compute_test.cc
lite/kernels/cuda/sequence_pad_compute_test.cc
+2
-2
lite/kernels/cuda/sequence_unpad_compute_test.cc
lite/kernels/cuda/sequence_unpad_compute_test.cc
+2
-2
lite/kernels/cuda/transpose_compute.cu
lite/kernels/cuda/transpose_compute.cu
+27
-27
lite/kernels/cuda/transpose_compute.h
lite/kernels/cuda/transpose_compute.h
+3
-2
lite/kernels/cuda/transpose_compute_test.cc
lite/kernels/cuda/transpose_compute_test.cc
+131
-41
lite/operators/transpose_op.cc
lite/operators/transpose_op.cc
+0
-30
未找到文件。
lite/backends/cuda/math/transpose.cu
浏览文件 @
11cbd50e
...
...
@@ -174,24 +174,9 @@ void Transpose<T>::transpose(T* dst,
TransposeCUDAImpl
<
T
>
(
src_dims
,
axes
,
src
,
dst
,
&
Y_dims_
,
&
strides_
,
stream
);
}
// template <typename T>
// void Transpose<T>::transpose(T* dst,
// const T* src,
// const std::vector<int>& src_dims,
// const std::vector<int>& axes,
// cudaStream_t* stream) {
// std::vector<int64_t> _src_dims(src_dims.size(), 0);
// std::transform(
// src_dims.begin(),
// src_dims.end(),
// _src_dims.begin(),
// [](int data) -> int64_t { return static_cast<int64_t>(data); });
// TransposeCUDAImpl<T>(_src_dims, axes, src, dst, &Y_dims_, &strides_,
// stream);
//}
template
class
Transpose
<
int8_t
>;
template
class
Transpose
<
float
>;
template
class
Transpose
<
half
>;
}
// namespace math
}
// namespace cuda
...
...
lite/kernels/cuda/assign_value_compute_test.cc
浏览文件 @
11cbd50e
...
...
@@ -60,6 +60,8 @@ class AssignValueTest : public ::testing::Test {
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
param
.
shape
=
shape
;
param
.
dtype
=
dtype
;
param
.
fp32_values
=
fp32_values
;
...
...
@@ -113,8 +115,6 @@ class AssignValueTest : public ::testing::Test {
TEST_F
(
AssignValueTest
,
fp32
)
{
float_data_init
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
AssignValueCompute
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
...
...
lite/kernels/cuda/sequence_mask_compute_test.cc
浏览文件 @
11cbd50e
...
...
@@ -57,6 +57,8 @@ class SequenceMaskTest : public ::testing::Test {
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
param
.
X
=
&
X_gpu
;
param
.
Y
=
&
Out_gpu
;
param
.
maxlen
=
maxlen
;
...
...
@@ -94,8 +96,6 @@ class SequenceMaskTest : public ::testing::Test {
TEST_F
(
SequenceMaskTest
,
fp32
)
{
float_data_init
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
SequenceMaskCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
...
...
lite/kernels/cuda/sequence_pad_compute_test.cc
浏览文件 @
11cbd50e
...
...
@@ -74,6 +74,8 @@ class SequencePadTest : public ::testing::Test {
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
param
.
X
=
&
X_gpu
;
param
.
PadValue
=
&
PadValue_gpu
;
param
.
Length
=
&
Length_gpu
;
...
...
@@ -125,8 +127,6 @@ class SequencePadTest : public ::testing::Test {
TEST_F
(
SequencePadTest
,
fp32
)
{
float_data_init
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
SequencePadCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
...
...
lite/kernels/cuda/sequence_unpad_compute_test.cc
浏览文件 @
11cbd50e
...
...
@@ -74,6 +74,8 @@ class SequenceUnpadTest : public ::testing::Test {
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
param
.
X
=
&
X_gpu
;
param
.
Length
=
&
Length_gpu
;
param
.
Out
=
&
Out_gpu
;
...
...
@@ -116,8 +118,6 @@ class SequenceUnpadTest : public ::testing::Test {
TEST_F
(
SequenceUnpadTest
,
fp32
)
{
float_data_init
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
SequenceUnpadCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
...
...
lite/kernels/cuda/transpose_compute.cu
浏览文件 @
11cbd50e
...
...
@@ -13,17 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "lite/kernels/cuda/transpose_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/transpose_compute.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
void
TransposeCompute
::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
template
<
typename
T
,
PrecisionType
Ptype
>
void
TransposeCompute
<
T
,
Ptype
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
param_t
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
...
...
@@ -31,8 +34,8 @@ void TransposeCompute::Run() {
lite
::
Tensor
*
Out
=
param
.
output
;
std
::
vector
<
int
>
axes
=
param
.
axis
;
const
float
*
in
=
X
->
data
<
float
>
();
float
*
out
=
Out
->
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
const
T
*
in
=
X
->
template
data
<
T
>();
T
*
out
=
Out
->
mutable_data
<
T
>
(
TARGET
(
kCUDA
));
int
ndim
=
X
->
dims
().
size
();
std
::
vector
<
int64_t
>
dims
=
X
->
dims
().
data
();
...
...
@@ -65,34 +68,31 @@ void TransposeCompute::Run() {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
transpose
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
TransposeCompute
,
def
)
using
TransFp32
=
paddle
::
lite
::
kernels
::
cuda
::
TransposeCompute
<
float
,
PRECISION
(
kFloat
)
>
;
using
TransFp16
=
paddle
::
lite
::
kernels
::
cuda
::
TransposeCompute
<
half
,
PRECISION
(
kFP16
)
>
;
REGISTER_LITE_KERNEL
(
transpose
,
kCUDA
,
kFloat
,
kNCHW
,
TransFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
transpose2
,
kCUDA
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
cuda
::
TransposeCompute
,
def
)
REGISTER_LITE_KERNEL
(
transpose2
,
kCUDA
,
kFloat
,
kNCHW
,
TransFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"XShape"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
// REGISTER_LITE_KERNEL(transpose2,
// kCUDA,
// kFloat,
// kNCHW,
// paddle::lite::kernels::cuda::TransposeCompute,
// def)
// .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .Finalize();
REGISTER_LITE_KERNEL
(
transpose
,
kCUDA
,
kFP16
,
kNCHW
,
TransFp16
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
transpose2
,
kCUDA
,
kFP16
,
kNCHW
,
TransFp16
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
BindOutput
(
"XShape"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kFP16
))})
.
Finalize
();
lite/kernels/cuda/transpose_compute.h
浏览文件 @
11cbd50e
...
...
@@ -21,7 +21,8 @@ namespace lite {
namespace
kernels
{
namespace
cuda
{
class
TransposeCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
{
template
<
typename
Dtype
,
PrecisionType
Ptype
>
class
TransposeCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
Ptype
>
{
public:
using
param_t
=
operators
::
TransposeParam
;
...
...
@@ -29,7 +30,7 @@ class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
virtual
~
TransposeCompute
()
=
default
;
private:
lite
::
cuda
::
math
::
Transpose
<
float
>
trans
;
lite
::
cuda
::
math
::
Transpose
<
Dtype
>
trans
;
};
}
// namespace cuda
...
...
lite/kernels/cuda/transpose_compute_test.cc
浏览文件 @
11cbd50e
...
...
@@ -13,11 +13,16 @@
// limitations under the License.
#include "lite/kernels/cuda/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
...
...
@@ -89,7 +94,7 @@ void nhwc2nchw_ref(lite::Tensor* input,
}
}
void
transpose_ref
(
lite
::
Tensor
*
input
,
void
transpose_ref
(
const
lite
::
Tensor
*
input
,
lite
::
Tensor
*
output
,
const
std
::
vector
<
int
>
axes
)
{
auto
*
input_data
=
input
->
data
<
float
>
();
...
...
@@ -123,7 +128,7 @@ void transpose_ref(lite::Tensor* input,
}
// namespace
TEST
(
transpose_nchw
,
normal
)
{
TransposeCompute
transpose_kernel
;
TransposeCompute
<
float
,
PRECISION
(
kFloat
)
>
transpose_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
...
...
@@ -177,7 +182,7 @@ TEST(transpose_nchw, normal) {
}
TEST
(
transpose_nhwc
,
normal
)
{
TransposeCompute
transpose_kernel
;
TransposeCompute
<
float
,
PRECISION
(
kFloat
)
>
transpose_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
...
...
@@ -228,54 +233,139 @@ TEST(transpose_nhwc, normal) {
}
}
TEST
(
transpose
,
normal
)
{
TransposeCompute
transpose_kernel
;
std
::
unique_ptr
<
KernelContext
>
ctx
(
new
KernelContext
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
class
TransposeTest
:
public
::
testing
::
Test
{
protected:
TransposeTest
()
:
C
(
3
),
H
(
128
),
W
(
64
),
axes
({
1
,
2
,
0
}),
x_shape
({
C
,
H
,
W
}),
out_shape
({
H
,
W
,
C
})
{
X_ref
.
Resize
(
lite
::
DDim
(
x_shape
));
X_gpu
.
Resize
(
X_ref
.
dims
());
auto
x_ref_data
=
X_ref
.
mutable_data
<
float
>
();
// prepare input
for
(
int64_t
i
=
0
;
i
<
X_ref
.
numel
();
i
++
)
{
x_ref_data
[
i
]
=
static_cast
<
float
>
(
i
);
}
operators
::
TransposeParam
param
;
Out_ref
.
Resize
(
lite
::
DDim
(
out_shape
));
Out_gpu
.
Resize
(
Out_ref
.
dims
());
Out_cpu
.
Resize
(
Out_ref
.
dims
());
cpu_base
(
&
X_ref
,
&
Out_ref
);
lite
::
Tensor
x
,
x_cpu
,
x_ref
;
lite
::
Tensor
out
,
out_cpu
,
out_ref
;
int
C
=
3
,
H
=
128
,
W
=
128
;
std
::
vector
<
int
>
axes
({
2
,
0
,
1
});
x
.
Resize
({
C
,
H
,
W
});
out
.
Resize
({
W
,
C
,
H
});
device_init
();
}
x_cpu
.
Resize
({
C
,
H
,
W
});
out_cpu
.
Resize
({
W
,
C
,
H
});
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
param
.
x
=
&
X_gpu
;
param
.
output
=
&
Out_gpu
;
param
.
axis
=
axes
;
}
x_ref
.
Resize
({
C
,
H
,
W
});
out_ref
.
Resize
({
W
,
C
,
H
});
void
float_data_init
()
{
X_gpu
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
X_ref
.
data
<
float
>
(),
X_gpu
.
dims
());
}
auto
*
x_cpu_data
=
x_cpu
.
mutable_data
<
float
>
();
auto
*
out_cpu_data
=
out_cpu
.
mutable_data
<
float
>
();
auto
*
x_ref_data
=
x_ref
.
mutable_data
<
float
>
();
void
half_data_init
()
{
X_half
.
Resize
(
lite
::
DDim
(
X_ref
.
dims
()));
auto
x_half_data
=
X_half
.
mutable_data
<
half
>
();
for
(
int64_t
i
=
0
;
i
<
X_half
.
numel
();
i
++
)
{
x_half_data
[
i
]
=
half
(
lite
::
float16
(
X_ref
.
data
<
float
>
()[
i
]));
}
X_gpu
.
Assign
<
half
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_half_data
,
X_gpu
.
dims
());
}
for
(
int
i
=
0
;
i
<
x_cpu
.
numel
();
++
i
)
{
x_cpu_data
[
i
]
=
i
+
1
;
x_ref_data
[
i
]
=
i
+
1
;
void
cpu_base
(
const
lite
::
Tensor
*
X
,
lite
::
Tensor
*
Out
)
{
transpose_ref
(
X
,
Out
,
axes
);
}
x
.
Assign
<
float
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
x_cpu_data
,
x_cpu
.
dims
());
param
.
x
=
&
x
;
param
.
output
=
&
out
;
param
.
axis
=
axes
;
transpose_kernel
.
SetParam
(
param
);
int
C
,
H
,
W
;
std
::
vector
<
int
>
axes
;
std
::
vector
<
int64_t
>
x_shape
,
out_shape
;
lite
::
Tensor
X_ref
,
Out_ref
;
lite
::
Tensor
X_gpu
,
Out_gpu
;
lite
::
Tensor
X_half
;
lite
::
Tensor
Out_cpu
;
operators
::
TransposeParam
param
;
std
::
unique_ptr
<
KernelContext
>
ctx
;
cudaStream_t
stream
;
cudaStreamCreate
(
&
stream
);
context
.
SetExecStream
(
stream
);
transpose_kernel
.
SetContext
(
std
::
move
(
ctx
));
transpose_kernel
.
Launch
();
};
TEST_F
(
TransposeTest
,
fp32
)
{
float_data_init
();
TransposeCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
*
out_data
=
out
.
mutable_data
<
float
>
(
TARGET
(
kCUDA
));
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_data
,
sizeof
(
float
)
*
out
.
numel
(),
IoDirection
::
DtoH
);
transpose_ref
(
&
x_ref
,
&
out_ref
,
axes
);
auto
*
out_ref_data
=
out_ref
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
out
.
numel
();
i
++
)
{
EXPECT_NEAR
(
out_cpu_data
[
i
],
out_ref_data
[
i
],
1e-5
);
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
CopySync
<
TARGET
(
kCUDA
)
>
(
Out_cpu
.
mutable_data
<
float
>
(),
Out_gpu
.
data
<
float
>
(),
sizeof
(
float
)
*
Out_gpu
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
Out_gpu
.
numel
();
++
i
)
{
EXPECT_NEAR
(
Out_cpu
.
data
<
float
>
()[
i
],
Out_ref
.
data
<
float
>
()[
i
],
1e-5
);
}
}
TEST_F
(
TransposeTest
,
TestFP16
)
{
half_data_init
();
TransposeCompute
<
half
,
PRECISION
(
kFP16
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp16, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
const
half
*
out_gpu_data
=
Out_gpu
.
data
<
half
>
();
half
*
out_cpu_data
=
Out_cpu
.
mutable_data
<
half
>
();
CopySync
<
TARGET
(
kCUDA
)
>
(
out_cpu_data
,
out_gpu_data
,
sizeof
(
half
)
*
Out_gpu
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
Out_cpu
.
numel
();
++
i
)
{
float
res
=
static_cast
<
float
>
(
lite
::
float16
(
out_cpu_data
[
i
]));
float
ref
=
Out_ref
.
data
<
float
>
()[
i
];
EXPECT_NEAR
(
fabs
(
res
-
ref
)
/
(
ref
+
1e-5
),
0.
,
1e-2
);
}
}
...
...
lite/operators/transpose_op.cc
浏览文件 @
11cbd50e
...
...
@@ -43,24 +43,9 @@ bool TransposeOp::CheckShape() const {
}
bool
TransposeOp
::
InferShapeImpl
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
std
::
vector
<
int
>
axis
=
param_
.
axis
;
size_t
axis_size
=
axis
.
size
();
// "The input tensor's rank(%d) should be equal to the axis's size(%d)",
// x_rank, axis_size
CHECK_OR_FALSE
(
x_rank
==
axis_size
);
std
::
vector
<
int
>
count
(
axis_size
,
0
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
// Each element of Attribute axis should be a unique value
// range from 0 to (dims - 1),
// where the dims is the axis's size
CHECK_OR_FALSE
(
axis
[
i
]
<
static_cast
<
int
>
(
axis_size
)
&&
++
count
[
axis
[
i
]]
==
1
);
}
lite
::
DDim
out_dims
(
x_dims
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
out_dims
[
i
]
=
x_dims
[
axis
[
i
]];
...
...
@@ -113,24 +98,9 @@ bool Transpose2Op::CheckShape() const {
}
bool
Transpose2Op
::
InferShapeImpl
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
std
::
vector
<
int
>
axis
=
param_
.
axis
;
size_t
axis_size
=
axis
.
size
();
// "The input tensor's rank(%d) should be equal to the axis's size(%d)",
// x_rank, axis_size
CHECK_OR_FALSE
(
x_rank
==
axis_size
);
std
::
vector
<
int
>
count
(
axis_size
,
0
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
// Each element of Attribute axis should be a unique value
// range from 0 to (dims - 1),
// where the dims is the axis's size
CHECK_OR_FALSE
(
axis
[
i
]
<
static_cast
<
int
>
(
axis_size
)
&&
++
count
[
axis
[
i
]]
==
1
);
}
lite
::
DDim
out_dims
(
x_dims
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
out_dims
[
i
]
=
x_dims
[
axis
[
i
]];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录