Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
396bd65f
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
396bd65f
编写于
8月 16, 2020
作者:
J
Juncheng
提交者:
GitHub
8月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize transpose performance (#3487)
Former-commit-id:
809793c4
上级
ea1d417c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
208 addition
and
51 deletion
+208
-51
oneflow/core/kernel/util/cuda_arithemetic_interface.cu
oneflow/core/kernel/util/cuda_arithemetic_interface.cu
+71
-18
oneflow/core/kernel/util/cuda_arithemetic_interface.h
oneflow/core/kernel/util/cuda_arithemetic_interface.h
+31
-12
oneflow/core/kernel/util/host_arithemetic_interface.cpp
oneflow/core/kernel/util/host_arithemetic_interface.cpp
+55
-6
oneflow/core/kernel/util/host_arithemetic_interface.h
oneflow/core/kernel/util/host_arithemetic_interface.h
+26
-10
oneflow/user/kernels/transpose_kernel.cpp
oneflow/user/kernels/transpose_kernel.cpp
+25
-5
未找到文件。
oneflow/core/kernel/util/cuda_arithemetic_interface.cu
浏览文件 @
396bd65f
...
...
@@ -33,8 +33,9 @@ template<int32_t NDIMS>
__device__
int32_t
GetXIndex
(
const
int32_t
*
y_shape
,
const
int32_t
*
x_strides
,
int32_t
y_idx
)
{
int32_t
x_idx
=
0
;
for
(
int32_t
i
=
NDIMS
-
1
;
i
>=
0
;
--
i
)
{
x_idx
+=
(
y_idx
%
y_shape
[
i
])
*
x_strides
[
i
];
y_idx
/=
y_shape
[
i
];
const
int32_t
next_y_idx
=
y_idx
/
y_shape
[
i
];
x_idx
+=
(
y_idx
-
next_y_idx
*
y_shape
[
i
])
*
x_strides
[
i
];
y_idx
=
next_y_idx
;
}
return
x_idx
;
}
...
...
@@ -42,16 +43,8 @@ __device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, i
template
<
int32_t
NDIMS
,
typename
T
>
__global__
void
TransposeGpu
(
const
Int32Array
<
NDIMS
>
y_shape
,
const
Int32Array
<
NDIMS
>
x_strides
,
const
int32_t
elem_cnt
,
const
T
*
x
,
T
*
y
)
{
__shared__
int32_t
x_strides_shared
[
NDIMS
];
__shared__
int32_t
y_dims_shared
[
NDIMS
];
const
int32_t
tid
=
threadIdx
.
x
;
if
(
tid
<
NDIMS
)
{
y_dims_shared
[
tid
]
=
y_shape
.
val
[
tid
];
x_strides_shared
[
tid
]
=
x_strides
.
val
[
tid
];
}
__syncthreads
();
CUDA_1D_KERNEL_LOOP
(
y_idx
,
elem_cnt
)
{
const
int32_t
x_idx
=
GetXIndex
<
NDIMS
>
(
y_
dims_shared
,
x_strides_shared
,
y_idx
);
const
int32_t
x_idx
=
GetXIndex
<
NDIMS
>
(
y_
shape
.
val
,
x_strides
.
val
,
y_idx
);
#if __CUDA_ARCH__ >= 350
y
[
y_idx
]
=
__ldg
(
x
+
x_idx
);
#else
...
...
@@ -62,7 +55,8 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N
template
<
int32_t
NDIMS
,
typename
T
>
void
TransposeImpl
(
DeviceCtx
*
ctx
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
T
*
x
,
T
*
y
)
{
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
T
*
x
,
T
*
y
)
{
CHECK_LE
(
y_shape
.
elem_cnt
(),
GetMaxVal
<
int32_t
>
());
Int32Array
<
NDIMS
>
y_shape_struct
;
FOR_RANGE
(
int32_t
,
i
,
0
,
NDIMS
)
{
y_shape_struct
.
val
[
i
]
=
y_shape
.
At
(
i
);
}
...
...
@@ -95,7 +89,7 @@ struct TransposeUtil final {
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
)
{
TRANSPOSE_CHECK
;
TransposeUtil
<
float
>::
SwitchTransposeImpl
(
SwitchCase
(
num_axis
),
ctx
,
x_shape
,
y_shape
,
...
...
@@ -104,7 +98,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
)
{
TRANSPOSE_CHECK
;
...
...
@@ -114,7 +108,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float16
*
x
,
float16
*
y
)
{
TRANSPOSE_CHECK
;
...
...
@@ -125,7 +119,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
)
{
TRANSPOSE_CHECK
;
...
...
@@ -135,7 +129,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
)
{
TRANSPOSE_CHECK
;
...
...
@@ -145,7 +139,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
)
{
TRANSPOSE_CHECK
;
...
...
@@ -155,6 +149,65 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
#undef TRANSPOSE_CHECK
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
)
{
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
)
{
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float16
*
x
,
float16
*
y
)
{
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
)
{
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
)
{
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
)
{
ArithemeticIf
<
DeviceType
::
kGPU
>::
Transpose
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kGPU
>::
InitializeWithConstConf
(
DeviceCtx
*
ctx
,
const
ConstantInitializerConf
&
initializer_conf
,
Blob
*
blob
)
{
WithHostBlobAndStreamSynchronizeEnv
(
ctx
,
blob
,
[
&
](
Blob
*
host_blob
)
{
...
...
oneflow/core/kernel/util/cuda_arithemetic_interface.h
浏览文件 @
396bd65f
...
...
@@ -29,24 +29,43 @@ class ConstantInitializerConf;
template
<
>
struct
ArithemeticIf
<
DeviceType
::
kGPU
>
{
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
float16
*
x
,
float16
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float16
*
x
,
float16
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
float16
*
x
,
float16
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
);
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
);
static
void
InitializeWithConstConf
(
DeviceCtx
*
ctx
,
const
ConstantInitializerConf
&
initializer_conf
,
Blob
*
blob
);
...
...
oneflow/core/kernel/util/host_arithemetic_interface.cpp
浏览文件 @
396bd65f
...
...
@@ -46,7 +46,7 @@ void IncreaseIndex(const int64_t* shape, DimVector& index) {
template
<
typename
T
>
void
TransposeImpl
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
T
*
x
,
T
*
y
)
{
int64_t
block_size
=
1
;
int32_t
shared_idxs_num
=
0
;
...
...
@@ -87,14 +87,14 @@ void ConstantInitializer(const T& value, Blob* blob) {
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
)
{
TransposeImpl
<
float
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
permutation
,
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
)
{
TransposeImpl
<
double
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
permutation
,
elem_cnt
,
x
,
y
);
...
...
@@ -102,7 +102,7 @@ void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
)
{
TransposeImpl
<
int8_t
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
permutation
,
elem_cnt
,
x
,
y
);
...
...
@@ -110,7 +110,7 @@ void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
)
{
TransposeImpl
<
int32_t
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
permutation
,
elem_cnt
,
x
,
y
);
...
...
@@ -118,12 +118,61 @@ void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
std
::
vector
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
)
{
TransposeImpl
<
int64_t
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
permutation
,
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
)
{
TransposeImpl
<
float
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
)
{
TransposeImpl
<
double
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
)
{
TransposeImpl
<
int8_t
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
)
{
TransposeImpl
<
int32_t
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
)
{
TransposeImpl
<
int64_t
>
(
ctx
,
num_axis
,
x_shape
,
y_shape
,
std
::
vector
<
int32_t
>
({
permutation
.
cbegin
(),
permutation
.
cend
()}),
elem_cnt
,
x
,
y
);
}
void
ArithemeticIf
<
DeviceType
::
kCPU
>::
InitializeWithConstConf
(
DeviceCtx
*
ctx
,
const
ConstantInitializerConf
&
initializer_conf
,
Blob
*
blob
)
{
DataType
dtype
=
blob
->
data_type
();
...
...
oneflow/core/kernel/util/host_arithemetic_interface.h
浏览文件 @
396bd65f
...
...
@@ -27,21 +27,37 @@ class ConstantInitializerConf;
template
<
>
struct
ArithemeticIf
<
DeviceType
::
kCPU
>
{
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
std
::
vector
<
int32_t
>&
permutation
,
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
float
*
x
,
float
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
double
*
x
,
double
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
int8_t
*
x
,
int8_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
const
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
int64_t
elem_cnt
,
const
int32_t
*
x
,
int32_t
*
y
);
static
void
Transpose
(
DeviceCtx
*
ctx
,
int32_t
num_axis
,
const
ShapeView
&
x_shape
,
const
ShapeView
&
y_shape
,
const
PbRf
<
int32_t
>&
permutation
,
const
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
);
int64_t
elem_cnt
,
const
int64_t
*
x
,
int64_t
*
y
);
static
void
InitializeWithConstConf
(
DeviceCtx
*
ctx
,
const
ConstantInitializerConf
&
initializer_conf
,
Blob
*
blob
);
...
...
oneflow/user/kernels/transpose_kernel.cpp
浏览文件 @
396bd65f
...
...
@@ -25,17 +25,37 @@ template<DeviceType device_type, typename T>
class
TransposeKernel
final
:
public
OpKernel
{
public:
TransposeKernel
()
=
default
;
~
TransposeKernel
()
=
default
;
~
TransposeKernel
()
override
=
default
;
private:
void
Compute
(
KernelComputeContext
*
ctx
)
const
override
{
const
Tensor
*
tensor_in
=
ctx
->
Tensor4ArgNameAndIndex
(
"input"
,
0
);
Tensor
*
tensor_out
=
ctx
->
Tensor4ArgNameAndIndex
(
"output"
,
0
);
const
auto
&
perm
=
ctx
->
Attr
<
std
::
vector
<
int32_t
>>
(
"perm"
);
NewKernelUtil
<
device_type
>::
Transpose
(
ctx
->
device_ctx
(),
tensor_in
->
shape
().
NumAxes
(),
tensor_in
->
shape
(),
tensor_out
->
shape
(),
StdVec2PbRf
(
perm
),
tensor_in
->
shape
().
elem_cnt
(),
tensor_in
->
dptr
<
T
>
(),
tensor_out
->
mut_dptr
<
T
>
());
using
PackType
=
int64_t
;
const
size_t
num_elem_per_pack
=
sizeof
(
PackType
)
/
sizeof
(
T
);
const
ShapeView
&
in_shape
=
tensor_in
->
shape
();
const
ShapeView
&
out_shape
=
tensor_out
->
shape
();
if
(
num_elem_per_pack
!=
1
&&
perm
.
back
()
==
perm
.
size
()
-
1
&&
in_shape
.
At
(
in_shape
.
NumAxes
()
-
1
)
%
num_elem_per_pack
==
0
)
{
CHECK_EQ
(
in_shape
.
At
(
in_shape
.
NumAxes
()
-
1
),
out_shape
.
At
(
out_shape
.
NumAxes
()
-
1
));
DimVector
packed_in_dim_vec
;
in_shape
.
ToDimVector
(
&
packed_in_dim_vec
);
packed_in_dim_vec
.
back
()
/=
num_elem_per_pack
;
const
Shape
packed_in_shape
(
packed_in_dim_vec
);
DimVector
packed_out_dim_vec
;
out_shape
.
ToDimVector
(
&
packed_out_dim_vec
);
packed_out_dim_vec
.
back
()
/=
num_elem_per_pack
;
const
Shape
packed_out_shape
(
packed_out_dim_vec
);
NewKernelUtil
<
device_type
>::
Transpose
(
ctx
->
device_ctx
(),
packed_in_shape
.
NumAxes
(),
packed_in_shape
,
packed_out_shape
,
perm
,
packed_in_shape
.
elem_cnt
(),
reinterpret_cast
<
const
PackType
*>
(
tensor_in
->
dptr
<
T
>
()),
reinterpret_cast
<
PackType
*>
(
tensor_out
->
mut_dptr
<
T
>
()));
}
else
{
NewKernelUtil
<
device_type
>::
Transpose
(
ctx
->
device_ctx
(),
in_shape
.
NumAxes
(),
in_shape
,
tensor_out
->
shape
(),
perm
,
in_shape
.
elem_cnt
(),
tensor_in
->
dptr
<
T
>
(),
tensor_out
->
mut_dptr
<
T
>
());
}
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录