Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ef76f664
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef76f664
编写于
10月 28, 2021
作者:
L
Liu-xiandong
提交者:
GitHub
10月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite Softmax in Kernel Primitive API, test=develop (#36706)
上级
b151a451
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
191 addition
and
210 deletion
+191
-210
paddle/fluid/operators/softmax_cudnn_op.cu.h
paddle/fluid/operators/softmax_cudnn_op.cu.h
+191
-210
未找到文件。
paddle/fluid/operators/softmax_cudnn_op.cu.h
浏览文件 @
ef76f664
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
...
@@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
...
@@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
}
}
}
}
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ReduceMaxFunctor
{
inline
Ty
initial
()
{
return
-
std
::
numeric_limits
<
Ty
>::
infinity
();
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
max
(
a
,
b
);
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpSubFunctor
{
HOSTDEVICE
inline
ExpSubFunctor
()
{
y
=
static_cast
<
Tx
>
(
0.0
f
);
}
HOSTDEVICE
explicit
inline
ExpSubFunctor
(
Tx
y
)
:
y
((
Tx
)(
y
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
std
::
exp
(
x
-
y
));
}
private:
Tx
y
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpMulFunctor
{
HOSTDEVICE
inline
ExpMulFunctor
()
{
y
=
static_cast
<
Tx
>
(
1.0
f
);
}
HOSTDEVICE
explicit
inline
ExpMulFunctor
(
Tx
y
)
:
y
((
Tx
)(
y
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
std
::
exp
(
x
)
*
y
);
}
private:
Tx
y
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnarySubFunctor
{
HOSTDEVICE
inline
UnarySubFunctor
()
{
y
=
static_cast
<
Tx
>
(
0.0
f
);
}
HOSTDEVICE
explicit
inline
UnarySubFunctor
(
Tx
y
)
:
y
((
Tx
)(
y
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
-
y
);
}
private:
Tx
y
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnaryLogFunctor
{
HOSTDEVICE
inline
UnaryLogFunctor
()
{}
HOSTDEVICE
explicit
inline
UnaryLogFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
std
::
log
(
x
));
}
};
template
<
typename
Tx
,
typename
Ty
>
struct
DataTransFunctor
{
HOSTDEVICE
inline
DataTransFunctor
()
{}
HOSTDEVICE
explicit
inline
DataTransFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
x
==
-
std
::
numeric_limits
<
Tx
>::
infinity
()
?
-
std
::
numeric_limits
<
Ty
>::
infinity
()
:
static_cast
<
Ty
>
(
x
);
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnaryDivFunctor
{
HOSTDEVICE
inline
UnaryDivFunctor
()
{
n_inv
=
static_cast
<
Tx
>
(
1.0
f
);
}
HOSTDEVICE
explicit
inline
UnaryDivFunctor
(
Tx
n
)
:
n_inv
((
Tx
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
*
n_inv
);
}
private:
Tx
n_inv
;
};
/*
/*
Core function of computing softmax forward for axis=-1.
Core function of computing softmax forward for axis=-1.
The computation includes
The computation includes
...
@@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
...
@@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kIterations
=
kDimCeil
/
kWarpSize
;
constexpr
int
kLoops
=
kDimCeil
/
kWarpSize
;
constexpr
int
kIterationsV
=
constexpr
int
kLoopsV
=
(
kLoops
>=
kVSize
)
?
(
kLoops
/
kVSize
)
:
1
;
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
32
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
constexpr
int
kStep
=
kBatchSize
*
kLoopsV
*
kVSize
;
constexpr
int
kVItem
=
kLoopsV
*
kVSize
;
constexpr
AccT
kLowInf
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
using
kMode
=
kps
::
details
::
ReduceMode
;
// max index to read
// max index to read
int
idx_max_v
[
kBatchSize
];
int
idx_max_v
[
kBatchSize
];
...
@@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
...
@@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
}
}
// read data from global memory
// read data from global memory
AccT
srcdata
[
kBatchSize
][
kIterationsV
][
kVSize
];
AccT
srcdata
[
kBatchSize
][
kLoopsV
][
kVSize
];
kps
::
Init
<
AccT
,
kStep
>
(
&
srcdata
[
0
][
0
][
0
],
kLowInf
);
T
src_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
kps
::
Init
<
T
,
kStep
>
(
&
src_tmp
[
0
][
0
][
0
],
-
std
::
numeric_limits
<
T
>::
infinity
());
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// read data
int
ptr
=
(
first_batch
+
i
)
*
stride
;
#pragma unroll
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[
ptr
]);
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
VecT
*
reg_v
=
reinterpret_cast
<
VecT
*>
(
&
src_tmp
[
i
][
0
][
0
]);
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
kps
::
ReadData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
if
(
kVSize
==
1
)
{
&
reg_v
[
0
],
&
src_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
1
);
if
(
src_idx
<
idx_max_v
[
i
])
{
kps
::
ElementwiseUnary
<
T
,
AccT
,
kVItem
,
1
,
1
,
DataTransFunctor
<
T
,
AccT
>>
(
srcdata
[
i
][
it
][
0
]
=
&
srcdata
[
i
][
0
][
0
],
&
src_tmp
[
i
][
0
][
0
],
DataTransFunctor
<
T
,
AccT
>
());
static_cast
<
AccT
>
(
src
[(
first_batch
+
i
)
*
stride
+
src_idx
]);
}
}
else
{
srcdata
[
i
][
it
][
0
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
// compute max
}
AccT
max
[
kBatchSize
];
}
else
{
kps
::
Init
<
AccT
,
kBatchSize
>
(
&
max
[
0
],
kLowInf
);
const
VecT
*
src_v
=
kps
::
Reduce
<
AccT
,
kVItem
,
kBatchSize
,
1
,
ReduceMaxFunctor
<
AccT
>
,
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
kMode
::
kLocalMode
>
(
&
max
[
0
],
&
srcdata
[
0
][
0
][
0
],
if
(
src_idx
<
idx_max_v
[
i
])
{
ReduceMaxFunctor
<
AccT
>
(),
true
);
VecT
srctmp
=
src_v
[
src_idx
];
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max
);
const
T
*
srcinptr
=
reinterpret_cast
<
const
T
*>
(
&
srctmp
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
static_cast
<
AccT
>
(
srcinptr
[
s
]);
}
}
else
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
srcdata
[
i
][
it
][
s
]
=
-
std
::
numeric_limits
<
AccT
>::
infinity
();
}
}
}
}
}
// compute max value
AccT
max_value
[
kBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
AccT
valmax
=
srcdata
[
i
][
0
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
0
][
s
])
?
valmax
:
srcdata
[
i
][
0
][
s
];
}
max_value
[
i
]
=
valmax
;
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
AccT
valmax
=
srcdata
[
i
][
it
][
0
];
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
valmax
=
(
valmax
>
srcdata
[
i
][
it
][
s
])
?
valmax
:
srcdata
[
i
][
it
][
s
];
}
max_value
[
i
]
=
(
max_value
[
i
]
>
valmax
)
?
max_value
[
i
]
:
valmax
;
}
}
WarpReduceMax
<
AccT
,
kBatchSize
,
kWarpSize
>
(
max_value
);
// compute sum
// compute sum
AccT
sum
[
kBatchSize
];
AccT
sum
[
kBatchSize
]
=
{
0
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
// it = 0
kps
::
ElementwiseUnary
<
AccT
,
AccT
,
kVItem
,
1
,
1
,
ExpSubFunctor
<
AccT
>>
(
if
(
LogMode
)
{
&
srcdata
[
i
][
0
][
0
],
&
srcdata
[
i
][
0
][
0
],
ExpSubFunctor
<
AccT
>
(
max
[
i
]));
sum
[
i
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
0
]
=
std
::
exp
(
srcdata
[
i
][
0
][
0
]
-
max_value
[
i
]);
sum
[
i
]
=
srcdata
[
i
][
0
][
0
];
}
#pragma unroll
for
(
int
s
=
1
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
0
][
s
]
=
std
::
exp
(
srcdata
[
i
][
0
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
0
][
s
];
}
}
// it = 1, 2, ...
#pragma unroll
for
(
int
it
=
1
;
it
<
kIterationsV
;
++
it
)
{
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
sum
[
i
]
+=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
}
else
{
srcdata
[
i
][
it
][
s
]
=
std
::
exp
(
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]);
sum
[
i
]
+=
srcdata
[
i
][
it
][
s
];
}
}
}
}
}
kps
::
Reduce
<
AccT
,
kVItem
,
kBatchSize
,
1
,
kps
::
AddFunctor
<
AccT
>
,
kMode
::
kLocalMode
>
(
&
sum
[
0
],
&
srcdata
[
0
][
0
][
0
],
kps
::
AddFunctor
<
AccT
>
(),
true
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result to global memory
// write result to global memory
T
out_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
LogMode
)
{
kps
::
ElementwiseUnary
<
AccT
,
T
,
kVItem
,
1
,
1
,
UnaryDivFunctor
<
AccT
>>
(
sum
[
i
]
=
std
::
log
(
sum
[
i
]);
&
out_tmp
[
i
][
0
][
0
],
&
srcdata
[
i
][
0
][
0
],
UnaryDivFunctor
<
AccT
>
(
sum
[
i
]));
}
int
softmax_ptr
=
(
first_batch
+
i
)
*
stride
;
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[
softmax_ptr
]);
#pragma unroll
VecT
*
reg_v
=
reinterpret_cast
<
VecT
*>
(
&
out_tmp
[
i
][
0
][
0
]);
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
kps
::
WriteData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
&
softmax_v
[
0
],
&
reg_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
1
);
if
(
kVSize
==
1
)
{
if
(
idx
<
idx_max_v
[
i
])
{
if
(
LogMode
)
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
softmax
[(
first_batch
+
i
)
*
stride
+
idx
]
=
srcdata
[
i
][
it
][
0
]
/
sum
[
i
];
}
}
else
{
break
;
}
}
else
{
VecT
*
softmax_v
=
reinterpret_cast
<
VecT
*>
(
&
softmax
[(
first_batch
+
i
)
*
stride
]);
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
-
max_value
[
i
]
-
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
srcdata
[
i
][
it
][
s
]
/
sum
[
i
];
}
}
if
(
idx
<
idx_max_v
[
i
])
{
softmax_v
[
idx
]
=
tmpdata
;
}
else
{
break
;
}
}
}
}
}
}
}
...
@@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
...
@@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kVSize
=
sizeof
(
VecT
)
/
sizeof
(
T
);
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kDimCeil
=
1
<<
Log2Elements
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
kWarpSize
=
(
kDimCeil
<
32
)
?
kDimCeil
:
32
;
constexpr
int
k
Iteration
s
=
kDimCeil
/
kWarpSize
;
constexpr
int
k
Loop
s
=
kDimCeil
/
kWarpSize
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kBatchSize
=
(
kDimCeil
<=
128
)
?
2
:
1
;
constexpr
int
kIterationsV
=
constexpr
int
kLoopsV
=
(
kLoops
>=
kVSize
)
?
(
kLoops
/
kVSize
)
:
1
;
(
kIterations
>=
kVSize
)
?
(
kIterations
/
kVSize
)
:
1
;
int
element_count_v
=
element_count
/
kVSize
;
int
element_count_v
=
element_count
/
kVSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
kBatchSize
;
int
local_batches
=
batch_size
-
first_batch
;
int
local_batches
=
min
(
batch_size
-
first_batch
,
kBatchSize
);
if
(
local_batches
>
kBatchSize
)
{
local_batches
=
kBatchSize
;
}
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kIterationsV
];
VecT
grad_reg
[
kBatchSize
][
kIterationsV
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[(
first_batch
+
i
)
*
stride
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[(
first_batch
+
i
)
*
stride
]);
// max index to read
// max index to read
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
int
idx_max_v
[
kBatchSize
];
int
idx_max_v
=
idx_max
/
kVSize
;
// read data
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
int
src_idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
src_idx
<
idx_max_v
)
{
src_reg
[
i
][
it
]
=
src_v
[
src_idx
];
grad_reg
[
i
][
it
]
=
grad_v
[
src_idx
];
}
else
{
#pragma unroll
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
i
++
)
{
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
])[
s
]
=
0.0
;
int
idx_max
=
((
i
+
first_batch
)
<
batch_size
)
?
element_count
:
0
;
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
])[
s
]
=
0.0
;
idx_max_v
[
i
]
=
idx_max
/
kVSize
;
}
}
}
}
// read data from global memory
VecT
src_reg
[
kBatchSize
][
kLoopsV
];
VecT
grad_reg
[
kBatchSize
][
kLoopsV
];
VecT
k_value
;
for
(
int
s
=
0
;
s
<
kVSize
;
s
++
)
{
reinterpret_cast
<
T
*>
(
&
k_value
)[
s
]
=
0.0
;
}
}
kps
::
Init
<
VecT
,
kBatchSize
*
kLoopsV
>
(
&
src_reg
[
0
][
0
],
k_value
);
kps
::
Init
<
VecT
,
kBatchSize
*
kLoopsV
>
(
&
grad_reg
[
0
][
0
],
k_value
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
int
flag
=
i
<
local_batches
?
1
:
0
;
int
ptr
=
(
first_batch
+
i
)
*
stride
;
const
VecT
*
src_v
=
reinterpret_cast
<
const
VecT
*>
(
&
src
[
ptr
]);
const
VecT
*
grad_v
=
reinterpret_cast
<
const
VecT
*>
(
&
grad
[
ptr
]);
kps
::
ReadData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
src_reg
[
i
][
0
],
&
src_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
flag
);
kps
::
ReadData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
&
grad_reg
[
i
][
0
],
&
grad_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
flag
);
}
// change T to AccT
AccT
src_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
AccT
grad_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
const
T
*
src_ptr
=
reinterpret_cast
<
const
T
*>
(
&
src_reg
[
0
][
0
]);
const
T
*
grad_ptr
=
reinterpret_cast
<
const
T
*>
(
&
grad_reg
[
0
][
0
]);
constexpr
int
kStep
=
kBatchSize
*
kLoopsV
*
kVSize
;
constexpr
int
kVItem
=
kLoopsV
*
kVSize
;
kps
::
ElementwiseUnary
<
T
,
AccT
,
kStep
,
1
,
1
,
DataTransFunctor
<
T
,
AccT
>>
(
&
src_tmp
[
0
][
0
][
0
],
&
src_ptr
[
0
],
DataTransFunctor
<
T
,
AccT
>
());
kps
::
ElementwiseUnary
<
T
,
AccT
,
kStep
,
1
,
1
,
DataTransFunctor
<
T
,
AccT
>>
(
&
grad_tmp
[
0
][
0
][
0
],
&
grad_ptr
[
0
],
DataTransFunctor
<
T
,
AccT
>
());
// compute sum
// compute sum
AccT
sum
[
kBatchSize
]{
0.0
};
AccT
sum
[
kBatchSize
]{
0.0
};
#pragma unroll
AccT
sum_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
AccT
*
gradptr
=
reinterpret_cast
<
AccT
*>
(
&
grad_tmp
[
0
][
0
][
0
]);
#pragma unroll
AccT
*
srcptr
=
reinterpret_cast
<
AccT
*>
(
&
src_tmp
[
0
][
0
][
0
]);
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
kps
::
ElementwiseBinary
<
AccT
,
AccT
,
kStep
,
1
,
1
,
kps
::
MulFunctor
<
AccT
>>
(
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
&
sum_tmp
[
0
][
0
][
0
],
&
gradptr
[
0
],
&
srcptr
[
0
],
kps
::
MulFunctor
<
AccT
>
());
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
kps
::
Reduce
<
AccT
,
kVItem
,
kBatchSize
,
1
,
kps
::
AddFunctor
<
AccT
>
,
#pragma unroll
kps
::
details
::
ReduceMode
::
kLocalMode
>
(
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
&
sum
[
0
],
&
sum_tmp
[
0
][
0
][
0
],
kps
::
AddFunctor
<
AccT
>
(),
true
);
if
(
LogMode
)
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]);
}
else
{
sum
[
i
]
+=
static_cast
<
AccT
>
(
gradptr
[
s
]
*
srcptr
[
s
]);
}
}
}
}
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
WarpReduceSum
<
AccT
,
kBatchSize
,
kWarpSize
>
(
sum
);
// write result
// write result to global memory
AccT
out
[
kBatchSize
][
kLoopsV
][
kVSize
];
T
out_tmp
[
kBatchSize
][
kLoopsV
][
kVSize
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
if
(
i
>=
local_batches
)
break
;
AccT
*
gradptr
=
reinterpret_cast
<
AccT
*>
(
&
grad_tmp
[
i
][
0
][
0
]);
AccT
*
srcptr
=
reinterpret_cast
<
AccT
*>
(
&
src_tmp
[
i
][
0
][
0
]);
kps
::
ElementwiseUnary
<
AccT
,
AccT
,
kVItem
,
1
,
1
,
UnarySubFunctor
<
AccT
>>
(
&
out
[
i
][
0
][
0
],
&
gradptr
[
0
],
UnarySubFunctor
<
AccT
>
(
sum
[
i
]));
kps
::
ElementwiseBinary
<
AccT
,
T
,
kVItem
,
1
,
1
,
kps
::
MulFunctor
<
AccT
>>
(
&
out_tmp
[
i
][
0
][
0
],
&
srcptr
[
0
],
&
out
[
i
][
0
][
0
],
kps
::
MulFunctor
<
AccT
>
());
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
VecT
*
dst_v
=
reinterpret_cast
<
VecT
*>
(
&
dst
[(
first_batch
+
i
)
*
stride
]);
VecT
*
reg_v
=
reinterpret_cast
<
VecT
*>
(
&
out_tmp
[
i
][
0
][
0
]);
// max index to write
kps
::
WriteData
<
VecT
,
VecT
,
kLoopsV
,
1
,
1
,
true
>
(
int
idx_max
=
(
i
<
local_batches
)
?
element_count
:
0
;
&
dst_v
[
0
],
&
reg_v
[
0
],
idx_max_v
[
i
],
0
,
kWarpSize
,
1
);
int
idx_max_v
=
idx_max
/
kVSize
;
#pragma unroll
for
(
int
it
=
0
;
it
<
kIterationsV
;
++
it
)
{
VecT
tmpdata
;
T
*
tmpptr
=
reinterpret_cast
<
T
*>
(
&
tmpdata
);
T
*
gradptr
=
reinterpret_cast
<
T
*>
(
&
grad_reg
[
i
][
it
]);
T
*
srcptr
=
reinterpret_cast
<
T
*>
(
&
src_reg
[
i
][
it
]);
#pragma unroll
for
(
int
s
=
0
;
s
<
kVSize
;
++
s
)
{
if
(
LogMode
)
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
gradptr
[
s
])
-
std
::
exp
(
static_cast
<
AccT
>
(
srcptr
[
s
]))
*
sum
[
i
];
}
else
{
tmpptr
[
s
]
=
static_cast
<
AccT
>
(
srcptr
[
s
])
*
(
static_cast
<
AccT
>
(
gradptr
[
s
])
-
sum
[
i
]);
}
}
int
idx
=
threadIdx
.
x
+
it
*
kWarpSize
;
if
(
idx
<
idx_max_v
)
{
dst_v
[
idx
]
=
tmpdata
;
}
}
}
}
}
}
...
@@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
...
@@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
// vectorization read/write
// vectorization read/write
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T4
=
typename
VecT4
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
using
T2
=
typename
VecT2
<
T
>::
Type
;
if
(
dim
%
4
==
0
)
{
if
(
dim
%
4
==
0
)
{
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
SwitchWarpSoftmaxForward
<
T
,
T4
,
LogMode
>
(
blocks
,
threads
,
dev_ctx
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
out_data
,
x
.
data
<
T
>
(),
N
,
dim
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录